Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 350a176

Browse filesBrowse files
committed
Update sampling api
1 parent 7837c3f commit 350a176
Copy full SHA for 350a176

File tree

Expand file treeCollapse file tree

2 files changed

+113
-28
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+113
-28
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+99-20Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def __init__(
127127
self.last_n_tokens_size = last_n_tokens_size
128128
self.n_batch = min(n_ctx, n_batch)
129129
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
130-
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
130+
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
131+
maxlen=n_ctx if logits_all else 1
132+
)
131133

132134
self.cache: Optional[LlamaCache] = None
133135

@@ -236,17 +238,90 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
236238
)
237239
if int(return_code) != 0:
238240
raise RuntimeError(f"llama_eval returned {return_code}")
241+
# Save tokens
239242
self.eval_tokens.extend(batch)
240-
if self.params.logits_all:
241-
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
242-
cols = int(n_vocab)
243-
rows = n_tokens
244-
logits_view = llama_cpp.llama_get_logits(self.ctx)
245-
logits = [
246-
[logits_view[i * cols + j] for j in range(cols)]
247-
for i in range(rows)
248-
]
249-
self.eval_logits.extend(logits)
243+
# Save logits
244+
rows = n_tokens if self.params.logits_all else 1
245+
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
246+
cols = int(n_vocab)
247+
logits_view = llama_cpp.llama_get_logits(self.ctx)
248+
logits: List[List[llama_cpp.c_float]] = [
249+
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
250+
]
251+
self.eval_logits.extend(logits)
252+
253+
def _sample_top_p_top_k(
254+
self,
255+
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
256+
last_n_tokens_size: llama_cpp.c_int,
257+
top_k: llama_cpp.c_int,
258+
top_p: llama_cpp.c_float,
259+
temp: llama_cpp.c_float,
260+
repeat_penalty: llama_cpp.c_float,
261+
):
262+
assert self.ctx is not None
263+
assert len(self.eval_logits) > 0
264+
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
265+
logits = self.eval_logits[-1]
266+
data = (llama_cpp.llama_token_data * n_vocab)(
267+
*[
268+
llama_cpp.llama_token_data(
269+
id=llama_cpp.llama_token(i),
270+
logit=logits[i],
271+
p=llama_cpp.c_float(0.0),
272+
)
273+
for i in range(n_vocab)
274+
]
275+
)
276+
size = llama_cpp.c_size_t(n_vocab)
277+
sorted = False
278+
candidates = llama_cpp.llama_token_data_array(
279+
data=data,
280+
size=size,
281+
sorted=sorted,
282+
)
283+
llama_cpp.llama_sample_repetition_penalty(
284+
ctx=self.ctx,
285+
last_tokens_data=last_n_tokens_data,
286+
last_tokens_size=last_n_tokens_size,
287+
candidates=llama_cpp.ctypes.pointer(candidates),
288+
penalty=repeat_penalty,
289+
)
290+
if temp == 0.0:
291+
return llama_cpp.llama_sample_token_greedy(
292+
ctx=self.ctx,
293+
candidates=llama_cpp.ctypes.pointer(candidates),
294+
)
295+
else:
296+
llama_cpp.llama_sample_top_k(
297+
ctx=self.ctx,
298+
candidates=llama_cpp.ctypes.pointer(candidates),
299+
k=top_k,
300+
)
301+
llama_cpp.llama_sample_tail_free(
302+
ctx=self.ctx,
303+
candidates=llama_cpp.ctypes.pointer(candidates),
304+
z=llama_cpp.c_float(1.0),
305+
)
306+
llama_cpp.llama_sample_typical(
307+
ctx=self.ctx,
308+
candidates=llama_cpp.ctypes.pointer(candidates),
309+
p=llama_cpp.c_float(1.0)
310+
)
311+
llama_cpp.llama_sample_top_p(
312+
ctx=self.ctx,
313+
candidates=llama_cpp.ctypes.pointer(candidates),
314+
p=top_p,
315+
)
316+
llama_cpp.llama_sample_temperature(
317+
ctx=self.ctx,
318+
candidates=llama_cpp.ctypes.pointer(candidates),
319+
temp=temp,
320+
)
321+
return llama_cpp.llama_sample_token(
322+
ctx=self.ctx,
323+
candidates=llama_cpp.ctypes.pointer(candidates),
324+
)
250325

