@@ -31,9 +31,7 @@ class Settings(BaseSettings):
31
31
ge = 0 ,
32
32
description = "The number of layers to put on the GPU. The rest will be on the CPU." ,
33
33
)
34
- seed : int = Field (
35
- default = 1337 , description = "Random seed. -1 for random."
36
- )
34
+ seed : int = Field (default = 1337 , description = "Random seed. -1 for random." )
37
35
n_batch : int = Field (
38
36
default = 512 , ge = 1 , description = "The batch size to use per eval."
39
37
)
@@ -80,12 +78,8 @@ class Settings(BaseSettings):
80
78
verbose : bool = Field (
81
79
default = True , description = "Whether to print debug information."
82
80
)
83
- host : str = Field (
84
- default = "localhost" , description = "Listen address"
85
- )
86
- port : int = Field (
87
- default = 8000 , description = "Listen port"
88
- )
81
+ host : str = Field (default = "localhost" , description = "Listen address" )
82
+ port : int = Field (default = 8000 , description = "Listen port" )
89
83
interrupt_requests : bool = Field (
90
84
default = True ,
91
85
description = "Whether to interrupt requests when a new request is received." ,
@@ -178,7 +172,7 @@ def get_settings():
178
172
yield settings
179
173
180
174
181
- model_field = Field (description = "The model to use for generating completions." )
175
+ model_field = Field (description = "The model to use for generating completions." , default = None )
182
176
183
177
max_tokens_field = Field (
184
178
default = 16 , ge = 1 , le = 2048 , description = "The maximum number of tokens to generate."
@@ -242,21 +236,18 @@ def get_settings():
242
236
default = 0 ,
243
237
ge = 0 ,
244
238
le = 2 ,
245
- description = "Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)"
239
+ description = "Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)" ,
246
240
)
247
241
248
242
mirostat_tau_field = Field (
249
243
default = 5.0 ,
250
244
ge = 0.0 ,
251
245
le = 10.0 ,
252
- description = "Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text"
246
+ description = "Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text" ,
253
247
)
254
248
255
249
mirostat_eta_field = Field (
256
- default = 0.1 ,
257
- ge = 0.001 ,
258
- le = 1.0 ,
259
- description = "Mirostat learning rate"
250
+ default = 0.1 , ge = 0.001 , le = 1.0 , description = "Mirostat learning rate"
260
251
)
261
252
262
253
@@ -294,22 +285,23 @@ class CreateCompletionRequest(BaseModel):
294
285
model : Optional [str ] = model_field
295
286
n : Optional [int ] = 1
296
287
best_of : Optional [int ] = 1
297
- user : Optional [str ] = Field (None )
288
+ user : Optional [str ] = Field (default = None )
298
289
299
290
# llama.cpp specific parameters
300
291
top_k : int = top_k_field
301
292
repeat_penalty : float = repeat_penalty_field
302
293
logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
303
294
304
- class Config :
305
- schema_extra = {
306
- "example" : {
307
- "prompt" : "\n \n ### Instructions:\n What is the capital of France?\n \n ### Response:\n " ,
308
- "stop" : ["\n " , "###" ],
309
- }
295
+ model_config = {
296
+ "json_schema_extra" : {
297
+ "examples" : [
298
+ {
299
+ "prompt" : "\n \n ### Instructions:\n What is the capital of France?\n \n ### Response:\n " ,
300
+ "stop" : ["\n " , "###" ],
301
+ }
302
+ ]
310
303
}
311
-
312
-
304
+ }
313
305
314
306
315
307
def make_logit_bias_processor (
@@ -328,7 +320,7 @@ def make_logit_bias_processor(
328
320
329
321
elif logit_bias_type == "tokens" :
330
322
for token , score in logit_bias .items ():
331
- token = token .encode (' utf-8' )
323
+ token = token .encode (" utf-8" )
332
324
for input_id in llama .tokenize (token , add_bos = False ):
333
325
to_bias [input_id ] = score
334
326
@@ -352,7 +344,7 @@ async def create_completion(
352
344
request : Request ,
353
345
body : CreateCompletionRequest ,
354
346
llama : llama_cpp .Llama = Depends (get_llama ),
355
- ):
347
+ ) -> llama_cpp . Completion :
356
348
if isinstance (body .prompt , list ):
357
349
assert len (body .prompt ) <= 1
358
350
body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
@@ -364,7 +356,7 @@ async def create_completion(
364
356
"logit_bias_type" ,
365
357
"user" ,
366
358
}
367
- kwargs = body .dict (exclude = exclude )
359
+ kwargs = body .model_dump (exclude = exclude )
368
360
369
361
if body .logit_bias is not None :
370
362
kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
@@ -396,7 +388,7 @@ async def event_publisher(inner_send_chan: MemoryObjectSendStream):
396
388
397
389
return EventSourceResponse (
398
390
recv_chan , data_sender_callable = partial (event_publisher , send_chan )
399
- )
391
+ ) # type: ignore
400
392
else :
401
393
completion : llama_cpp .Completion = await run_in_threadpool (llama , ** kwargs ) # type: ignore
402
394
return completion
@@ -405,16 +397,17 @@ async def event_publisher(inner_send_chan: MemoryObjectSendStream):
405
397
class CreateEmbeddingRequest (BaseModel ):
406
398
model : Optional [str ] = model_field
407
399
input : Union [str , List [str ]] = Field (description = "The input to embed." )
408
- user : Optional [str ]
409
-
410
- class Config :
411
- schema_extra = {
412
- "example" : {
413
- "input" : "The food was delicious and the waiter..." ,
414
- }
400
+ user : Optional [str ] = Field (default = None )
401
+
402
+ model_config = {
403
+ "json_schema_extra" : {
404
+ "examples" : [
405
+ {
406
+ "input" : "The food was delicious and the waiter..." ,
407
+ }
408
+ ]
415
409
}
416
-
417
-
410
+ }
418
411
419
412
420
413
@router .post (
@@ -424,7 +417,7 @@ async def create_embedding(
424
417
request : CreateEmbeddingRequest , llama : llama_cpp .Llama = Depends (get_llama )
425
418
):
426
419
return await run_in_threadpool (
427
- llama .create_embedding , ** request .dict (exclude = {"user" })
420
+ llama .create_embedding , ** request .model_dump (exclude = {"user" })
428
421
)
429
422
430
423
@@ -461,21 +454,22 @@ class CreateChatCompletionRequest(BaseModel):
461
454
repeat_penalty : float = repeat_penalty_field
462
455
logit_bias_type : Optional [Literal ["input_ids" , "tokens" ]] = Field (None )
463
456
464
- class Config :
465
- schema_extra = {
466
- "example" : {
467
- "messages" : [
468
- ChatCompletionRequestMessage (
469
- role = "system" , content = "You are a helpful assistant."
470
- ),
471
- ChatCompletionRequestMessage (
472
- role = "user" , content = "What is the capital of France?"
473
- ),
474
- ]
475
- }
457
+ model_config = {
458
+ "json_schema_extra" : {
459
+ "examples" : [
460
+ {
461
+ "messages" : [
462
+ ChatCompletionRequestMessage (
463
+ role = "system" , content = "You are a helpful assistant."
464
+ ).model_dump (),
465
+ ChatCompletionRequestMessage (
466
+ role = "user" , content = "What is the capital of France?"
467
+ ).model_dump (),
468
+ ]
469
+ }
470
+ ]
476
471
}
477
-
478
-
472
+ }
479
473
480
474
481
475
@router .post (
@@ -486,14 +480,14 @@ async def create_chat_completion(
486
480
body : CreateChatCompletionRequest ,
487
481
llama : llama_cpp .Llama = Depends (get_llama ),
488
482
settings : Settings = Depends (get_settings ),
489
- ) -> Union [ llama_cpp .ChatCompletion ]: # type: ignore
483
+ ) -> llama_cpp .ChatCompletion :
490
484
exclude = {
491
485
"n" ,
492
486
"logit_bias" ,
493
487
"logit_bias_type" ,
494
488
"user" ,
495
489
}
496
- kwargs = body .dict (exclude = exclude )
490
+ kwargs = body .model_dump (exclude = exclude )
497
491
498
492
if body .logit_bias is not None :
499
493
kwargs ['logits_processor' ] = llama_cpp .LogitsProcessorList ([
@@ -526,7 +520,7 @@ async def event_publisher(inner_send_chan: MemoryObjectSendStream):
526
520
return EventSourceResponse (
527
521
recv_chan ,
528
522
data_sender_callable = partial (event_publisher , send_chan ),
529
- )
523
+ ) # type: ignore
530
524
else :
531
525
completion : llama_cpp .ChatCompletion = await run_in_threadpool (
532
526
llama .create_chat_completion , ** kwargs # type: ignore
@@ -546,8 +540,6 @@ class ModelList(TypedDict):
546
540
data : List [ModelData ]
547
541
548
542
549
-
550
-
551
543
@router .get ("/v1/models" )
552
544
async def get_models (
553
545
settings : Settings = Depends (get_settings ),
0 commit comments