@@ -290,7 +290,11 @@ def _sample(
290
290
mirostat_tau : llama_cpp .c_float ,
291
291
mirostat_eta : llama_cpp .c_float ,
292
292
penalize_nl : bool = True ,
293
+ logits_processors = None
293
294
):
295
+ if logits_processors is None :
296
+ logits_processors = []
297
+
294
298
assert self .ctx is not None
295
299
assert len (self .eval_logits ) > 0
296
300
n_vocab = int (llama_cpp .llama_n_vocab (self .ctx ))
@@ -302,6 +306,9 @@ def _sample(
302
306
else last_n_tokens_size
303
307
)
304
308
logits = self .eval_logits [- 1 ]
309
+ for processor in logits_processors :
310
+ logits = processor (last_n_tokens_data , logits )
311
+
305
312
nl_logit = logits [int (Llama .token_nl ())]
306
313
data = (llama_cpp .llama_token_data * n_vocab )(
307
314
* [
@@ -420,6 +427,7 @@ def sample(
420
427
mirostat_eta : float = 0.1 ,
421
428
mirostat_tau : float = 5.0 ,
422
429
penalize_nl : bool = True ,
430
+ logits_processors = None
423
431
):
424
432
"""Sample a token from the model.
425
433
@@ -452,6 +460,7 @@ def sample(
452
460
mirostat_tau = llama_cpp .c_float (mirostat_tau ),
453
461
mirostat_eta = llama_cpp .c_float (mirostat_eta ),
454
462
penalize_nl = penalize_nl ,
463
+ logits_processors = logits_processors
455
464
)
456
465
457
466
def generate (
@@ -468,6 +477,7 @@ def generate(
468
477
mirostat_mode : int = 0 ,
469
478
mirostat_tau : float = 5.0 ,
470
479
mirostat_eta : float = 0.1 ,
480
+ logits_processors = None
471
481
) -> Generator [int , Optional [Sequence [int ]], None ]:
472
482
"""Create a generator of tokens from a prompt.
473
483
@@ -525,6 +535,7 @@ def generate(
525
535
mirostat_mode = mirostat_mode ,
526
536
mirostat_tau = mirostat_tau ,
527
537
mirostat_eta = mirostat_eta ,
538
+ logits_processors = logits_processors
528
539
)
529
540
tokens_or_none = yield token
530
541
tokens = [token ]
@@ -609,6 +620,8 @@ def _create_completion(
609
620
mirostat_tau : float = 5.0 ,
610
621
mirostat_eta : float = 0.1 ,
611
622
model : Optional [str ] = None ,
623
+ logits_processors = None ,
624
+ stopping_criterias = None
612
625
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
613
626
assert self .ctx is not None
614
627
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
@@ -672,13 +685,22 @@ def _create_completion(
672
685
frequency_penalty = frequency_penalty ,
673
686
presence_penalty = presence_penalty ,
674
687
repeat_penalty = repeat_penalty ,
688
+ logits_processors = logits_processors
675
689
):
676
690
if token == Llama .token_eos ():
677
691
text = self .detokenize (completion_tokens )
678
692
finish_reason = "stop"
679
693
break
680
694
681
695
completion_tokens .append (token )
696
+ for stopping_crit in stopping_criterias :
697
+ if stopping_crit (completion_tokens , None ):
698
+ text = self .detokenize (completion_tokens )
699
+ finish_reason = "stop"
700
+ break
701
+
702
+ if finish_reason == "stop" :
703
+ break
682
704
683
705
all_text = self .detokenize (completion_tokens )
684
706
@@ -978,6 +1000,8 @@ def create_completion(
978
1000
mirostat_tau : float = 5.0 ,
979
1001
mirostat_eta : float = 0.1 ,
980
1002
model : Optional [str ] = None ,
1003
+ logits_processors = None ,
1004
+ stopping_criterias = None
981
1005
) -> Union [Completion , Iterator [CompletionChunk ]]:
982
1006
"""Generate text from a prompt.
983
1007
@@ -1020,6 +1044,8 @@ def create_completion(
1020
1044
mirostat_tau = mirostat_tau ,
1021
1045
mirostat_eta = mirostat_eta ,
1022
1046
model = model ,
1047
+ logits_processors = logits_processors ,
1048
+ stopping_criterias = stopping_criterias
1023
1049
)
1024
1050
if stream :
1025
1051
chunks : Iterator [CompletionChunk ] = completion_or_chunks
0 commit comments