27
27
import numpy as np
28
28
import numpy .typing as npt
29
29
30
+
30
31
class BaseLlamaCache (ABC ):
31
32
"""Base cache class for a llama.cpp model."""
32
33
@@ -179,21 +180,27 @@ def __init__(
179
180
self .llama_state_size = llama_state_size
180
181
181
182
182
- LogitsProcessor = Callable [[List [int ], List [float ]], List [float ]]
183
+ LogitsProcessor = Callable [
184
+ [npt .NDArray [np .intc ], npt .NDArray [np .single ]], npt .NDArray [np .single ]
185
+ ]
183
186
184
187
185
188
class LogitsProcessorList (List [LogitsProcessor ]):
186
- def __call__ (self , input_ids : List [int ], scores : List [float ]) -> List [float ]:
189
+ def __call__ (
190
+ self , input_ids : npt .NDArray [np .intc ], scores : npt .NDArray [np .single ]
191
+ ) -> npt .NDArray [np .single ]:
187
192
for processor in self :
188
193
scores = processor (input_ids , scores )
189
194
return scores
190
195
191
196
192
- StoppingCriteria = Callable [[List [ int ], List [ float ]], bool ]
197
+ StoppingCriteria = Callable [[npt . NDArray [ np . intc ], npt . NDArray [ np . single ]], bool ]
193
198
194
199
195
200
class StoppingCriteriaList (List [StoppingCriteria ]):
196
- def __call__ (self , input_ids : List [int ], logits : List [float ]) -> bool :
201
+ def __call__ (
202
+ self , input_ids : npt .NDArray [np .intc ], logits : npt .NDArray [np .single ]
203
+ ) -> bool :
197
204
return any ([stopping_criteria (input_ids , logits ) for stopping_criteria in self ])
198
205
199
206
@@ -274,9 +281,11 @@ def __init__(
274
281
self ._c_tensor_split = None
275
282
276
283
if self .tensor_split is not None :
277
- #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
284
+ # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
278
285
FloatArray = ctypes .c_float * llama_cpp .LLAMA_MAX_DEVICES .value
279
- self ._c_tensor_split = FloatArray (* tensor_split ) # keep a reference to the array so it is not gc'd
286
+ self ._c_tensor_split = FloatArray (
287
+ * tensor_split
288
+ ) # keep a reference to the array so it is not gc'd
280
289
self .params .tensor_split = self ._c_tensor_split
281
290
282
291
self .params .rope_freq_base = rope_freq_base
@@ -503,11 +512,7 @@ def _sample(
503
512
logits : npt .NDArray [np .single ] = self ._scores [- 1 , :]
504
513
505
514
if logits_processor is not None :
506
- logits = np .array (
507
- logits_processor (self ._input_ids .tolist (), logits .tolist ()),
508
- dtype = np .single ,
509
- )
510
- self ._scores [- 1 , :] = logits
515
+ logits [:] = logits_processor (self ._input_ids , logits )
511
516
512
517
nl_logit = logits [self ._token_nl ]
513
518
candidates = self ._candidates
@@ -725,7 +730,7 @@ def generate(
725
730
logits_processor = logits_processor ,
726
731
)
727
732
if stopping_criteria is not None and stopping_criteria (
728
- self ._input_ids . tolist () , self ._scores [- 1 , :]. tolist ()
733
+ self ._input_ids , self ._scores [- 1 , :]
729
734
):
730
735
return
731
736
tokens_or_none = yield token
@@ -1014,7 +1019,7 @@ def _create_completion(
1014
1019
break
1015
1020
1016
1021
if stopping_criteria is not None and stopping_criteria (
1017
- self ._input_ids . tolist () , self ._scores [- 1 , :]. tolist ()
1022
+ self ._input_ids , self ._scores [- 1 , :]
1018
1023
):
1019
1024
text = self .detokenize (completion_tokens )
1020
1025
finish_reason = "stop"
0 commit comments