Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 07e47f5

Browse filesBrowse files
committed
Add support for logit_bias outside of server api. Closes abetlen#827
1 parent c21edb6 commit 07e47f5
Copy full SHA for 07e47f5

File tree

Expand file treeCollapse file tree

3 files changed

+44
-38
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+44
-38
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+25Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,7 @@ def _create_completion(
13271327
stopping_criteria: Optional[StoppingCriteriaList] = None,
13281328
logits_processor: Optional[LogitsProcessorList] = None,
13291329
grammar: Optional[LlamaGrammar] = None,
1330+
logit_bias: Optional[Dict[int, float]] = None,
13301331
) -> Union[
13311332
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
13321333
]:
@@ -1355,6 +1356,28 @@ def _create_completion(
13551356
)
13561357
model_name: str = model if model is not None else self.model_path
13571358

1359+
# NOTE: This likely doesn't work correctly for the first token in the prompt
1360+
# because of the extra space added to the start of the prompt_tokens
1361+
if logit_bias is not None:
1362+
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
1363+
1364+
def logit_bias_processor(
1365+
input_ids: npt.NDArray[np.intc],
1366+
scores: npt.NDArray[np.single],
1367+
) -> npt.NDArray[np.single]:
1368+
new_scores = np.copy(
1369+
scores
1370+
) # Does it make sense to copy the whole array or can we just overwrite the original one?
1371+
for input_id, score in logit_bias_map.items():
1372+
new_scores[input_id] = score + scores[input_id]
1373+
return new_scores
1374+
1375+
_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
1376+
if logits_processor is None:
1377+
logits_processor = _logit_bias_processor
1378+
else:
1379+
logits_processor = logits_processor.extend(_logit_bias_processor)
1380+
13581381
if self.verbose:
13591382
self._ctx.reset_timings()
13601383

@@ -1963,6 +1986,7 @@ def create_chat_completion(
19631986
model: Optional[str] = None,
19641987
logits_processor: Optional[LogitsProcessorList] = None,
19651988
grammar: Optional[LlamaGrammar] = None,
1989+
logit_bias: Optional[Dict[str, float]] = None,
19661990
) -> Union[
19671991
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
19681992
]:
@@ -2011,6 +2035,7 @@ def create_chat_completion(
20112035
model=model,
20122036
logits_processor=logits_processor,
20132037
grammar=grammar,
2038+
logit_bias=logit_bias,
20142039
)
20152040

20162041
def __getstate__(self):

‎llama_cpp/llama_chat_format.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_chat_format.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __call__(
4545
model: Optional[str] = None,
4646
logits_processor: Optional[llama.LogitsProcessorList] = None,
4747
grammar: Optional[llama.LlamaGrammar] = None,
48+
logit_bias: Optional[Dict[str, float]] = None,
4849
**kwargs, # type: ignore
4950
) -> Union[
5051
llama_types.CreateChatCompletionResponse,
@@ -308,6 +309,7 @@ def basic_create_chat_completion(
308309
model: Optional[str] = None,
309310
logits_processor: Optional[llama.LogitsProcessorList] = None,
310311
grammar: Optional[llama.LlamaGrammar] = None,
312+
logit_bias: Optional[Dict[str, float]] = None,
311313
**kwargs, # type: ignore
312314
) -> Union[
313315
llama_types.CreateChatCompletionResponse,
@@ -350,6 +352,7 @@ def basic_create_chat_completion(
350352
model=model,
351353
logits_processor=logits_processor,
352354
grammar=grammar,
355+
logit_bias=logit_bias,
353356
)
354357
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
355358

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+16-38Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -646,36 +646,16 @@ class CreateCompletionRequest(BaseModel):
646646
}
647647

648648

649-
def make_logit_bias_processor(
649+
def _logit_bias_tokens_to_input_ids(
650650
llama: llama_cpp.Llama,
651651
logit_bias: Dict[str, float],
652-
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
653-
):
654-
if logit_bias_type is None:
655-
logit_bias_type = "input_ids"
656-
657-
to_bias: Dict[int, float] = {}
658-
if logit_bias_type == "input_ids":
659-
for input_id, score in logit_bias.items():
660-
input_id = int(input_id)
661-
to_bias[input_id] = score
662-
663-
elif logit_bias_type == "tokens":
664-
for token, score in logit_bias.items():
665-
token = token.encode("utf-8")
666-
for input_id in llama.tokenize(token, add_bos=False, special=True):
667-
to_bias[input_id] = score
668-
669-
def logit_bias_processor(
670-
input_ids: npt.NDArray[np.intc],
671-
scores: npt.NDArray[np.single],
672-
) -> npt.NDArray[np.single]:
673-
new_scores = np.copy(scores) # Does it make sense to copy the whole array or can we just overwrite the original one?
674-
for input_id, score in to_bias.items():
675-
new_scores[input_id] = score + scores[input_id]
676-
return new_scores
677-
678-
return logit_bias_processor
652+
) -> Dict[str, float]:
653+
to_bias: Dict[str, float] = {}
654+
for token, score in logit_bias.items():
655+
token = token.encode("utf-8")
656+
for input_id in llama.tokenize(token, add_bos=False, special=True):
657+
to_bias[str(input_id)] = score
658+
return to_bias
679659

680660

681661
@router.post(
@@ -694,17 +674,16 @@ async def create_completion(
694674
exclude = {
695675
"n",
696676
"best_of",
697-
"logit_bias",
698677
"logit_bias_type",
699678
"user",
700679
}
701680
kwargs = body.model_dump(exclude=exclude)
702681

703682
if body.logit_bias is not None:
704-
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
705-
[
706-
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
707-
]
683+
kwargs["logit_bias"] = (
684+
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
685+
if body.logit_bias_type == "tokens"
686+
else body.logit_bias
708687
)
709688

710689
if body.grammar is not None:
@@ -851,17 +830,16 @@ async def create_chat_completion(
851830
) -> llama_cpp.ChatCompletion:
852831
exclude = {
853832
"n",
854-
"logit_bias",
855833
"logit_bias_type",
856834
"user",
857835
}
858836
kwargs = body.model_dump(exclude=exclude)
859837

860838
if body.logit_bias is not None:
861-
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
862-
[
863-
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
864-
]
839+
kwargs["logit_bias"] = (
840+
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
841+
if body.logit_bias_type == "tokens"
842+
else body.logit_bias
865843
)
866844

867845
if body.grammar is not None:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.