Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 0a7e05b

Browse filesBrowse files
committed
tests: don't mock sampling functions
1 parent d7388f1 commit 0a7e05b
Copy full SHA for 0a7e05b

File tree

Expand file treeCollapse file tree

1 file changed

+27
-15
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+27
-15
lines changed

‎tests/test_llama.py

Copy file name to clipboardExpand all lines: tests/test_llama.py
+27-15Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def test_llama_cpp_tokenization():
4747
@pytest.fixture
4848
def mock_llama(monkeypatch):
4949
def setup_mock(llama: llama_cpp.Llama, output_text: str):
50-
llama.reset()
5150
n_vocab = llama.n_vocab()
5251
output_tokens = llama.tokenize(
5352
output_text.encode("utf-8"), add_bos=True, special=True
@@ -59,28 +58,41 @@ def mock_decode(ctx: llama_cpp.llama_context_p, batch: llama_cpp.llama_batch):
5958
nonlocal n
6059
nonlocal last_n_tokens
6160
# Test some basic invariants of this mocking technique
62-
assert ctx == llama._ctx.ctx
63-
assert llama.n_tokens == n
64-
assert batch.n_tokens > 0
65-
n += batch.n_tokens
61+
assert ctx == llama._ctx.ctx, "context does not match mock_llama"
62+
assert batch.n_tokens > 0, "no tokens in batch"
63+
assert all(
64+
batch.n_seq_id[i] == 1 for i in range(batch.n_tokens)
65+
), "n_seq >1 not supported by mock_llama"
66+
assert all(
67+
batch.seq_id[i][0] == 0 for i in range(batch.n_tokens)
68+
), "n_seq >1 not supported by mock_llama"
69+
assert batch.logits[
70+
batch.n_tokens - 1
71+
], "logits not allocated for last token"
72+
# Update the mock context state
73+
n = max(batch.pos[i] for i in range(batch.n_tokens)) + 1
6674
last_n_tokens = batch.n_tokens
6775
return 0
6876

6977
def mock_get_logits(*args, **kwargs):
70-
nonlocal last_n_tokens
71-
size = n_vocab * last_n_tokens
72-
return (llama_cpp.c_float * size)()
73-
74-
def mock_sample(*args, **kwargs):
7578
nonlocal n
76-
if n < len(output_tokens):
77-
return output_tokens[n]
78-
else:
79-
return llama.token_eos()
79+
nonlocal last_n_tokens
80+
assert n > 0, "mock_llama_decode not called"
81+
assert last_n_tokens > 0, "mock_llama_decode not called"
82+
logits = (llama_cpp.c_float * (last_n_tokens * n_vocab))(-100.0)
83+
for logits_idx, output_idx in enumerate(
84+
range(n - last_n_tokens + 1, n + 1)
85+
):
86+
if output_idx < len(output_tokens):
87+
logits[
88+
logits_idx * last_n_tokens + output_tokens[output_idx]
89+
] = 100.0
90+
else:
91+
logits[logits_idx * last_n_tokens + llama.token_eos()] = 100.0
92+
return logits
8093

8194
monkeypatch.setattr("llama_cpp.llama_cpp.llama_decode", mock_decode)
8295
monkeypatch.setattr("llama_cpp.llama_cpp.llama_get_logits", mock_get_logits)
83-
monkeypatch.setattr("llama_cpp.llama_cpp.llama_sample_token", mock_sample)
8496

8597
return setup_mock
8698

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.