@@ -141,7 +141,7 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141
141
if _key is None :
142
142
raise KeyError ("Key not found" )
143
143
value : "LlamaState" = self .cache .pop (_key ) # type: ignore
144
- # NOTE: This puts an integer as key in cache, which breaks,
144
+ # NOTE: This puts an integer as key in cache, which breaks,
145
145
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146
146
# self.cache.push(_key, side="front") # type: ignore
147
147
return value
@@ -166,17 +166,15 @@ def __setitem__(self, key: Sequence[int], value: "LlamaState"):
166
166
class LlamaState :
167
167
def __init__ (
168
168
self ,
169
- eval_tokens : Deque [int ],
170
- eval_logits : Deque [List [float ]],
171
169
input_ids : npt .NDArray [np .intc ],
172
170
scores : npt .NDArray [np .single ],
171
+ n_tokens : int ,
173
172
llama_state : bytes ,
174
173
llama_state_size : int ,
175
174
):
176
- self .eval_tokens = eval_tokens
177
- self .eval_logits = eval_logits
178
175
self .input_ids = input_ids
179
176
self .scores = scores
177
+ self .n_tokens = n_tokens
180
178
self .llama_state = llama_state
181
179
self .llama_state_size = llama_state_size
182
180
@@ -267,8 +265,6 @@ def __init__(
267
265
268
266
self .last_n_tokens_size = last_n_tokens_size
269
267
self .n_batch = min (n_ctx , n_batch )
270
- self .eval_tokens : Deque [int ] = deque (maxlen = n_ctx )
271
- self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx if logits_all else 1 )
272
268
273
269
self .cache : Optional [BaseLlamaCache ] = None
274
270
@@ -329,8 +325,30 @@ def __init__(
329
325
self ._token_nl = Llama .token_nl ()
330
326
self ._token_eos = Llama .token_eos ()
331
327
332
- self ._input_ids = np .array ([], dtype = np .intc )
333
- self ._scores : npt .NDArray [np .single ] = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
328
+ self .n_tokens = 0
329
+ self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
330
+ self .scores : npt .NDArray [np .single ] = np .ndarray (
331
+ (n_ctx , self ._n_vocab ), dtype = np .single
332
+ )
333
+
334
+ @property
335
+ def _input_ids (self ) -> npt .NDArray [np .intc ]:
336
+ return self .input_ids [: self .n_tokens ]
337
+
338
+ @property
339
+ def _scores (self ) -> npt .NDArray [np .single ]:
340
+ return self .scores [: self .n_tokens , :]
341
+
342
+ @property
343
+ def eval_tokens (self ) -> Deque [int ]:
344
+ return deque (self .input_ids [: self .n_tokens ].tolist (), maxlen = self ._n_ctx )
345
+
346
+ @property
347
+ def eval_logits (self ) -> Deque [List [float ]]:
348
+ return deque (
349
+ self .scores [: self .n_tokens , :].tolist (),
350
+ maxlen = self ._n_ctx if self .params .logits_all else 1 ,
351
+ )
334
352
335
353
def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
336
354
"""Tokenize a string.
@@ -397,10 +415,7 @@ def set_cache(self, cache: Optional[BaseLlamaCache]):
397
415
398
416
def reset (self ):
399
417
"""Reset the model state."""
400
- self .eval_tokens .clear ()
401
- self .eval_logits .clear ()
402
- self ._input_ids = np .array ([], dtype = np .intc )
403
- self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
418
+ self .n_tokens = 0
404
419
405
420
def eval (self , tokens : Sequence [int ]):
406
421
"""Evaluate a list of tokens.
@@ -410,7 +425,6 @@ def eval(self, tokens: Sequence[int]):
410
425
"""
411
426
assert self .ctx is not None
412
427
n_ctx = self ._n_ctx
413
- scores : List [npt .NDArray [np .single ]] = []
414
428
for i in range (0 , len (tokens ), self .n_batch ):
415
429
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
416
430
n_past = min (n_ctx - len (batch ), len (self ._input_ids ))
@@ -425,19 +439,14 @@ def eval(self, tokens: Sequence[int]):
425
439
if return_code != 0 :
426
440
raise RuntimeError (f"llama_eval returned { return_code } " )
427
441
# Save tokens
428
- self .eval_tokens .extend (batch )
429
- self ._input_ids : npt .NDArray [np .intc ] = np .concatenate (
430
- (self ._input_ids , np .array (batch , dtype = np .intc )), axis = 0
431
- )
442
+ self .input_ids [self .n_tokens : self .n_tokens + n_tokens ] = batch
432
443
# Save logits
433
444
rows = n_tokens if self .params .logits_all else 1
434
- n_vocab = self ._n_vocab
435
- cols = n_vocab
436
- logits_view = llama_cpp .llama_get_logits (self .ctx )
437
- logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
438
- self .eval_logits .extend (logits )
439
- scores .append (np .array (logits , dtype = np .single ))
440
- self ._scores = np .concatenate (scores )
445
+ cols = self ._n_vocab
446
+ offset = 0 if self .params .logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False
447
+ self .scores [self .n_tokens + offset : self .n_tokens + n_tokens , :].reshape (- 1 )[:] = llama_cpp .llama_get_logits (self .ctx )[:rows * cols ]
448
+ # Update n_tokens
449
+ self .n_tokens += n_tokens
441
450
442
451
def _sample (
443
452
self ,
@@ -457,8 +466,7 @@ def _sample(
457
466
logits_processor : Optional [LogitsProcessorList ] = None ,
458
467
):
459
468
assert self .ctx is not None
460
- assert len (self .eval_logits ) > 0
461
- assert self ._scores .shape [0 ] > 0
469
+ assert self .n_tokens > 0
462
470
n_vocab = self ._n_vocab
463
471
n_ctx = self ._n_ctx
464
472
top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
@@ -475,7 +483,6 @@ def _sample(
475
483
dtype = np .single ,
476
484
)
477
485
self ._scores [- 1 , :] = logits
478
- self .eval_logits [- 1 ] = logits .tolist ()
479
486
480
487
nl_logit = logits [self ._token_nl ]
481
488
candidates = self ._candidates
@@ -672,14 +679,7 @@ def generate(
672
679
print ("Llama.generate: prefix-match hit" , file = sys .stderr )
673
680
reset = False
674
681
tokens = tokens [longest_prefix :]
675
- self ._input_ids = self ._input_ids [:longest_prefix ]
676
- self ._scores = self ._scores [:longest_prefix , :]
677
- for _ in range (len (self .eval_tokens ) - longest_prefix ):
678
- self .eval_tokens .pop ()
679
- try :
680
- self .eval_logits .pop ()
681
- except IndexError :
682
- pass
682
+ self .n_tokens = longest_prefix
683
683
684
684
if reset :
685
685
self .reset ()
@@ -819,7 +819,9 @@ def _create_completion(
819
819
llama_cpp .llama_reset_timings (self .ctx )
820
820
821
821
if len (prompt_tokens ) > self ._n_ctx :
822
- raise ValueError (f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { self ._n_ctx } " )
822
+ raise ValueError (
823
+ f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { self ._n_ctx } "
824
+ )
823
825
824
826
# Truncate max_tokens if requested tokens would exceed the context window
825
827
max_tokens = (
@@ -1437,6 +1439,9 @@ def create_chat_completion(
1437
1439
return self ._convert_text_completion_to_chat (completion )
1438
1440
1439
1441
def __del__ (self ):
1442
+ if self .model is not None :
1443
+ llama_cpp .llama_free_model (self .model )
1444
+ self .model = None
1440
1445
if self .ctx is not None :
1441
1446
llama_cpp .llama_free (self .ctx )
1442
1447
self .ctx = None
@@ -1510,22 +1515,20 @@ def save_state(self) -> LlamaState:
1510
1515
file = sys .stderr ,
1511
1516
)
1512
1517
return LlamaState (
1513
- eval_tokens = self .eval_tokens .copy (),
1514
- eval_logits = self .eval_logits .copy (),
1515
- scores = self ._scores .copy (),
1516
- input_ids = self ._input_ids .copy (),
1518
+ scores = self .scores .copy (),
1519
+ input_ids = self .input_ids .copy (),
1520
+ n_tokens = self .n_tokens ,
1517
1521
llama_state = bytes (llama_state_compact ),
1518
1522
llama_state_size = n_bytes ,
1519
1523
)
1520
1524
1521
1525
def load_state (self , state : LlamaState ) -> None :
1522
1526
assert self .ctx is not None
1523
- self .eval_tokens = state .eval_tokens .copy ()
1524
- self .eval_logits = state .eval_logits .copy ()
1525
- self ._scores = state .scores .copy ()
1526
- self ._input_ids = state .input_ids .copy ()
1527
+ self .scores = state .scores .copy ()
1528
+ self .input_ids = state .input_ids .copy ()
1529
+ self .n_tokens = state .n_tokens
1527
1530
state_size = state .llama_state_size
1528
- LLamaStateArrayType = ( llama_cpp .c_uint8 * state_size )
1531
+ LLamaStateArrayType = llama_cpp .c_uint8 * state_size
1529
1532
llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
1530
1533
1531
1534
if llama_cpp .llama_set_state_data (self .ctx , llama_state ) != state_size :
0 commit comments