Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b3805bb

Browse filesBrowse files
committed
Implement logprobs parameter for text completion. Closes abetlen#2
1 parent 2a60eb8 commit b3805bb
Copy full SHA for b3805bb

File tree

Expand file treeCollapse file tree

2 files changed

+111
-16
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+111
-16
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+109-16Lines changed: 109 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import uuid
44
import time
5+
import math
56
import multiprocessing
67
from typing import List, Optional, Union, Generator, Sequence, Iterator
78
from collections import deque
@@ -76,6 +77,9 @@ def __init__(
7677
)
7778
self.tokens_consumed = 0
7879
self.n_batch = min(n_ctx, n_batch)
80+
self.n_tokens = 0
81+
self.n_past = 0
82+
self.all_logits: List[List[float]] = [] # TODO: Use an array instead of a list.
7983

8084
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
8185

@@ -136,6 +140,9 @@ def reset(self):
136140
[llama_cpp.llama_token(0)] * self.last_n_tokens_size
137141
)
138142
self.tokens_consumed = 0
143+
self.n_tokens = 0
144+
self.n_past = 0
145+
self.all_logits = []
139146

140147
def eval(self, tokens: Sequence[llama_cpp.llama_token]):
141148
"""Evaluate a list of tokens.
@@ -147,18 +154,31 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
147154
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
148155
for i in range(0, len(tokens), self.n_batch):
149156
batch = tokens[i : min(len(tokens), i + self.n_batch)]
150-
n_past = min(n_ctx - len(batch), self.tokens_consumed)
157+
self.n_past = min(n_ctx - len(batch), self.tokens_consumed)
158+
self.n_tokens = len(batch)
151159
return_code = llama_cpp.llama_eval(
152160
ctx=self.ctx,
153161
tokens=(llama_cpp.llama_token * len(batch))(*batch),
154-
n_tokens=llama_cpp.c_int(len(batch)),
155-
n_past=llama_cpp.c_int(n_past),
162+
n_tokens=llama_cpp.c_int(self.n_tokens),
163+
n_past=llama_cpp.c_int(self.n_past),
156164
n_threads=llama_cpp.c_int(self.n_threads),
157165
)
158166
if int(return_code) != 0:
159167
raise RuntimeError(f"llama_eval returned {return_code}")
160168
self.last_n_tokens_data.extend(batch)
161169
self.tokens_consumed += len(batch)
170+
if self.params.logits_all:
171+
self.all_logits.extend(self._logits())
172+
173+
def _logits(self) -> List[List[float]]:
174+
"""Return the logits from the last call to llama_eval."""
175+
assert self.ctx is not None
176+
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
177+
cols = int(n_vocab)
178+
rows = self.n_tokens if self.params.logits_all else 1
179+
logits_view = llama_cpp.llama_get_logits(self.ctx)
180+
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
181+
return logits
162182

163183
def sample(
164184
self,
@@ -327,14 +347,55 @@ def _create_completion(
327347
else:
328348
stop_sequences = []
329349

330-
finish_reason = None
331-
for token in self.generate(
332-
prompt_tokens,
333-
top_k=top_k,
334-
top_p=top_p,
335-
temp=temperature,
336-
repeat_penalty=repeat_penalty,
337-
):
350+
text_offset = 0
351+
text_offsets: List[int] = []
352+
token_logprobs: List[float] = []
353+
tokens: List[str] = []
354+
top_logprobs: List[Dict[str, float]] = []
355+
356+
self.reset()
357+
self.eval(prompt_tokens)
358+
359+
if logprobs is not None and self.params.logits_all is False:
360+
raise ValueError(
361+
"logprobs is not supported for models created with logits_all=False"
362+
)
363+
364+
if logprobs is not None:
365+
token_strs = [
366+
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
367+
]
368+
logprobs_all = [
369+
[Llama.logit_to_logprob(logit) for logit in row]
370+
for row in self.all_logits
371+
]
372+
for token, token_str, logprobs_token in zip(
373+
prompt_tokens, token_strs, logprobs_all
374+
):
375+
text_offsets.append(text_offset)
376+
text_offset += len(token_str)
377+
tokens.append(token_str)
378+
sorted_logprobs = list(
379+
sorted(
380+
zip(logprobs_token, range(len(logprobs_token))), reverse=True
381+
)
382+
)
383+
token_logprobs.append(sorted_logprobs[int(token)][0])
384+
top_logprob = {
385+
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
386+
for logprob, i in sorted_logprobs[:logprobs]
387+
}
388+
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
389+
top_logprobs.append(top_logprob)
390+
391+
finish_reason = "length"
392+
while True:
393+
token = self.sample(
394+
top_k=top_k,
395+
top_p=top_p,
396+
temp=temperature,
397+
repeat_penalty=repeat_penalty,
398+
)
338399
if token == llama_cpp.llama_token_eos():
339400
text = self.detokenize(completion_tokens)
340401
finish_reason = "stop"
@@ -377,13 +438,35 @@ def _create_completion(
377438
}
378439
],
379440
}
441+
442+
if logprobs is not None:
443+
# TODO: Confirm wether this should happen before or after
444+
# next eval.
445+
token_str = self.detokenize([token]).decode("utf-8")
446+
text_offsets.append(text_offset)
447+
text_offset += len(token_str)
448+
tokens.append(token_str)
449+
logprobs_token = [
450+
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
451+
]
452+
sorted_logprobs = list(
453+
sorted(
454+
zip(logprobs_token, range(len(logprobs_token))), reverse=True
455+
)
456+
)
457+
token_logprobs.append(sorted_logprobs[int(token)][0])
458+
top_logprob = {
459+
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
460+
for logprob, i in sorted_logprobs[:logprobs]
461+
}
462+
top_logprob.update({token_str: logprobs_token[int(token)]})
463+
top_logprobs.append(top_logprob)
464+
380465
if len(completion_tokens) >= max_tokens:
381466
text = self.detokenize(completion_tokens)
382467
finish_reason = "length"
383468
break
384-
385-
if finish_reason is None:
386-
finish_reason = "length"
469+
self.eval([token])
387470

388471
if stream:
389472
yield {
@@ -410,8 +493,14 @@ def _create_completion(
410493
if suffix is not None:
411494
text = text + suffix
412495

496+
logprobs_or_none: Optional[CompletionLogprobs] = None
413497
if logprobs is not None:
414-
raise NotImplementedError("logprobs not implemented")
498+
logprobs_or_none = {
499+
"tokens": tokens,
500+
"text_offset": text_offsets,
501+
"token_logprobs": token_logprobs,
502+
"top_logprobs": top_logprobs,
503+
}
415504

416505
if self.verbose:
417506
llama_cpp.llama_print_timings(self.ctx)
@@ -425,7 +514,7 @@ def _create_completion(
425514
{
426515
"text": text,
427516
"index": 0,
428-
"logprobs": None,
517+
"logprobs": logprobs_or_none,
429518
"finish_reason": finish_reason,
430519
}
431520
],
@@ -704,3 +793,7 @@ def token_eos() -> llama_cpp.llama_token:
704793
def token_bos() -> llama_cpp.llama_token:
705794
"""Return the beginning-of-sequence token."""
706795
return llama_cpp.llama_token_bos()
796+
797+
@staticmethod
798+
def logit_to_logprob(x: float) -> float:
799+
return math.log(1.0 + math.exp(x))

‎llama_cpp/server/__main__.py

Copy file name to clipboardExpand all lines: llama_cpp/server/__main__.py
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Settings(BaseSettings):
3333
use_mlock: bool = False # This causes a silent failure on platforms that don't support mlock (e.g. Windows) took forever to figure out...
3434
embedding: bool = True
3535
last_n_tokens_size: int = 64
36+
logits_all: bool = False
3637

3738

3839
app = FastAPI(
@@ -52,6 +53,7 @@ class Settings(BaseSettings):
5253
f16_kv=settings.f16_kv,
5354
use_mlock=settings.use_mlock,
5455
embedding=settings.embedding,
56+
logits_all=settings.logits_all,
5557
n_threads=settings.n_threads,
5658
n_batch=settings.n_batch,
5759
n_ctx=settings.n_ctx,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.