Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fe331ec

Browse filesBrowse files
committed
Replace eval_logits and eval_tokens with numpy arrays
1 parent efb763b commit fe331ec
Copy full SHA for fe331ec

File tree

Expand file treeCollapse file tree

1 file changed

+16
-12
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+16
-12
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+16-12Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def reset(self):
299299
"""Reset the model state."""
300300
self.eval_tokens.clear()
301301
self.eval_logits.clear()
302+
self._input_ids = np.array([], dtype=np.intc)
303+
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
302304

303305
def eval(self, tokens: Sequence[int]):
304306
"""Evaluate a list of tokens.
@@ -310,7 +312,7 @@ def eval(self, tokens: Sequence[int]):
310312
n_ctx = self._n_ctx
311313
for i in range(0, len(tokens), self.n_batch):
312314
batch = tokens[i : min(len(tokens), i + self.n_batch)]
313-
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
315+
n_past = min(n_ctx - len(batch), len(self._input_ids))
314316
n_tokens = len(batch)
315317
return_code = llama_cpp.llama_eval(
316318
ctx=self.ctx,
@@ -356,6 +358,7 @@ def _sample(
356358
):
357359
assert self.ctx is not None
358360
assert len(self.eval_logits) > 0
361+
assert self._scores.shape[0] > 0
359362
n_vocab = self._n_vocab
360363
n_ctx = self._n_ctx
361364
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@@ -368,7 +371,7 @@ def _sample(
368371

369372
if logits_processor is not None:
370373
logits = np.array(
371-
logits_processor(list(self.eval_tokens), logits.tolist()),
374+
logits_processor(self._input_ids.tolist(), logits.tolist()),
372375
dtype=np.single,
373376
)
374377
self._scores[-1, :] = logits
@@ -498,8 +501,8 @@ def sample(
498501
"""
499502
assert self.ctx is not None
500503
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
501-
0, self.last_n_tokens_size - len(self.eval_tokens)
502-
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
504+
0, self.last_n_tokens_size - len(self._input_ids)
505+
) + self._input_ids[-self.last_n_tokens_size :].tolist()
503506
return self._sample(
504507
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
505508
*last_n_tokens_data
@@ -557,9 +560,9 @@ def generate(
557560
"""
558561
assert self.ctx is not None
559562

560-
if reset and len(self.eval_tokens) > 0:
563+
if reset and len(self._input_ids) > 0:
561564
longest_prefix = 0
562-
for a, b in zip(self.eval_tokens, tokens[:-1]):
565+
for a, b in zip(self._input_ids, tokens[:-1]):
563566
if a == b:
564567
longest_prefix += 1
565568
else:
@@ -569,6 +572,8 @@ def generate(
569572
print("Llama.generate: prefix-match hit", file=sys.stderr)
570573
reset = False
571574
tokens = tokens[longest_prefix:]
575+
self._input_ids = self._input_ids[:longest_prefix]
576+
self._scores = self._scores[:longest_prefix, :]
572577
for _ in range(len(self.eval_tokens) - longest_prefix):
573578
self.eval_tokens.pop()
574579
try:
@@ -595,7 +600,7 @@ def generate(
595600
logits_processor=logits_processor,
596601
)
597602
if stopping_criteria is not None and stopping_criteria(
598-
list(self.eval_tokens), self.eval_logits[-1]
603+
self._input_ids.tolist(), self._scores[-1, :].tolist()
599604
):
600605
return
601606
tokens_or_none = yield token
@@ -820,7 +825,7 @@ def _create_completion(
820825
self.detokenize(completion_tokens[:returned_tokens])
821826
)
822827
token_offset = len(prompt_tokens) + returned_tokens
823-
logits = self.eval_logits[token_offset - 1]
828+
logits = self._scores[token_offset - 1, :].tolist()
824829
current_logprobs = Llama.logits_to_logprobs(logits)
825830
sorted_logprobs = list(
826831
sorted(
@@ -869,7 +874,7 @@ def _create_completion(
869874
break
870875

871876
if stopping_criteria is not None and stopping_criteria(
872-
list(self.eval_tokens), self.eval_logits[-1]
877+
self._input_ids.tolist(), self._scores[-1, :].tolist()
873878
):
874879
text = self.detokenize(completion_tokens)
875880
finish_reason = "stop"
@@ -899,7 +904,7 @@ def _create_completion(
899904
self.detokenize(completion_tokens[:returned_tokens])
900905
)
901906
token_offset = len(prompt_tokens) + returned_tokens - 1
902-
logits = self.eval_logits[token_offset]
907+
logits = self._scores[token_offset, :].tolist()
903908
current_logprobs = Llama.logits_to_logprobs(logits)
904909
sorted_logprobs = list(
905910
sorted(
@@ -1001,8 +1006,7 @@ def _create_completion(
10011006
for token in all_tokens
10021007
]
10031008
all_logprobs = [
1004-
Llama.logits_to_logprobs(list(map(float, row)))
1005-
for row in self.eval_logits
1009+
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
10061010
][token_offset:]
10071011
for token, token_str, logprobs_token in zip(
10081012
all_tokens, all_token_strs, all_logprobs

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.