Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit cc81afe

Browse filesBrowse files
committed
feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct
1 parent d17c188 commit cc81afe
Copy full SHA for cc81afe

File tree

Expand file treeCollapse file tree

2 files changed

+25
-2
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+25
-2
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,10 @@ def __init__(
426426
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
427427

428428
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
429-
template=template, eos_token=eos_token, bos_token=bos_token
429+
template=template,
430+
eos_token=eos_token,
431+
bos_token=bos_token,
432+
stop_token_ids=[eos_token_id],
430433
).to_chat_handler()
431434

432435
if self.chat_format is None and self.chat_handler is None:

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+21-1Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import jinja2
1212

13+
import numpy as np
14+
import numpy.typing as npt
15+
1316
import llama_cpp.llama as llama
1417
import llama_cpp.llama_types as llama_types
1518
import llama_cpp.llama_grammar as llama_grammar
@@ -150,6 +153,7 @@ class ChatFormatterResponse:
150153

151154
prompt: str
152155
stop: Optional[Union[str, List[str]]] = None
156+
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
153157

154158

155159
class ChatFormatter(Protocol):
@@ -173,12 +177,14 @@ def __init__(
173177
eos_token: str,
174178
bos_token: str,
175179
add_generation_prompt: bool = True,
180+
stop_token_ids: Optional[List[int]] = None,
176181
):
177182
"""A chat formatter that uses jinja2 templates to format the prompt."""
178183
self.template = template
179184
self.eos_token = eos_token
180185
self.bos_token = bos_token
181186
self.add_generation_prompt = add_generation_prompt
187+
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
182188

183189
self._environment = jinja2.Environment(
184190
loader=jinja2.BaseLoader(),
@@ -211,7 +217,16 @@ def raise_exception(message: str):
211217
tool_choice=tool_choice,
212218
)
213219

214-
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
220+
stopping_criteria = None
221+
if self.stop_token_ids is not None:
222+
def stop_on_last_token(
223+
tokens: npt.NDArray[np.intc],
224+
logits: npt.NDArray[np.single]
225+
) -> bool:
226+
return tokens[-1] in self.stop_token_ids
227+
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
228+
229+
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
215230

216231
def to_chat_handler(self) -> LlamaChatCompletionHandler:
217232
return chat_formatter_to_chat_completion_handler(self)
@@ -533,6 +548,10 @@ def chat_completion_handler(
533548
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
534549
stop = stop + rstop
535550

551+
stopping_criteria = None
552+
if result.stopping_criteria is not None:
553+
stopping_criteria = result.stopping_criteria
554+
536555
if response_format is not None and response_format["type"] == "json_object":
537556
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
538557

@@ -598,6 +617,7 @@ def chat_completion_handler(
598617
mirostat_eta=mirostat_eta,
599618
model=model,
600619
logits_processor=logits_processor,
620+
stopping_criteria=stopping_criteria,
601621
grammar=grammar,
602622
logit_bias=logit_bias,
603623
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.