Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6ecf40c

Browse filesBrowse files
authored
Merge branch 'abetlen:main' into main
2 parents 0ede672 + c088a2b commit 6ecf40c
Copy full SHA for 6ecf40c

File tree

Expand file treeCollapse file tree

11 files changed

+744
-313
lines changed
Filter options
Expand file treeCollapse file tree

11 files changed

+744
-313
lines changed

‎.github/workflows/test.yaml

Copy file name to clipboardExpand all lines: .github/workflows/test.yaml
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
python-version: ${{ matrix.python-version }}
2727
- name: Install dependencies
2828
run: |
29-
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
29+
python -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi sse-starlette httpx uvicorn
3030
pip install . -v
3131
- name: Test with pytest
3232
run: |
@@ -49,7 +49,7 @@ jobs:
4949
python-version: ${{ matrix.python-version }}
5050
- name: Install dependencies
5151
run: |
52-
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
52+
python -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi sse-starlette httpx uvicorn
5353
pip install . -v
5454
- name: Test with pytest
5555
run: |
@@ -72,7 +72,7 @@ jobs:
7272
python-version: ${{ matrix.python-version }}
7373
- name: Install dependencies
7474
run: |
75-
python -m pip install --upgrade pip pytest cmake scikit-build setuptools
75+
python -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi sse-starlette httpx uvicorn
7676
pip install . -v
7777
- name: Test with pytest
7878
run: |

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+99-20Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def __init__(
127127
self.last_n_tokens_size = last_n_tokens_size
128128
self.n_batch = min(n_ctx, n_batch)
129129
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
130-
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
130+
self.eval_logits: Deque[List[llama_cpp.c_float]] = deque(
131+
maxlen=n_ctx if logits_all else 1
132+
)
131133

132134
self.cache: Optional[LlamaCache] = None
133135

@@ -236,17 +238,90 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
236238
)
237239
if int(return_code) != 0:
238240
raise RuntimeError(f"llama_eval returned {return_code}")
241+
# Save tokens
239242
self.eval_tokens.extend(batch)
240-
if self.params.logits_all:
241-
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
242-
cols = int(n_vocab)
243-
rows = n_tokens
244-
logits_view = llama_cpp.llama_get_logits(self.ctx)
245-
logits = [
246-
[logits_view[i * cols + j] for j in range(cols)]
247-
for i in range(rows)
248-
]
249-
self.eval_logits.extend(logits)
243+
# Save logits
244+
rows = n_tokens if self.params.logits_all else 1
245+
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
246+
cols = int(n_vocab)
247+
logits_view = llama_cpp.llama_get_logits(self.ctx)
248+
logits: List[List[llama_cpp.c_float]] = [
249+
[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)
250+
]
251+
self.eval_logits.extend(logits)
252+
253+
def _sample_top_p_top_k(
254+
self,
255+
last_n_tokens_data, # type: llama_cpp.Array[llama_cpp.llama_token]
256+
last_n_tokens_size: llama_cpp.c_int,
257+
top_k: llama_cpp.c_int,
258+
top_p: llama_cpp.c_float,
259+
temp: llama_cpp.c_float,
260+
repeat_penalty: llama_cpp.c_float,
261+
):
262+
assert self.ctx is not None
263+
assert len(self.eval_logits) > 0
264+
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
265+
logits = self.eval_logits[-1]
266+
data = (llama_cpp.llama_token_data * n_vocab)(
267+
*[
268+
llama_cpp.llama_token_data(
269+
id=llama_cpp.llama_token(i),
270+
logit=logits[i],
271+
p=llama_cpp.c_float(0.0),
272+
)
273+
for i in range(n_vocab)
274+
]
275+
)
276+
size = llama_cpp.c_size_t(n_vocab)
277+
sorted = False
278+
candidates = llama_cpp.llama_token_data_array(
279+
data=data,
280+
size=size,
281+
sorted=sorted,
282+
)
283+
llama_cpp.llama_sample_repetition_penalty(
284+
ctx=self.ctx,
285+
last_tokens_data=last_n_tokens_data,
286+
last_tokens_size=last_n_tokens_size,
287+
candidates=llama_cpp.ctypes.pointer(candidates),
288+
penalty=repeat_penalty,
289+
)
290+
if temp == 0.0:
291+
return llama_cpp.llama_sample_token_greedy(
292+
ctx=self.ctx,
293+
candidates=llama_cpp.ctypes.pointer(candidates),
294+
)
295+
else:
296+
llama_cpp.llama_sample_top_k(
297+
ctx=self.ctx,
298+
candidates=llama_cpp.ctypes.pointer(candidates),
299+
k=top_k,
300+
)
301+
llama_cpp.llama_sample_tail_free(
302+
ctx=self.ctx,
303+
candidates=llama_cpp.ctypes.pointer(candidates),
304+
z=llama_cpp.c_float(1.0),
305+
)
306+
llama_cpp.llama_sample_typical(
307+
ctx=self.ctx,
308+
candidates=llama_cpp.ctypes.pointer(candidates),
309+
p=llama_cpp.c_float(1.0)
310+
)
311+
llama_cpp.llama_sample_top_p(
312+
ctx=self.ctx,
313+
candidates=llama_cpp.ctypes.pointer(candidates),
314+
p=top_p,
315+
)
316+
llama_cpp.llama_sample_temperature(
317+
ctx=self.ctx,
318+
candidates=llama_cpp.ctypes.pointer(candidates),
319+
temp=temp,
320+
)
321+
return llama_cpp.llama_sample_token(
322+
ctx=self.ctx,
323+
candidates=llama_cpp.ctypes.pointer(candidates),
324+
)
250325

