Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit eb7645b

Browse filesBrowse files
committed
Add support for logit_bias and logit_bias_type parameters
1 parent 0da655b commit eb7645b
Copy full SHA for eb7645b

File tree

Expand file treeCollapse file tree

2 files changed

+53
-2
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+53
-2
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,7 @@ def create_chat_completion(
13801380
mirostat_tau: float = 5.0,
13811381
mirostat_eta: float = 0.1,
13821382
model: Optional[str] = None,
1383+
logits_processor: Optional[LogitsProcessorList] = None,
13831384
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
13841385
"""Generate a chat completion from a list of messages.
13851386
@@ -1421,6 +1422,7 @@ def create_chat_completion(
14211422
mirostat_tau=mirostat_tau,
14221423
mirostat_eta=mirostat_eta,
14231424
model=model,
1425+
logits_processor=logits_processor,
14241426
)
14251427
if stream:
14261428
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+51-2Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,14 @@ class CreateCompletionRequest(BaseModel):
249249
)
250250
presence_penalty: Optional[float] = presence_penalty_field
251251
frequency_penalty: Optional[float] = frequency_penalty_field
252+
logit_bias: Optional[Dict[str, float]] = Field(None)
253+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
252254

253255
# ignored or currently unsupported
254256
model: Optional[str] = model_field
255257
n: Optional[int] = 1
256258
logprobs: Optional[int] = Field(None)
257259
best_of: Optional[int] = 1
258-
logit_bias: Optional[Dict[str, float]] = Field(None)
259260
user: Optional[str] = Field(None)
260261

261262
# llama.cpp specific parameters
@@ -274,6 +275,39 @@ class Config:
274275
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
275276

276277

278+
def make_logit_bias_processor(
279+
llama: llama_cpp.Llama,
280+
logit_bias: Dict[str, float],
281+
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
282+
):
283+
if logit_bias_type is None:
284+
logit_bias_type = "input_ids"
285+
286+
to_bias: Dict[int, float] = {}
287+
if logit_bias_type == "input_ids":
288+
for input_id, score in logit_bias.items():
289+
input_id = int(input_id)
290+
to_bias[input_id] = score
291+
292+
elif logit_bias_type == "tokens":
293+
for token, score in logit_bias.items():
294+
token = token.encode('utf-8')
295+
for input_id in llama.tokenize(token, add_bos=False):
296+
to_bias[input_id] = score
297+
298+
def logit_bias_processor(
299+
input_ids: List[int],
300+
scores: List[float],
301+
) -> List[float]:
302+
new_scores = [None] * len(scores)
303+
for input_id, score in enumerate(scores):
304+
new_scores[input_id] = score + to_bias.get(input_id, 0.0)
305+
306+
return new_scores
307+
308+
return logit_bias_processor
309+
310+
277311
@router.post(
278312
"/v1/completions",
279313
response_model=CreateCompletionResponse,
@@ -291,9 +325,16 @@ async def create_completion(
291325
"n",
292326
"best_of",
293327
"logit_bias",
328+
"logit_bias_type",
294329
"user",
295330
}
296331
kwargs = body.dict(exclude=exclude)
332+
333+
if body.logit_bias is not None:
334+
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
335+
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
336+
])
337+
297338
if body.stream:
298339
send_chan, recv_chan = anyio.create_memory_object_stream(10)
299340

@@ -372,11 +413,12 @@ class CreateChatCompletionRequest(BaseModel):
372413
stream: bool = stream_field
373414
presence_penalty: Optional[float] = presence_penalty_field
374415
frequency_penalty: Optional[float] = frequency_penalty_field
416+
logit_bias: Optional[Dict[str, float]] = Field(None)
417+
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
375418

376419
# ignored or currently unsupported
377420
model: Optional[str] = model_field
378421
n: Optional[int] = 1
379-
logit_bias: Optional[Dict[str, float]] = Field(None)
380422
user: Optional[str] = Field(None)
381423

382424
# llama.cpp specific parameters
@@ -413,9 +455,16 @@ async def create_chat_completion(
413455
exclude = {
414456
"n",
415457
"logit_bias",
458+
"logit_bias_type",
416459
"user",
417460
}
418461
kwargs = body.dict(exclude=exclude)
462+
463+
if body.logit_bias is not None:
464+
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
465+
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
466+
])
467+
419468
if body.stream:
420469
send_chan, recv_chan = anyio.create_memory_object_stream(10)
421470

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.