Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 19ba9d3

Browse filesBrowse files
committed
Use numpy arrays for logits_processors and stopping_criteria. Closes abetlen#491
1 parent 5eab1db commit 19ba9d3
Copy full SHA for 19ba9d3

File tree

Expand file treeCollapse file tree

2 files changed

+24
-16
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+24
-16
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+18-13Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
import numpy.typing as npt
2929

30+
3031
class BaseLlamaCache(ABC):
3132
"""Base cache class for a llama.cpp model."""
3233

@@ -179,21 +180,27 @@ def __init__(
179180
self.llama_state_size = llama_state_size
180181

181182

182-
LogitsProcessor = Callable[[List[int], List[float]], List[float]]
183+
LogitsProcessor = Callable[
184+
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
185+
]
183186

184187

185188
class LogitsProcessorList(List[LogitsProcessor]):
186-
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]:
189+
def __call__(
190+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
191+
) -> npt.NDArray[np.single]:
187192
for processor in self:
188193
scores = processor(input_ids, scores)
189194
return scores
190195

191196

192-
StoppingCriteria = Callable[[List[int], List[float]], bool]
197+
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
193198

194199

195200
class StoppingCriteriaList(List[StoppingCriteria]):
196-
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
201+
def __call__(
202+
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
203+
) -> bool:
197204
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
198205

199206

@@ -274,9 +281,11 @@ def __init__(
274281
self._c_tensor_split = None
275282

276283
if self.tensor_split is not None:
277-
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
284+
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
278285
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
279-
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
286+
self._c_tensor_split = FloatArray(
287+
*tensor_split
288+
) # keep a reference to the array so it is not gc'd
280289
self.params.tensor_split = self._c_tensor_split
281290

282291
self.params.rope_freq_base = rope_freq_base
@@ -503,11 +512,7 @@ def _sample(
503512
logits: npt.NDArray[np.single] = self._scores[-1, :]
504513

505514
if logits_processor is not None:
506-
logits = np.array(
507-
logits_processor(self._input_ids.tolist(), logits.tolist()),
508-
dtype=np.single,
509-
)
510-
self._scores[-1, :] = logits
515+
logits[:] = logits_processor(self._input_ids, logits)
511516

512517
nl_logit = logits[self._token_nl]
513518
candidates = self._candidates
@@ -725,7 +730,7 @@ def generate(
725730
logits_processor=logits_processor,
726731
)
727732
if stopping_criteria is not None and stopping_criteria(
728-
self._input_ids.tolist(), self._scores[-1, :].tolist()
733+
self._input_ids, self._scores[-1, :]
729734
):
730735
return
731736
tokens_or_none = yield token
@@ -1014,7 +1019,7 @@ def _create_completion(
10141019
break
10151020

10161021
if stopping_criteria is not None and stopping_criteria(
1017-
self._input_ids.tolist(), self._scores[-1, :].tolist()
1022+
self._input_ids, self._scores[-1, :]
10181023
):
10191024
text = self.detokenize(completion_tokens)
10201025
finish_reason = "stop"

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+6-3Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from pydantic_settings import BaseSettings
1717
from sse_starlette.sse import EventSourceResponse
1818

19+
import numpy as np
20+
import numpy.typing as npt
21+
1922

2023
class Settings(BaseSettings):
2124
model: str = Field(
@@ -336,9 +339,9 @@ def make_logit_bias_processor(
336339
to_bias[input_id] = score
337340

338341
def logit_bias_processor(
339-
input_ids: List[int],
340-
scores: List[float],
341-
) -> List[float]:
342+
input_ids: npt.NDArray[np.intc],
343+
scores: npt.NDArray[np.single],
344+
) -> npt.NDArray[np.single]:
342345
new_scores = [None] * len(scores)
343346
for input_id, score in enumerate(scores):
344347
new_scores[input_id] = score + to_bias.get(input_id, 0.0)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.