251326
def sample(
252327
self,
@@ -270,8 +345,7 @@ def sample(
270345
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
271346
0, self.last_n_tokens_size - len(self.eval_tokens)
272347
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
273-
return llama_cpp.llama_sample_top_p_top_k(
274-
ctx=self.ctx,
348+
return self._sample_top_p_top_k(
275349
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
276350
*last_n_tokens_data
277351
),
@@ -470,15 +544,15 @@ def _create_completion(
470544
all_text = self.detokenize(completion_tokens)
471545

472546
# Contains multi-byte UTF8
473-
for k,char in enumerate(all_text[-3:]):
547+
for k, char in enumerate(all_text[-3:]):
474548
k = 3 - k
475-
for num,pattern in [(2, 192), (3, 224), (4, 240)]:
549+
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
476550
# Bitwise AND check
477-
if (num > k and pattern & char == pattern):
551+
if num > k and pattern & char == pattern:
478552
multibyte_fix = num - k
479553

480554
# Stop incomplete bytes from passing
481-
if (multibyte_fix > 0):
555+
if multibyte_fix > 0:
482556
multibyte_fix -= 1
483557
continue
484558

@@ -531,7 +605,9 @@ def _create_completion(
531605
"model": self.model_path,
532606
"choices": [
533607
{
534-
"text": text[returned_characters:].decode("utf-8", errors="ignore"),
608+
"text": text[returned_characters:].decode(
609+
"utf-8", errors="ignore"
610+
),
535611
"index": 0,
536612
"logprobs": None,
537613
"finish_reason": finish_reason,
@@ -558,7 +634,8 @@ def _create_completion(
558634

559635
all_tokens = prompt_tokens + completion_tokens
560636
all_token_strs = [
561-
self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens
637+
self.detokenize([token]).decode("utf-8", errors="ignore")
638+
for token in all_tokens
562639
]
563640
all_logprobs = [
564641
[Llama.logit_to_logprob(logit) for logit in row]
@@ -577,7 +654,9 @@ def _create_completion(
577654
)
578655
token_logprobs.append(sorted_logprobs[int(token)][0])
579656
top_logprob = {
580-
self.detokenize([llama_cpp.llama_token(i)]).decode("utf-8", errors="ignore"): logprob
657+
self.detokenize([llama_cpp.llama_token(i)]).decode(
658+
"utf-8", errors="ignore"
659+
): logprob
581660
for logprob, i in sorted_logprobs[:logprobs]
582661
}
583662
top_logprob.update({token_str: sorted_logprobs[int(token)][0]})

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.