Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 22cedad

Browse filesBrowse files
xu-songabetlen
andauthored
fix: Fix memory allocation of ndarray (abetlen#1704)
* Fix memory allocation of ndarray * Add basic LlamaState tests * Improve LlamaState test and fix rng / seed --------- Co-authored-by: Andrei <abetlen@gmail.com>
1 parent 9b64bb5 commit 22cedad
Copy full SHA for 22cedad

File tree

Expand file treeCollapse file tree

2 files changed

+64
-15
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+64
-15
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+19-15Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import ctypes
99
import typing
10+
import random
1011
import fnmatch
1112
import warnings
1213
import contextlib
@@ -301,9 +302,11 @@ def __init__(
301302
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
302303
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
303304

305+
# Used by the sampler
306+
self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED
307+
304308
# Context Params
305309
self.context_params = llama_cpp.llama_context_default_params()
306-
self.context_params.seed = seed
307310
self.context_params.n_ctx = n_ctx
308311
self.context_params.n_batch = self.n_batch
309312
self.context_params.n_threads = self.n_threads
@@ -613,8 +616,7 @@ def set_seed(self, seed: int):
613616
Args:
614617
seed: The random seed.
615618
"""
616-
# TODO: Fix this
617-
# llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
619+
self._seed = seed
618620

619621
def reset(self):
620622
"""Reset the model state."""
@@ -672,7 +674,6 @@ def _init_sampler(
672674
penalize_nl: bool = True,
673675
logits_processor: Optional[LogitsProcessorList] = None,
674676
grammar: Optional[LlamaGrammar] = None,
675-
seed: Optional[int] = None,
676677
):
677678
sampler = internals.LlamaSampler()
678679

@@ -715,22 +716,22 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
715716

716717
if temp < 0.0:
717718
sampler.add_softmax()
718-
sampler.add_dist(seed or llama_cpp.LLAMA_DEFAULT_SEED)
719+
sampler.add_dist(self._seed)
719720
elif temp == 0.0:
720721
sampler.add_greedy()
721722
else:
722723
if mirostat_mode == 1:
723724
mirostat_m = 100
724725
sampler.add_mirostat(
725726
self._n_vocab,
726-
seed or llama_cpp.LLAMA_DEFAULT_SEED,
727+
self._seed,
727728
mirostat_tau,
728729
mirostat_eta,
729730
mirostat_m,
730731
)
731732
elif mirostat_mode == 2:
732733
sampler.add_mirostat_v2(
733-
seed or llama_cpp.LLAMA_DEFAULT_SEED,
734+
self._seed,
734735
mirostat_tau,
735736
mirostat_eta,
736737
)
@@ -743,7 +744,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
743744
sampler.add_top_p(top_p, min_keep)
744745
sampler.add_min_p(min_p, min_keep)
745746
sampler.add_temp(temp)
746-
sampler.add_dist(seed or llama_cpp.LLAMA_DEFAULT_SEED)
747+
sampler.add_dist(self._seed)
747748
return sampler
748749

749750
def sample(
@@ -826,7 +827,6 @@ def generate(
826827
logits_processor: Optional[LogitsProcessorList] = None,
827828
stopping_criteria: Optional[StoppingCriteriaList] = None,
828829
grammar: Optional[LlamaGrammar] = None,
829-
seed: Optional[int] = None,
830830
) -> Generator[int, Optional[Sequence[int]], None]:
831831
"""Create a generator of tokens from a prompt.
832832
@@ -865,7 +865,6 @@ def generate(
865865
penalize_nl=penalize_nl,
866866
logits_processor=logits_processor,
867867
grammar=grammar,
868-
seed=seed,
869868
)
870869

871870
# Check for kv cache prefix match
@@ -1301,9 +1300,10 @@ def logit_bias_processor(
13011300
if self.verbose:
13021301
print("Llama._create_completion: cache miss", file=sys.stderr)
13031302

1304-
# TODO: Fix this
1305-
# if seed is not None:
1306-
# self._ctx.set_rng_seed(seed)
1303+
if seed is not None:
1304+
self.set_seed(seed)
1305+
else:
1306+
self.set_seed(random.Random(self._seed).randint(0, 2 ** 32))
13071307

13081308
finish_reason = "length"
13091309
multibyte_fix = 0
@@ -1324,7 +1324,6 @@ def logit_bias_processor(
13241324
stopping_criteria=stopping_criteria,
13251325
logits_processor=logits_processor,
13261326
grammar=grammar,
1327-
seed=seed,
13281327
):
13291328
if llama_cpp.llama_token_is_eog(self._model.model, token):
13301329
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
@@ -2136,14 +2135,17 @@ def save_state(self) -> LlamaState:
21362135
n_tokens=self.n_tokens,
21372136
llama_state=bytes(llama_state_compact),
21382137
llama_state_size=n_bytes,
2138+
seed=self._seed,
21392139
)
21402140

21412141
def load_state(self, state: LlamaState) -> None:
21422142
# Only filling in up to `n_tokens` and then zero-ing out the rest
21432143
self.scores[: state.n_tokens, :] = state.scores.copy()
2144-
self.scores[state.n_tokens :, :] = 0.0
2144+
rest = self.scores[state.n_tokens :, :]
2145+
rest[rest > 0] = 0.0
21452146
self.input_ids = state.input_ids.copy()
21462147
self.n_tokens = state.n_tokens
2148+
self._seed = state.seed
21472149
state_size = state.llama_state_size
21482150
LLamaStateArrayType = ctypes.c_uint8 * state_size
21492151
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
@@ -2321,12 +2323,14 @@ def __init__(
23212323
n_tokens: int,
23222324
llama_state: bytes,
23232325
llama_state_size: int,
2326+
seed: int,
23242327
):
23252328
self.input_ids = input_ids
23262329
self.scores = scores
23272330
self.n_tokens = n_tokens
23282331
self.llama_state = llama_state
23292332
self.llama_state_size = llama_state_size
2333+
self.seed = seed
23302334

23312335

23322336
LogitsProcessor = Callable[

‎tests/test_llama.py

Copy file name to clipboardExpand all lines: tests/test_llama.py
+45Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,48 @@ def logit_processor_func(input_ids, logits):
171171
logits_processor=logit_processors
172172
)
173173
assert output["choices"][0]["text"].lower().startswith("rot")
174+
175+
model.set_seed(1337)
176+
177+
state = model.save_state()
178+
179+
output = model.create_completion(
180+
"Pick a number from 1 to 10?:\n",
181+
max_tokens=4,
182+
top_k=50,
183+
top_p=0.9,
184+
temperature=0.8,
185+
grammar=llama_cpp.LlamaGrammar.from_string("""
186+
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
187+
""")
188+
)
189+
number_1 = output["choices"][0]["text"]
190+
191+
output = model.create_completion(
192+
"Pick a number from 1 to 10?:\n",
193+
max_tokens=4,
194+
top_k=50,
195+
top_p=0.9,
196+
temperature=0.8,
197+
grammar=llama_cpp.LlamaGrammar.from_string("""
198+
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
199+
""")
200+
)
201+
number_2 = output["choices"][0]["text"]
202+
203+
model.load_state(state)
204+
205+
output = model.create_completion(
206+
"Pick a number from 1 to 10?:\n",
207+
max_tokens=4,
208+
top_k=50,
209+
top_p=0.9,
210+
temperature=0.8,
211+
grammar=llama_cpp.LlamaGrammar.from_string("""
212+
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
213+
""")
214+
)
215+
number_3 = output["choices"][0]["text"]
216+
217+
assert number_1 != number_2
218+
assert number_1 == number_3

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.