Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f568bae

Browse filesBrowse files
authored
Merge pull request abetlen#351 from player1537-forks/th/add-logits-bias-parameter
Add support for `logit_bias` and `logit_bias_type` parameters
2 parents abf6d4a + eb7645b commit f568bae
Copy full SHA for f568bae

File tree

Expand file treeCollapse file tree

2 files changed

+53
-2
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+53
-2
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,7 @@ def create_chat_completion(
13781378
mirostat_tau: float = 5.0,
13791379
mirostat_eta: float = 0.1,
13801380
model: Optional[str] = None,
1381+
logits_processor: Optional[LogitsProcessorList] = None,
13811382
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
13821383
"""Generate a chat completion from a list of messages.
13831384
@@ -1419,6 +1420,7 @@ def create_chat_completion(
14191420
mirostat_tau=mirostat_tau,
14201421
mirostat_eta=mirostat_eta,
14211422
model=model,
1423+
logits_processor=logits_processor,
14221424
)
14231425
if stream:
14241426
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+51-2Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,14 @@ class CreateCompletionRequest(BaseModel):
255255
)
256256
presence_penalty: Optional[float] = presence_penalty_field
257257
frequency_penalty: Optional[float] = frequency_penalty_field
258+
logit_bias: Optional[Dict[str, float]] = Field(None)
259+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
258260

259261
# ignored or currently unsupported
260262
model: Optional[str] = model_field
261263
n: Optional[int] = 1
262264
logprobs: Optional[int] = Field(None)
263265
best_of: Optional[int] = 1
264-
logit_bias: Optional[Dict[str, float]] = Field(None)
265266
user: Optional[str] = Field(None)
266267

267268
# llama.cpp specific parameters
@@ -280,6 +281,39 @@ class Config:
280281
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
281282

282283

284+
def make_logit_bias_processor(
285+
llama: llama_cpp.Llama,
286+
logit_bias: Dict[str, float],
287+
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
288+
):
289+
if logit_bias_type is None:
290+
logit_bias_type = "input_ids"
291+
292+
to_bias: Dict[int, float] = {}
293+
if logit_bias_type == "input_ids":
294+
for input_id, score in logit_bias.items():
295+
input_id = int(input_id)
296+
to_bias[input_id] = score
297+
298+
elif logit_bias_type == "tokens":
299+
for token, score in logit_bias.items():
300+
token = token.encode('utf-8')
301+
for input_id in llama.tokenize(token, add_bos=False):
302+
to_bias[input_id] = score
303+
304+
def logit_bias_processor(
305+
input_ids: List[int],
306+
scores: List[float],
307+
) -> List[float]:
308+
new_scores = [None] * len(scores)
309+
for input_id, score in enumerate(scores):
310+
new_scores[input_id] = score + to_bias.get(input_id, 0.0)
311+
312+
return new_scores
313+
314+
return logit_bias_processor
315+
316+
283317
@router.post(
284318
"/v1/completions",
285319
response_model=CreateCompletionResponse,
@@ -297,9 +331,16 @@ async def create_completion(
297331
"n",
298332
"best_of",
299333
"logit_bias",
334+
"logit_bias_type",
300335
"user",
301336
}
302337
kwargs = body.dict(exclude=exclude)
338+
339+
if body.logit_bias is not None:
340+
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
341+
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
342+
])
343+
303344
if body.stream:
304345
send_chan, recv_chan = anyio.create_memory_object_stream(10)
305346

@@ -378,11 +419,12 @@ class CreateChatCompletionRequest(BaseModel):
378419
stream: bool = stream_field
379420
presence_penalty: Optional[float] = presence_penalty_field
380421
frequency_penalty: Optional[float] = frequency_penalty_field
422+
logit_bias: Optional[Dict[str, float]] = Field(None)
423+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
381424

382425
# ignored or currently unsupported
383426
model: Optional[str] = model_field
384427
n: Optional[int] = 1
385-
logit_bias: Optional[Dict[str, float]] = Field(None)
386428
user: Optional[str] = Field(None)
387429

388430
# llama.cpp specific parameters
@@ -419,9 +461,16 @@ async def create_chat_completion(
419461
exclude = {
420462
"n",
421463
"logit_bias",
464+
"logit_bias_type",
422465
"user",
423466
}
424467
kwargs = body.dict(exclude=exclude)
468+
469+
if body.logit_bias is not None:
470+
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
471+
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
472+
])
473+
425474
if body.stream:
426475
send_chan, recv_chan = anyio.create_memory_object_stream(10)
427476

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.