Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e811a81

Browse filesBrowse files
committed
Merge branch 'main' of https://github.com/abetlen/llama-cpp-python into main
2 parents ca8e3c9 + 5212fb0 commit e811a81
Copy full SHA for e811a81

File tree

Expand file treeCollapse file tree

3 files changed

+49
-5
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+49
-5
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+21-5Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ def __init__(
410410
if self.verbose:
411411
print(f"Model metadata: {self.metadata}", file=sys.stderr)
412412

413-
eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos()))
414-
bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos()))
413+
eos_token_id = self.token_eos()
414+
bos_token_id = self.token_bos()
415415

416416
eos_token = self._model.token_get_text(eos_token_id)
417417
bos_token = self._model.token_get_text(bos_token_id)
@@ -961,9 +961,9 @@ def _create_completion(
961961

962962
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
963963
created: int = int(time.time())
964-
prefix_token_id: int = int(self.metadata.get("tokenizer.ggml.prefix_token_id", self._model.token_prefix()))
965-
middle_token_id: int = int(self.metadata.get("tokenizer.ggml.middle_token_id", self._model.token_middle()))
966-
suffix_token_id: int = int(self.metadata.get("tokenizer.ggml.suffix_token_id", self._model.token_suffix()))
964+
prefix_token_id: int = self._model.token_prefix()
965+
middle_token_id: int = self._model.token_middle()
966+
suffix_token_id: int = self._model.token_suffix()
967967
# If prompt is empty, initialize completion with BOS token to avoid
968968
# detokenization including a space at the beginning of the completion
969969
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
@@ -2084,3 +2084,19 @@ def __call__(
20842084
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
20852085
) -> bool:
20862086
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
2087+
2088+
2089+
class MinTokensLogitsProcessor(LogitsProcessor):
2090+
def __init__(self, min_tokens: int, token_eos: int):
2091+
self.min_tokens = min_tokens
2092+
self.token_eos = token_eos
2093+
self.prompt_tokens = None
2094+
2095+
def __call__(
2096+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
2097+
) -> npt.NDArray[np.single]:
2098+
if self.prompt_tokens is None:
2099+
self.prompt_tokens = len(input_ids)
2100+
if len(input_ids) - self.prompt_tokens < self.min_tokens:
2101+
scores[self.token_eos] = -np.inf
2102+
return scores

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+20Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ async def create_completion(
275275
"best_of",
276276
"logit_bias_type",
277277
"user",
278+
"min_tokens",
278279
}
279280
kwargs = body.model_dump(exclude=exclude)
280281

@@ -288,6 +289,15 @@ async def create_completion(
288289
if body.grammar is not None:
289290
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
290291

292+
if body.min_tokens > 0:
293+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
294+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
295+
)
296+
if "logits_processor" not in kwargs:
297+
kwargs["logits_processor"] = _min_tokens_logits_processor
298+
else:
299+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
300+
291301
iterator_or_completion: Union[
292302
llama_cpp.CreateCompletionResponse,
293303
Iterator[llama_cpp.CreateCompletionStreamResponse],
@@ -445,6 +455,7 @@ async def create_chat_completion(
445455
"n",
446456
"logit_bias_type",
447457
"user",
458+
"min_tokens",
448459
}
449460
kwargs = body.model_dump(exclude=exclude)
450461
llama = llama_proxy(body.model)
@@ -458,6 +469,15 @@ async def create_chat_completion(
458469
if body.grammar is not None:
459470
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
460471

472+
if body.min_tokens > 0:
473+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
474+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
475+
)
476+
if "logits_processor" not in kwargs:
477+
kwargs["logits_processor"] = _min_tokens_logits_processor
478+
else:
479+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
480+
461481
iterator_or_completion: Union[
462482
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
463483
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)

‎llama_cpp/server/types.py

Copy file name to clipboardExpand all lines: llama_cpp/server/types.py
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
default=16, ge=1, description="The maximum number of tokens to generate."
1717
)
1818

19+
min_tokens_field = Field(
20+
default=0,
21+
ge=0,
22+
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
23+
)
24+
1925
temperature_field = Field(
2026
default=0.8,
2127
description="Adjust the randomness of the generated text.\n\n"
@@ -111,6 +117,7 @@ class CreateCompletionRequest(BaseModel):
111117
max_tokens: Optional[int] = Field(
112118
default=16, ge=0, description="The maximum number of tokens to generate."
113119
)
120+
min_tokens: int = min_tokens_field
114121
temperature: float = temperature_field
115122
top_p: float = top_p_field
116123
min_p: float = min_p_field
@@ -206,6 +213,7 @@ class CreateChatCompletionRequest(BaseModel):
206213
default=None,
207214
description="The maximum number of tokens to generate. Defaults to inf",
208215
)
216+
min_tokens: int = min_tokens_field
209217
logprobs: Optional[bool] = Field(
210218
default=False,
211219
description="Whether to output the logprobs or not. Default is True"

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.