diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7dd862515..80fadffe2 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -7,6 +7,7 @@ import json import ctypes import typing +import random import fnmatch import warnings import contextlib @@ -301,9 +302,11 @@ def __init__( self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count() + # Used by the sampler + self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED + # Context Params self.context_params = llama_cpp.llama_context_default_params() - self.context_params.seed = seed self.context_params.n_ctx = n_ctx self.context_params.n_batch = self.n_batch self.context_params.n_threads = self.n_threads @@ -613,8 +616,7 @@ def set_seed(self, seed: int): Args: seed: The random seed. """ - # TODO: Fix this - # llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed) + self._seed = seed def reset(self): """Reset the model state.""" @@ -672,7 +674,6 @@ def _init_sampler( penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - seed: Optional[int] = None, ): sampler = internals.LlamaSampler() @@ -715,7 +716,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): if temp < 0.0: sampler.add_softmax() - sampler.add_dist(seed or llama_cpp.LLAMA_DEFAULT_SEED) + sampler.add_dist(self._seed) elif temp == 0.0: sampler.add_greedy() else: @@ -723,14 +724,14 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): mirostat_m = 100 sampler.add_mirostat( self._n_vocab, - seed or llama_cpp.LLAMA_DEFAULT_SEED, + self._seed, mirostat_tau, mirostat_eta, mirostat_m, ) elif mirostat_mode == 2: sampler.add_mirostat_v2( - seed or llama_cpp.LLAMA_DEFAULT_SEED, + self._seed, mirostat_tau, mirostat_eta, ) @@ -743,7 +744,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): sampler.add_top_p(top_p, min_keep) sampler.add_min_p(min_p, min_keep) sampler.add_temp(temp) - sampler.add_dist(seed or llama_cpp.LLAMA_DEFAULT_SEED) + sampler.add_dist(self._seed) return sampler def sample( @@ -826,7 +827,6 @@ def generate( logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, grammar: Optional[LlamaGrammar] = None, - seed: Optional[int] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -865,7 +865,6 @@ def generate( penalize_nl=penalize_nl, logits_processor=logits_processor, grammar=grammar, - seed=seed, ) # Check for kv cache prefix match @@ -1301,9 +1300,10 @@ def logit_bias_processor( if self.verbose: print("Llama._create_completion: cache miss", file=sys.stderr) - # TODO: Fix this - # if seed is not None: - # self._ctx.set_rng_seed(seed) + if seed is not None: + self.set_seed(seed) + else: + self.set_seed(random.Random(self._seed).randint(0, 2 ** 32)) finish_reason = "length" multibyte_fix = 0 @@ -1324,7 +1324,6 @@ def logit_bias_processor( stopping_criteria=stopping_criteria, logits_processor=logits_processor, grammar=grammar, - seed=seed, ): if llama_cpp.llama_token_is_eog(self._model.model, token): text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) @@ -2136,14 +2135,17 @@ def save_state(self) -> LlamaState: n_tokens=self.n_tokens, llama_state=bytes(llama_state_compact), llama_state_size=n_bytes, + seed=self._seed, ) def load_state(self, state: LlamaState) -> None: # Only filling in up to `n_tokens` and then zero-ing out the rest self.scores[: state.n_tokens, :] = state.scores.copy() - self.scores[state.n_tokens :, :] = 0.0 + rest = self.scores[state.n_tokens :, :] + rest[rest > 0] = 0.0 self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens + self._seed = state.seed state_size = state.llama_state_size LLamaStateArrayType = ctypes.c_uint8 * state_size llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) @@ -2321,12 +2323,14 @@ def __init__( n_tokens: int, llama_state: bytes, llama_state_size: int, + seed: int, ): self.input_ids = input_ids self.scores = scores self.n_tokens = n_tokens self.llama_state = llama_state self.llama_state_size = llama_state_size + self.seed = seed LogitsProcessor = Callable[ diff --git a/tests/test_llama.py b/tests/test_llama.py index cf134c2e7..fc182ae20 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -171,3 +171,48 @@ def logit_processor_func(input_ids, logits): logits_processor=logit_processors ) assert output["choices"][0]["text"].lower().startswith("rot") + + model.set_seed(1337) + + state = model.save_state() + + output = model.create_completion( + "Pick a number from 1 to 10?:\n", + max_tokens=4, + top_k=50, + top_p=0.9, + temperature=0.8, + grammar=llama_cpp.LlamaGrammar.from_string(""" +root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10" +""") + ) + number_1 = output["choices"][0]["text"] + + output = model.create_completion( + "Pick a number from 1 to 10?:\n", + max_tokens=4, + top_k=50, + top_p=0.9, + temperature=0.8, + grammar=llama_cpp.LlamaGrammar.from_string(""" +root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10" +""") + ) + number_2 = output["choices"][0]["text"] + + model.load_state(state) + + output = model.create_completion( + "Pick a number from 1 to 10?:\n", + max_tokens=4, + top_k=50, + top_p=0.9, + temperature=0.8, + grammar=llama_cpp.LlamaGrammar.from_string(""" +root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10" +""") + ) + number_3 = output["choices"][0]["text"] + + assert number_1 != number_2 + assert number_1 == number_3