Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit aa203a0

Browse filesBrowse files
Added mirostat sampling to the high level API.
1 parent 2f2ea00 commit aa203a0
Copy full SHA for aa203a0

File tree

Expand file treeCollapse file tree

1 file changed

+83
-1
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+83
-1
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+83-1Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ def _sample_top_p_top_k(
257257
top_k: llama_cpp.c_int,
258258
top_p: llama_cpp.c_float,
259259
temp: llama_cpp.c_float,
260+
mirostat_mode: llama_cpp.c_int,
261+
mirostat_tau: llama_cpp.c_float,
262+
mirostat_eta: llama_cpp.c_float,
263+
mirostat_mu: llama_cpp.c_float,
264+
mirostat_m: llama_cpp.c_int,
260265
repeat_penalty: llama_cpp.c_float,
261266
):
262267
assert self.ctx is not None
@@ -287,7 +292,34 @@ def _sample_top_p_top_k(
287292
candidates=llama_cpp.ctypes.pointer(candidates),
288293
penalty=repeat_penalty,
289294
)
290-
if float(temp.value) == 0.0:
295+
if mirostat_mode == 1:
296+
llama_cpp.llama_sample_temperature(
297+
ctx=self.ctx,
298+
candidates=llama_cpp.ctypes.pointer(candidates),
299+
temp=temp,
300+
)
301+
llama_cpp.llama_sample_token_mirostat(
302+
ctx=self.ctx,
303+
candidates=llama_cpp.ctypes.pointer(candidates),
304+
tau=mirostat_tau,
305+
eta=mirostat_eta,
306+
mu=mirostat_mu,
307+
m=mirostat_m
308+
)
309+
elif mirostat_mode == 2:
310+
llama_cpp.llama_sample_temperature(
311+
ctx=self.ctx,
312+
candidates=llama_cpp.ctypes.pointer(candidates),
313+
temp=temp,
314+
)
315+
llama_cpp.llama_sample_token_mirostat_v2(
316+
ctx=self.ctx,
317+
candidates=llama_cpp.ctypes.pointer(candidates),
318+
tau=mirostat_tau,
319+
eta=mirostat_eta,
320+
mu=mirostat_mu
321+
)
322+
elif float(temp.value) == 0.0:
291323
return llama_cpp.llama_sample_token_greedy(
292324
ctx=self.ctx,
293325
candidates=llama_cpp.ctypes.pointer(candidates),
@@ -328,6 +360,11 @@ def sample(
328360
top_k: int,
329361
top_p: float,
330362
temp: float,
363+
mirostat_mode: int,
364+
mirostat_tau: float,
365+
mirostat_eta: float,
366+
mirostat_mu: float,
367+
mirostat_m: int,
331368
repeat_penalty: float,
332369
):
333370
"""Sample a token from the model.
@@ -353,6 +390,11 @@ def sample(
353390
top_k=llama_cpp.c_int(top_k),
354391
top_p=llama_cpp.c_float(top_p),
355392
temp=llama_cpp.c_float(temp),
393+
mirostat=llama_cpp.c_int(mirostat_mode),
394+
mirostat_mu=llama_cpp.c_float(mirostat_mu),
395+
mirostat_tau=llama_cpp.c_float(mirostat_tau),
396+
mirostat_eta=llama_cpp.c_float(mirostat_eta),
397+
mirostat_m=llama_cpp.c_int(mirostat_m),
356398
repeat_penalty=llama_cpp.c_float(repeat_penalty),
357399
)
358400

@@ -362,6 +404,11 @@ def generate(
362404
top_k: int,
363405
top_p: float,
364406
temp: float,
407+
mirostat: int,
408+
mirostat_tau: float,
409+
mirostat_eta: float,
410+
mirostat_mu: float,
411+
mirostat_m: int,
365412
repeat_penalty: float,
366413
reset: bool = True,
367414
) -> Generator[
@@ -416,6 +463,11 @@ def generate(
416463
top_k=top_k,
417464
top_p=top_p,
418465
temp=temp,
466+
mirostat_mode=mirostat_mode,
467+
mirostat_tau=mirostat_tau,
468+
mirostat_eta=mirostat_eta,
469+
mirostat_mu=mirostat_mu,
470+
mirostat_m=mirostat_m,
419471
repeat_penalty=repeat_penalty,
420472
)
421473
tokens_or_none = yield token
@@ -486,6 +538,11 @@ def _create_completion(
486538
suffix: Optional[str] = None,
487539
max_tokens: int = 16,
488540
temperature: float = 0.8,
541+
mirostat_mode: int = 0,
542+
mirostat_tau: float = 5.0,
543+
mirostat_eta: float = 0.1,
544+
mirostat_mu: float = 10,
545+
mirostat_m: int = 100,
489546
top_p: float = 0.95,
490547
logprobs: Optional[int] = None,
491548
echo: bool = False,
@@ -536,6 +593,11 @@ def _create_completion(
536593
top_k=top_k,
537594
top_p=top_p,
538595
temp=temperature,
596+
mirostat_mode=mirostat_mode,
597+
mirostat_tau=mirostat_tau,
598+
mirostat_eta=mirostat_eta,
599+
mirostat_mu=mirostat_mu,
600+
mirostat_m=mirostat_m,
539601
repeat_penalty=repeat_penalty,
540602
):
541603
if token == llama_cpp.llama_token_eos():
@@ -707,6 +769,11 @@ def create_completion(
707769
suffix: Optional[str] = None,
708770
max_tokens: int = 128,
709771
temperature: float = 0.8,
772+
mirostat_mode: int = 0,
773+
mirostat_tau: float = 5.0,
774+
mirostat_eta: float = 0.1,
775+
mirostat_mu: float = 10,
776+
mirostat_m: int = 100,
710777
top_p: float = 0.95,
711778
logprobs: Optional[int] = None,
712779
echo: bool = False,
@@ -742,6 +809,11 @@ def create_completion(
742809
suffix=suffix,
743810
max_tokens=max_tokens,
744811
temperature=temperature,
812+
mirostat_mode=mirostat_mode,
813+
mirostat_tau=mirostat_tau,
814+
mirostat_eta=mirostat_eta,
815+
mirostat_mu=mirostat_mu,
816+
mirostat_m=mirostat_m,
745817
top_p=top_p,
746818
logprobs=logprobs,
747819
echo=echo,
@@ -762,6 +834,11 @@ def __call__(
762834
suffix: Optional[str] = None,
763835
max_tokens: int = 128,
764836
temperature: float = 0.8,
837+
mirostat_mode: int = 0,
838+
mirostat_tau: float = 5.0,
839+
mirostat_eta: float = 0.1,
840+
mirostat_mu: float = 10,
841+
mirostat_m: int = 100,
765842
top_p: float = 0.95,
766843
logprobs: Optional[int] = None,
767844
echo: bool = False,
@@ -797,6 +874,11 @@ def __call__(
797874
suffix=suffix,
798875
max_tokens=max_tokens,
799876
temperature=temperature,
877+
mirostat_mode=mirostat_mode,
878+
mirostat_tau=mirostat_tau,
879+
mirostat_eta=mirostat_eta,
880+
mirostat_mu=mirostat_mu,
881+
mirostat_m=mirostat_m,
800882
top_p=top_p,
801883
logprobs=logprobs,
802884
echo=echo,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.