Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6b2e0e0

Browse filesBrowse files
authored
perf: Don't convert logprobs arrays to lists (abetlen#1021)
1 parent 62944df commit 6b2e0e0
Copy full SHA for 6b2e0e0

File tree

Expand file treeCollapse file tree

1 file changed

+6
-7
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+6
-7
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+6-7Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,7 +1552,7 @@ def logit_bias_processor(
15521552
self.detokenize(completion_tokens[:returned_tokens])
15531553
)
15541554
token_offset = len(prompt_tokens) + returned_tokens
1555-
logits = self._scores[token_offset - 1, :].tolist()
1555+
logits = self._scores[token_offset - 1, :]
15561556
current_logprobs = Llama.logits_to_logprobs(logits)
15571557
sorted_logprobs = list(
15581558
sorted(
@@ -1671,7 +1671,7 @@ def logit_bias_processor(
16711671
self.detokenize(completion_tokens[:returned_tokens])
16721672
)
16731673
token_offset = len(prompt_tokens) + returned_tokens - 1
1674-
logits = self._scores[token_offset, :].tolist()
1674+
logits = self._scores[token_offset, :]
16751675
current_logprobs = Llama.logits_to_logprobs(logits)
16761676
sorted_logprobs = list(
16771677
sorted(
@@ -1785,9 +1785,8 @@ def logit_bias_processor(
17851785
self.detokenize([token]).decode("utf-8", errors="ignore")
17861786
for token in all_tokens
17871787
]
1788-
all_logprobs = [
1789-
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
1790-
][token_offset:]
1788+
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
1789+
# TODO: may be able to change this loop to use np.take_along_dim
17911790
for token, token_str, logprobs_token in zip(
17921791
all_tokens, all_token_strs, all_logprobs
17931792
):
@@ -2282,7 +2281,7 @@ def token_nl(self) -> int:
22822281

22832282
@staticmethod
22842283
def logits_to_logprobs(
2285-
logits: Union[List, npt.NDArray[np.single]], axis: int = -1
2284+
logits: Union[npt.NDArray[np.single], List], axis: int = -1
22862285
) -> npt.NDArray[np.single]:
22872286
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
22882287
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
@@ -2293,7 +2292,7 @@ def logits_to_logprobs(
22932292
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
22942293
exp = np.exp(subtract_maxs)
22952294
# Suppress warnings about log of zero
2296-
with np.errstate(divide='ignore'):
2295+
with np.errstate(divide="ignore"):
22972296
summed = np.sum(exp, axis=axis, keepdims=True)
22982297
out = np.log(summed)
22992298
return subtract_maxs - out

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.