Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 33629b4

Browse filesBrowse files
Added support for logic processors and dynamic stop criteria's
1 parent 01a010b commit 33629b4
Copy full SHA for 33629b4

File tree

Expand file treeCollapse file tree

1 file changed

+26
-0
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+26
-0
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,11 @@ def _sample(
290290
mirostat_tau: llama_cpp.c_float,
291291
mirostat_eta: llama_cpp.c_float,
292292
penalize_nl: bool = True,
293+
logits_processors=None
293294
):
295+
if logits_processors is None:
296+
logits_processors = []
297+
294298
assert self.ctx is not None
295299
assert len(self.eval_logits) > 0
296300
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
@@ -302,6 +306,9 @@ def _sample(
302306
else last_n_tokens_size
303307
)
304308
logits = self.eval_logits[-1]
309+
for processor in logits_processors:
310+
logits = processor(last_n_tokens_data, logits)
311+
305312
nl_logit = logits[int(Llama.token_nl())]
306313
data = (llama_cpp.llama_token_data * n_vocab)(
307314
*[
@@ -420,6 +427,7 @@ def sample(
420427
mirostat_eta: float = 0.1,
421428
mirostat_tau: float = 5.0,
422429
penalize_nl: bool = True,
430+
logits_processors=None
423431
):
424432
"""Sample a token from the model.
425433
@@ -452,6 +460,7 @@ def sample(
452460
mirostat_tau=llama_cpp.c_float(mirostat_tau),
453461
mirostat_eta=llama_cpp.c_float(mirostat_eta),
454462
penalize_nl=penalize_nl,
463+
logits_processors=logits_processors
455464
)
456465

457466
def generate(
@@ -468,6 +477,7 @@ def generate(
468477
mirostat_mode: int = 0,
469478
mirostat_tau: float = 5.0,
470479
mirostat_eta: float = 0.1,
480+
logits_processors=None
471481
) -> Generator[int, Optional[Sequence[int]], None]:
472482
"""Create a generator of tokens from a prompt.
473483
@@ -525,6 +535,7 @@ def generate(
525535
mirostat_mode=mirostat_mode,
526536
mirostat_tau=mirostat_tau,
527537
mirostat_eta=mirostat_eta,
538+
logits_processors=logits_processors
528539
)
529540
tokens_or_none = yield token
530541
tokens = [token]
@@ -609,6 +620,8 @@ def _create_completion(
609620
mirostat_tau: float = 5.0,
610621
mirostat_eta: float = 0.1,
611622
model: Optional[str] = None,
623+
logits_processors=None,
624+
stopping_criterias=None
612625
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
613626
assert self.ctx is not None
614627
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@@ -672,13 +685,22 @@ def _create_completion(
672685
frequency_penalty=frequency_penalty,
673686
presence_penalty=presence_penalty,
674687
repeat_penalty=repeat_penalty,
688+
logits_processors=logits_processors
675689
):
676690
if token == Llama.token_eos():
677691
text = self.detokenize(completion_tokens)
678692
finish_reason = "stop"
679693
break
680694

681695
completion_tokens.append(token)
696+
for stopping_crit in stopping_criterias:
697+
if stopping_crit(completion_tokens, None):
698+
text = self.detokenize(completion_tokens)
699+
finish_reason = "stop"
700+
break
701+
702+
if finish_reason == "stop":
703+
break
682704

683705
all_text = self.detokenize(completion_tokens)
684706

@@ -978,6 +1000,8 @@ def create_completion(
9781000
mirostat_tau: float = 5.0,
9791001
mirostat_eta: float = 0.1,
9801002
model: Optional[str] = None,
1003+
logits_processors=None,
1004+
stopping_criterias=None
9811005
) -> Union[Completion, Iterator[CompletionChunk]]:
9821006
"""Generate text from a prompt.
9831007
@@ -1020,6 +1044,8 @@ def create_completion(
10201044
mirostat_tau=mirostat_tau,
10211045
mirostat_eta=mirostat_eta,
10221046
model=model,
1047+
logits_processors=logits_processors,
1048+
stopping_criterias=stopping_criterias
10231049
)
10241050
if stream:
10251051
chunks: Iterator[CompletionChunk] = completion_or_chunks

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.