@@ -261,14 +261,16 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
261
261
]
262
262
self .eval_logits .extend (logits )
263
263
264
- def _sample_top_p_top_k (
264
+ def _sample (
265
265
self ,
266
266
last_n_tokens_data , # type: llama_cpp.Array[llama_cpp.llama_token]
267
267
last_n_tokens_size : llama_cpp .c_int ,
268
268
top_k : llama_cpp .c_int ,
269
269
top_p : llama_cpp .c_float ,
270
270
temp : llama_cpp .c_float ,
271
271
repeat_penalty : llama_cpp .c_float ,
272
+ frequency_penalty : llama_cpp .c_float ,
273
+ presence_penalty : llama_cpp .c_float ,
272
274
):
273
275
assert self .ctx is not None
274
276
assert len (self .eval_logits ) > 0
@@ -298,6 +300,14 @@ def _sample_top_p_top_k(
298
300
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
299
301
penalty = repeat_penalty ,
300
302
)
303
+ llama_cpp .llama_sample_frequency_and_presence_penalties (
304
+ ctx = self .ctx ,
305
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
306
+ last_tokens_data = last_n_tokens_data ,
307
+ last_tokens_size = last_n_tokens_size ,
308
+ alpha_frequency = frequency_penalty ,
309
+ alpha_presence = presence_penalty ,
310
+ )
301
311
if float (temp .value ) == 0.0 :
302
312
return llama_cpp .llama_sample_token_greedy (
303
313
ctx = self .ctx ,
@@ -344,6 +354,8 @@ def sample(
344
354
top_p : float ,
345
355
temp : float ,
346
356
repeat_penalty : float ,
357
+ frequency_penalty : float = 0.0 ,
358
+ presence_penalty : float = 0.0 ,
347
359
):
348
360
"""Sample a token from the model.
349
361
@@ -360,7 +372,7 @@ def sample(
360
372
last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
361
373
0 , self .last_n_tokens_size - len (self .eval_tokens )
362
374
) + list (self .eval_tokens )[- self .last_n_tokens_size :]
363
- return self ._sample_top_p_top_k (
375
+ return self ._sample (
364
376
last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
365
377
* last_n_tokens_data
366
378
),
@@ -369,6 +381,8 @@ def sample(
369
381
top_p = llama_cpp .c_float (top_p ),
370
382
temp = llama_cpp .c_float (temp ),
371
383
repeat_penalty = llama_cpp .c_float (repeat_penalty ),
384
+ frequency_penalty = llama_cpp .c_float (frequency_penalty ),
385
+ presence_penalty = llama_cpp .c_float (presence_penalty ),
372
386
)
373
387
374
388
def generate (
@@ -378,6 +392,8 @@ def generate(
378
392
top_p : float ,
379
393
temp : float ,
380
394
repeat_penalty : float ,
395
+ frequency_penalty : float = 0.0 ,
396
+ presence_penalty : float = 0.0 ,
381
397
reset : bool = True ,
382
398
) -> Generator [
383
399
llama_cpp .llama_token , Optional [Sequence [llama_cpp .llama_token ]], None
@@ -431,6 +447,8 @@ def generate(
431
447
top_k = top_k ,
432
448
top_p = top_p ,
433
449
temp = temp ,
450
+ frequency_penalty = frequency_penalty ,
451
+ presence_penalty = presence_penalty ,
434
452
repeat_penalty = repeat_penalty ,
435
453
)
436
454
tokens_or_none = yield token
@@ -505,6 +523,8 @@ def _create_completion(
505
523
logprobs : Optional [int ] = None ,
506
524
echo : bool = False ,
507
525
stop : Optional [List [str ]] = [],
526
+ frequency_penalty : float = 0.0 ,
527
+ presence_penalty : float = 0.0 ,
508
528
repeat_penalty : float = 1.1 ,
509
529
top_k : int = 40 ,
510
530
stream : bool = False ,
@@ -563,6 +583,8 @@ def _create_completion(
563
583
top_k = top_k ,
564
584
top_p = top_p ,
565
585
temp = temperature ,
586
+ frequency_penalty = frequency_penalty ,
587
+ presence_penalty = presence_penalty ,
566
588
repeat_penalty = repeat_penalty ,
567
589
):
568
590
if token == llama_cpp .llama_token_eos ():
@@ -737,6 +759,8 @@ def create_completion(
737
759
logprobs : Optional [int ] = None ,
738
760
echo : bool = False ,
739
761
stop : Optional [List [str ]] = [],
762
+ frequency_penalty : float = 0.0 ,
763
+ presence_penalty : float = 0.0 ,
740
764
repeat_penalty : float = 1.1 ,
741
765
top_k : int = 40 ,
742
766
stream : bool = False ,
@@ -772,6 +796,8 @@ def create_completion(
772
796
logprobs = logprobs ,
773
797
echo = echo ,
774
798
stop = stop ,
799
+ frequency_penalty = frequency_penalty ,
800
+ presence_penalty = presence_penalty ,
775
801
repeat_penalty = repeat_penalty ,
776
802
top_k = top_k ,
777
803
stream = stream ,
@@ -792,6 +818,8 @@ def __call__(
792
818
logprobs : Optional [int ] = None ,
793
819
echo : bool = False ,
794
820
stop : Optional [List [str ]] = [],
821
+ frequency_penalty : float = 0.0 ,
822
+ presence_penalty : float = 0.0 ,
795
823
repeat_penalty : float = 1.1 ,
796
824
top_k : int = 40 ,
797
825
stream : bool = False ,
@@ -827,6 +855,8 @@ def __call__(
827
855
logprobs = logprobs ,
828
856
echo = echo ,
829
857
stop = stop ,
858
+ frequency_penalty = frequency_penalty ,
859
+ presence_penalty = presence_penalty ,
830
860
repeat_penalty = repeat_penalty ,
831
861
top_k = top_k ,
832
862
stream = stream ,
@@ -899,6 +929,8 @@ def create_chat_completion(
899
929
stream : bool = False ,
900
930
stop : Optional [List [str ]] = [],
901
931
max_tokens : int = 256 ,
932
+ presence_penalty : float = 0.0 ,
933
+ frequency_penalty : float = 0.0 ,
902
934
repeat_penalty : float = 1.1 ,
903
935
) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
904
936
"""Generate a chat completion from a list of messages.
@@ -932,6 +964,8 @@ def create_chat_completion(
932
964
stream = stream ,
933
965
max_tokens = max_tokens ,
934
966
repeat_penalty = repeat_penalty ,
967
+ presence_penalty = presence_penalty ,
968
+ frequency_penalty = frequency_penalty ,
935
969
)
936
970
if stream :
937
971
chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
0 commit comments