Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8649d76

Browse filesBrowse files
committed
fix: segfault when logits_all=False. Closes abetlen#1319
1 parent f96de6d commit 8649d76
Copy full SHA for 8649d76

File tree

Expand file treeCollapse file tree

1 file changed

+10
-8
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+10
-8
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+10-8Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,16 @@ def eval(self, tokens: Sequence[int]):
535535
# Save tokens
536536
self.input_ids[n_past : n_past + n_tokens] = batch
537537
# Save logits
538-
rows = n_tokens
539-
cols = self._n_vocab
540-
offset = (
541-
0 if self.context_params.logits_all else n_tokens - 1
542-
) # NOTE: Only save the last token logits if logits_all is False
543-
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[
544-
:
545-
] = self._ctx.get_logits()[offset * cols : rows * cols]
538+
if self.context_params.logits_all:
539+
rows = n_tokens
540+
cols = self._n_vocab
541+
logits = self._ctx.get_logits()[: rows * cols]
542+
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
543+
else:
544+
rows = 1
545+
cols = self._n_vocab
546+
logits = self._ctx.get_logits()[: rows * cols]
547+
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
546548
# Update n_tokens
547549
self.n_tokens += n_tokens
548550

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.