Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6153baa

Browse filesBrowse files
committed
Clean up logprobs implementation
1 parent 26cc4ee commit 6153baa
Copy full SHA for 6153baa

File tree

Expand file treeCollapse file tree

1 file changed

+39
-67
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+39
-67
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+39-67Lines changed: 39 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -351,55 +351,19 @@ def _create_completion(
351351
else:
352352
stop_sequences = []
353353

354-
text_offset = 0
355-
text_offsets: List[int] = []
356-
token_logprobs: List[float] = []
357-
tokens: List[str] = []
358-
top_logprobs: List[Dict[str, float]] = []
359-
360-
self.reset()
361-
self.eval(prompt_tokens)
362-
363354
if logprobs is not None and self.params.logits_all is False:
364355
raise ValueError(
365356
"logprobs is not supported for models created with logits_all=False"
366357
)
367358

368-
if logprobs is not None:
369-
token_strs = [
370-
self.detokenize([token]).decode("utf-8") for token in prompt_tokens
371-
]
372-
logprobs_all = [
373-
[Llama.logit_to_logprob(logit) for logit in row]
374-
for row in self.all_logits
375-
]
376-
for token, token_str, logprobs_token in zip(
377-
prompt_tokens, token_strs, logprobs_all
378-
):
379-
text_offsets.append(text_offset)
380-
text_offset += len(token_str)
381-
tokens.append(token_str)
382-
sorted_logprobs = list(
383-
sorted(
384-
zip(logprobs_token, range(len(logprobs_token))), reverse=True
385-
)
386-
)
387-
token_logprobs.append(sorted_logprobs[int(token)][0])
388-
top_logprob = {
389-
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
390-
for logprob, i in sorted_logprobs[:logprobs]
391-
}
392-
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
393-
top_logprobs.append(top_logprob)
394-
395359
finish_reason = "length"
396-
while True:
397-
token = self.sample(
398-
top_k=top_k,
399-
top_p=top_p,
400-
temp=temperature,
401-
repeat_penalty=repeat_penalty,
402-
)
360+
for token in self.generate(
361+
prompt_tokens,
362+
top_k=top_k,
363+
top_p=top_p,
364+
temp=temperature,
365+
repeat_penalty=repeat_penalty,
366+
):
403367
if token == llama_cpp.llama_token_eos():
404368
text = self.detokenize(completion_tokens)
405369
finish_reason = "stop"
@@ -443,34 +407,10 @@ def _create_completion(
443407
],
444408
}
445409

446-
if logprobs is not None:
447-
# TODO: Confirm wether this should happen before or after
448-
# next eval.
449-
token_str = self.detokenize([token]).decode("utf-8")
450-
text_offsets.append(text_offset)
451-
text_offset += len(token_str)
452-
tokens.append(token_str)
453-
logprobs_token = [
454-
Llama.logit_to_logprob(logit) for logit in self.all_logits[-1]
455-
]
456-
sorted_logprobs = list(
457-
sorted(
458-
zip(logprobs_token, range(len(logprobs_token))), reverse=True
459-
)
460-
)
461-
token_logprobs.append(sorted_logprobs[int(token)][0])
462-
top_logprob = {
463-
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
464-
for logprob, i in sorted_logprobs[:logprobs]
465-
}
466-
top_logprob.update({token_str: logprobs_token[int(token)]})
467-
top_logprobs.append(top_logprob)
468-
469410
if len(completion_tokens) >= max_tokens:
470411
text = self.detokenize(completion_tokens)
471412
finish_reason = "length"
472413
break
473-
self.eval([token])
474414

475415
if stream:
476416
yield {
@@ -499,6 +439,38 @@ def _create_completion(
499439

500440
logprobs_or_none: Optional[CompletionLogprobs] = None
501441
if logprobs is not None:
442+
text_offset = 0
443+
text_offsets: List[int] = []
444+
token_logprobs: List[float] = []
445+
tokens: List[str] = []
446+
top_logprobs: List[Dict[str, float]] = []
447+
448+
all_tokens = prompt_tokens + completion_tokens
449+
all_token_strs = [
450+
self.detokenize([token]).decode("utf-8") for token in all_tokens
451+
]
452+
all_logprobs = [
453+
[Llama.logit_to_logprob(logit) for logit in row]
454+
for row in self.all_logits
455+
]
456+
for token, token_str, logprobs_token in zip(
457+
all_tokens, all_token_strs, all_logprobs
458+
):
459+
text_offsets.append(text_offset)
460+
text_offset += len(token_str)
461+
tokens.append(token_str)
462+
sorted_logprobs = list(
463+
sorted(
464+
zip(logprobs_token, range(len(logprobs_token))), reverse=True
465+
)
466+
)
467+
token_logprobs.append(sorted_logprobs[int(token)][0])
468+
top_logprob = {
469+
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8"): logprob
470+
for logprob, i in sorted_logprobs[:logprobs]
471+
}
472+
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})
473+
top_logprobs.append(top_logprob)
502474
logprobs_or_none = {
503475
"tokens": tokens,
504476
"text_offset": text_offsets,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.