Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 65d9cc0

Browse filesBrowse files
committed
Add openai frequency and presence penalty parameters. Closes abetlen#169
1 parent 75d8619 commit 65d9cc0
Copy full SHA for 65d9cc0

File tree

Expand file treeCollapse file tree

2 files changed

+36
-6
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+36
-6
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+36-2Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,14 +261,16 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
261261
]
262262
self.eval_logits.extend(logits)
263263

264-
def _sample_top_p_top_k(
264+
def _sample(
265265
self,
266266
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
267267
last_n_tokens_size: llama_cpp.c_int,
268268
top_k: llama_cpp.c_int,
269269
top_p: llama_cpp.c_float,
270270
temp: llama_cpp.c_float,
271271
repeat_penalty: llama_cpp.c_float,
272+
frequency_penalty: llama_cpp.c_float,
273+
presence_penalty: llama_cpp.c_float,
272274
):
273275
assert self.ctx is not None
274276
assert len(self.eval_logits) > 0
@@ -298,6 +300,14 @@ def _sample_top_p_top_k(
298300
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
299301
penalty=repeat_penalty,
300302
)
303+
llama_cpp.llama_sample_frequency_and_presence_penalties(
304+
ctx=self.ctx,
305+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
306+
last_tokens_data=last_n_tokens_data,
307+
last_tokens_size=last_n_tokens_size,
308+
alpha_frequency=frequency_penalty,
309+
alpha_presence=presence_penalty,
310+
)
301311
if float(temp.value) == 0.0:
302312
return llama_cpp.llama_sample_token_greedy(
303313
ctx=self.ctx,
@@ -344,6 +354,8 @@ def sample(
344354
top_p: float,
345355
temp: float,
346356
repeat_penalty: float,
357+
frequency_penalty: float = 0.0,
358+
presence_penalty: float = 0.0,
347359
):
348360
"""Sample a token from the model.
349361
@@ -360,7 +372,7 @@ def sample(
360372
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
361373
0, self.last_n_tokens_size - len(self.eval_tokens)
362374
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
363-
return self._sample_top_p_top_k(
375+
return self._sample(
364376
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
365377
*last_n_tokens_data
366378
),
@@ -369,6 +381,8 @@ def sample(
369381
top_p=llama_cpp.c_float(top_p),
370382
temp=llama_cpp.c_float(temp),
371383
repeat_penalty=llama_cpp.c_float(repeat_penalty),
384+
frequency_penalty=llama_cpp.c_float(frequency_penalty),
385+
presence_penalty=llama_cpp.c_float(presence_penalty),
372386
)
373387

374388
def generate(
@@ -378,6 +392,8 @@ def generate(
378392
top_p: float,
379393
temp: float,
380394
repeat_penalty: float,
395+
frequency_penalty: float = 0.0,
396+
presence_penalty: float = 0.0,
381397
reset: bool = True,
382398
) -> Generator[
383399
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
@@ -431,6 +447,8 @@ def generate(
431447
top_k=top_k,
432448
top_p=top_p,
433449
temp=temp,
450+
frequency_penalty=frequency_penalty,
451+
presence_penalty=presence_penalty,
434452
repeat_penalty=repeat_penalty,
435453
)
436454
tokens_or_none = yield token
@@ -505,6 +523,8 @@ def _create_completion(
505523
logprobs: Optional[int] = None,
506524
echo: bool = False,
507525
stop: Optional[List[str]] = [],
526+
frequency_penalty: float = 0.0,
527+
presence_penalty: float = 0.0,
508528
repeat_penalty: float = 1.1,
509529
top_k: int = 40,
510530
stream: bool = False,
@@ -563,6 +583,8 @@ def _create_completion(
563583
top_k=top_k,
564584
top_p=top_p,
565585
temp=temperature,
586+
frequency_penalty=frequency_penalty,
587+
presence_penalty=presence_penalty,
566588
repeat_penalty=repeat_penalty,
567589
):
568590
if token == llama_cpp.llama_token_eos():
@@ -737,6 +759,8 @@ def create_completion(
737759
logprobs: Optional[int] = None,
738760
echo: bool = False,
739761
stop: Optional[List[str]] = [],
762+
frequency_penalty: float = 0.0,
763+
presence_penalty: float = 0.0,
740764
repeat_penalty: float = 1.1,
741765
top_k: int = 40,
742766
stream: bool = False,
@@ -772,6 +796,8 @@ def create_completion(
772796
logprobs=logprobs,
773797
echo=echo,
774798
stop=stop,
799+
frequency_penalty=frequency_penalty,
800+
presence_penalty=presence_penalty,
775801
repeat_penalty=repeat_penalty,
776802
top_k=top_k,
777803
stream=stream,
@@ -792,6 +818,8 @@ def __call__(
792818
logprobs: Optional[int] = None,
793819
echo: bool = False,
794820
stop: Optional[List[str]] = [],
821+
frequency_penalty: float = 0.0,
822+
presence_penalty: float = 0.0,
795823
repeat_penalty: float = 1.1,
796824
top_k: int = 40,
797825
stream: bool = False,
@@ -827,6 +855,8 @@ def __call__(
827855
logprobs=logprobs,
828856
echo=echo,
829857
stop=stop,
858+
frequency_penalty=frequency_penalty,
859+
presence_penalty=presence_penalty,
830860
repeat_penalty=repeat_penalty,
831861
top_k=top_k,
832862
stream=stream,
@@ -899,6 +929,8 @@ def create_chat_completion(
899929
stream: bool = False,
900930
stop: Optional[List[str]] = [],
901931
max_tokens: int = 256,
932+
presence_penalty: float = 0.0,
933+
frequency_penalty: float = 0.0,
902934
repeat_penalty: float = 1.1,
903935
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
904936
"""Generate a chat completion from a list of messages.
@@ -932,6 +964,8 @@ def create_chat_completion(
932964
stream=stream,
933965
max_tokens=max_tokens,
934966
repeat_penalty=repeat_penalty,
967+
presence_penalty=presence_penalty,
968+
frequency_penalty=frequency_penalty,
935969
)
936970
if stream:
937971
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
-4Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ def create_completion(
214214
exclude={
215215
"model",
216216
"n",
217-
"frequency_penalty",
218-
"presence_penalty",
219217
"best_of",
220218
"logit_bias",
221219
"user",
@@ -315,8 +313,6 @@ def create_chat_completion(
315313
exclude={
316314
"model",
317315
"n",
318-
"presence_penalty",
319-
"frequency_penalty",
320316
"logit_bias",
321317
"user",
322318
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.