Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 5212fb0

Browse filesBrowse files
authored
feat: add MinTokensLogitProcessor and min_tokens argument to server (abetlen#1333)
* implement min_tokens * set default to 0 * pass min_tokens * fix * remove copy * implement MinTokensLogitsProcessor * format * fix condition
1 parent 389e09c commit 5212fb0
Copy full SHA for 5212fb0

File tree

Expand file treeCollapse file tree

3 files changed

+44
-0
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+44
-0
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+16Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,3 +2084,19 @@ def __call__(
20842084
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
20852085
) -> bool:
20862086
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
2087+
2088+
2089+
class MinTokensLogitsProcessor(LogitsProcessor):
2090+
def __init__(self, min_tokens: int, token_eos: int):
2091+
self.min_tokens = min_tokens
2092+
self.token_eos = token_eos
2093+
self.prompt_tokens = None
2094+
2095+
def __call__(
2096+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
2097+
) -> npt.NDArray[np.single]:
2098+
if self.prompt_tokens is None:
2099+
self.prompt_tokens = len(input_ids)
2100+
if len(input_ids) - self.prompt_tokens < self.min_tokens:
2101+
scores[self.token_eos] = -np.inf
2102+
return scores

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+20Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ async def create_completion(
275275
"best_of",
276276
"logit_bias_type",
277277
"user",
278+
"min_tokens",
278279
}
279280
kwargs = body.model_dump(exclude=exclude)
280281

@@ -288,6 +289,15 @@ async def create_completion(
288289
if body.grammar is not None:
289290
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
290291

292+
if body.min_tokens > 0:
293+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
294+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
295+
)
296+
if "logits_processor" not in kwargs:
297+
kwargs["logits_processor"] = _min_tokens_logits_processor
298+
else:
299+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
300+
291301
iterator_or_completion: Union[
292302
llama_cpp.CreateCompletionResponse,
293303
Iterator[llama_cpp.CreateCompletionStreamResponse],
@@ -445,6 +455,7 @@ async def create_chat_completion(
445455
"n",
446456
"logit_bias_type",
447457
"user",
458+
"min_tokens",
448459
}
449460
kwargs = body.model_dump(exclude=exclude)
450461
llama = llama_proxy(body.model)
@@ -458,6 +469,15 @@ async def create_chat_completion(
458469
if body.grammar is not None:
459470
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
460471

472+
if body.min_tokens > 0:
473+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
474+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
475+
)
476+
if "logits_processor" not in kwargs:
477+
kwargs["logits_processor"] = _min_tokens_logits_processor
478+
else:
479+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
480+
461481
iterator_or_completion: Union[
462482
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
463483
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)

‎llama_cpp/server/types.py

Copy file name to clipboardExpand all lines: llama_cpp/server/types.py
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
default=16, ge=1, description="The maximum number of tokens to generate."
1717
)
1818

19+
min_tokens_field = Field(
20+
default=0,
21+
ge=0,
22+
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
23+
)
24+
1925
temperature_field = Field(
2026
default=0.8,
2127
description="Adjust the randomness of the generated text.\n\n"
@@ -111,6 +117,7 @@ class CreateCompletionRequest(BaseModel):
111117
max_tokens: Optional[int] = Field(
112118
default=16, ge=0, description="The maximum number of tokens to generate."
113119
)
120+
min_tokens: int = min_tokens_field
114121
temperature: float = temperature_field
115122
top_p: float = top_p_field
116123
min_p: float = min_p_field
@@ -206,6 +213,7 @@ class CreateChatCompletionRequest(BaseModel):
206213
default=None,
207214
description="The maximum number of tokens to generate. Defaults to inf",
208215
)
216+
min_tokens: int = min_tokens_field
209217
logprobs: Optional[bool] = Field(
210218
default=False,
211219
description="Whether to output the logprobs or not. Default is True"

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.