Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d28b753

Browse filesBrowse files
committed
Implement penalize_nl
1 parent f11e2a7 commit d28b753
Copy full SHA for d28b753

File tree

Expand file treeCollapse file tree

1 file changed

+11
-0
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+11
-0
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def _sample(
291291
mirostat_mode: llama_cpp.c_int,
292292
mirostat_tau: llama_cpp.c_float,
293293
mirostat_eta: llama_cpp.c_float,
294+
penalize_nl: bool = True,
294295
):
295296
assert self.ctx is not None
296297
assert len(self.eval_logits) > 0
@@ -299,6 +300,7 @@ def _sample(
299300
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
300301
last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size
301302
logits = self.eval_logits[-1]
303+
nl_logit = logits[llama_cpp.llama_token_nl().value]
302304
data = (llama_cpp.llama_token_data * n_vocab)(
303305
*[
304306
llama_cpp.llama_token_data(
@@ -331,6 +333,8 @@ def _sample(
331333
alpha_frequency=frequency_penalty,
332334
alpha_presence=presence_penalty,
333335
)
336+
if not penalize_nl:
337+
candidates.data[llama_cpp.llama_token_nl().value].logit = nl_logit
334338
if temp.value == 0.0:
335339
return llama_cpp.llama_sample_token_greedy(
336340
ctx=self.ctx,
@@ -413,6 +417,7 @@ def sample(
413417
mirostat_mode: int = 0,
414418
mirostat_eta: float = 0.1,
415419
mirostat_tau: float = 5.0,
420+
penalize_nl: bool = True,
416421
):
417422
"""Sample a token from the model.
418423
@@ -444,6 +449,7 @@ def sample(
444449
mirostat_mode=llama_cpp.c_int(mirostat_mode),
445450
mirostat_tau=llama_cpp.c_float(mirostat_tau),
446451
mirostat_eta=llama_cpp.c_float(mirostat_eta),
452+
penalize_nl=penalize_nl,
447453
)
448454

449455
def generate(
@@ -1170,6 +1176,11 @@ def token_bos() -> llama_cpp.llama_token:
11701176
"""Return the beginning-of-sequence token."""
11711177
return llama_cpp.llama_token_bos()
11721178

1179+
@staticmethod
1180+
def token_nl() -> llama_cpp.llama_token:
1181+
"""Return the newline token."""
1182+
return llama_cpp.llama_token_nl()
1183+
11731184
@staticmethod
11741185
def logits_to_logprobs(logits: List[float]) -> List[float]:
11751186
exps = [math.exp(float(x)) for x in logits]

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.