251326
def sample(
252327
self,
@@ -270,8 +345,7 @@ def sample(
270345
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
271346
0, self.last_n_tokens_size - len(self.eval_tokens)
272347
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
273-
return llama_cpp.llama_sample_top_p_top_k(
274-
ctx=self.ctx,
348+
return self._sample_top_p_top_k(
275349
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
276350
*last_n_tokens_data
277351
),
@@ -470,15 +544,15 @@ def _create_completion(
470544
all_text = self.detokenize(completion_tokens)
471545

472546
# Contains multi-byte UTF8
473-
for k,char in enumerate(all_text[-3:]):
547+
for k, char in enumerate(all_text[-3:]):
474548
k = 3 - k
475-
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
549+
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
476550
# Bitwise AND check
477-
if (num > k and pattern & char == pattern):
551+
if num > k and pattern & char == pattern:
478552
multibyte_fix = num - k
479553

480554
# Stop incomplete bytes from passing
481-
if (multibyte_fix > 0):
555+
if multibyte_fix > 0:
482556
multibyte_fix -= 1
483557
continue
484558

@@ -531,7 +605,9 @@ def _create_completion(
531605
"model": self.model_path,
532606
"choices": [
533607
{
534-
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
608+
"text": text[returned_characters:].decode(
609+
"utf-8", errors="ignore"
610+
),
535611
"index": 0,
536612
"logprobs": None,
537613
"finish_reason": finish_reason,
@@ -558,7 +634,8 @@ def _create_completion(
558634

559635
all_tokens = prompt_tokens + completion_tokens
560636
all_token_strs = [
561-
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
637+
self.detokenize([token]).decode("utf-8", errors="ignore")
638+
for token in all_tokens
562639
]
563640
all_logprobs = [
564641
[Llama.logit_to_logprob(logit) for logit in row]
@@ -577,7 +654,9 @@ def _create_completion(
577654
)
578655
token_logprobs.append(sorted_logprobs[int(token)][0])
579656
top_logprob = {
580-
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
657+
self.detokenize([llama_cpp.llama_token(i)]).decode(
658+
"utf-8", errors="ignore"
659+
): logprob
581660
for logprob, i in sorted_logprobs[:logprobs]
582661
}
583662
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})

‎llama_cpp/llama_cpp.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cpp.py
+14-8Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -495,36 +495,40 @@ def llama_sample_softmax(ctx: llama_context_p, candidates):
495495

496496

497497
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
498-
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
498+
def llama_sample_top_k(
499+
ctx: llama_context_p, candidates, k: c_int, min_keep: c_size_t = c_size_t(1)
500+
):
499501
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
500502

501503

502504
_lib.llama_sample_top_k.argtypes = [
503505
llama_context_p,
504506
llama_token_data_array_p,
505507
c_int,
506-
c_int,
508+
c_size_t,
507509
]
508510
_lib.llama_sample_top_k.restype = None
509511

510512

511513
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
512-
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
514+
def llama_sample_top_p(
515+
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
516+
):
513517
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
514518

515519

516520
_lib.llama_sample_top_p.argtypes = [
517521
llama_context_p,
518522
llama_token_data_array_p,
519523
c_float,
520-
c_int,
524+
c_size_t,
521525
]
522526
_lib.llama_sample_top_p.restype = None
523527

524528

525529
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
526530
def llama_sample_tail_free(
527-
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
531+
ctx: llama_context_p, candidates, z: c_float, min_keep: c_size_t = c_size_t(1)
528532
):
529533
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
530534

@@ -533,21 +537,23 @@ def llama_sample_tail_free(
533537
llama_context_p,
534538
llama_token_data_array_p,
535539
c_float,
536-
c_int,
540+
c_size_t,
537541
]
538542
_lib.llama_sample_tail_free.restype = None
539543

540544

541545
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
542-
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
546+
def llama_sample_typical(
547+
ctx: llama_context_p, candidates, p: c_float, min_keep: c_size_t = c_size_t(1)
548+
):
543549
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
544550

545551

546552
_lib.llama_sample_typical.argtypes = [
547553
llama_context_p,
548554
llama_token_data_array_p,
549555
c_float,
550-
c_int,
556+
c_size_t,
551557
]
552558
_lib.llama_sample_typical.restype = None
553559

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.