Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 03e2947

Browse filesBrowse files
committed
Fix unnecessary memory allocation while sampling
1 parent fafe471 commit 03e2947
Copy full SHA for 03e2947

File tree

Expand file treeCollapse file tree

1 file changed

+33
-21
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+33
-21
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+33-21Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,28 @@ def __init__(
176176

177177
if self.verbose:
178178
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
179+
180+
181+
n_vocab = self.n_vocab()
182+
n_ctx = self.n_ctx()
183+
data = (llama_cpp.llama_token_data * n_vocab)(
184+
*[
185+
llama_cpp.llama_token_data(
186+
id=llama_cpp.llama_token(i),
187+
logit=llama_cpp.c_float(0.0),
188+
p=llama_cpp.c_float(0.0),
189+
)
190+
for i in range(n_vocab)
191+
]
192+
)
193+
size = llama_cpp.c_size_t(n_vocab)
194+
sorted = False
195+
candidates = llama_cpp.llama_token_data_array(
196+
data=data,
197+
size=size,
198+
sorted=sorted,
199+
)
200+
self._candidates = candidates
179201

180202
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
181203
"""Tokenize a string.
@@ -296,33 +318,23 @@ def _sample(
296318
):
297319
assert self.ctx is not None
298320
assert len(self.eval_logits) > 0
299-
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
300-
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
321+
n_vocab = self.n_vocab()
322+
n_ctx = self.n_ctx()
301323
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
302324
last_n_tokens_size = (
303325
llama_cpp.c_int(n_ctx)
304326
if last_n_tokens_size.value < 0
305327
else last_n_tokens_size
306328
)
307329
logits = self.eval_logits[-1]
308-
nl_logit = logits[int(Llama.token_nl())]
309-
data = (llama_cpp.llama_token_data * n_vocab)(
310-
*[
311-
llama_cpp.llama_token_data(
312-
id=llama_cpp.llama_token(i),
313-
logit=logits[i],
314-
p=llama_cpp.c_float(0.0),
315-
)
316-
for i in range(n_vocab)
317-
]
318-
)
319-
size = llama_cpp.c_size_t(n_vocab)
320-
sorted = False
321-
candidates = llama_cpp.llama_token_data_array(
322-
data=data,
323-
size=size,
324-
sorted=sorted,
325-
)
330+
nl_logit = logits[Llama.token_nl()]
331+
candidates = self._candidates
332+
for i, logit in enumerate(logits):
333+
candidates.data[i].id = llama_cpp.llama_token(i)
334+
candidates.data[i].logit = llama_cpp.c_float(logit)
335+
candidates.data[i].p = llama_cpp.c_float(0.0)
336+
candidates.sorted = llama_cpp.c_bool(False)
337+
candidates.size = llama_cpp.c_size_t(n_vocab)
326338
llama_cpp.llama_sample_repetition_penalty(
327339
ctx=self.ctx,
328340
last_tokens_data=last_n_tokens_data,
@@ -339,7 +351,7 @@ def _sample(
339351
alpha_presence=presence_penalty,
340352
)
341353
if not penalize_nl:
342-
candidates.data[int(Llama.token_nl())].logit = nl_logit
354+
candidates.data[Llama.token_nl()].logit = nl_logit
343355
if temp.value == 0.0:
344356
return llama_cpp.llama_sample_token_greedy(
345357
ctx=self.ctx,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.