Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3babe35

Browse filesBrowse files
committed
Fix mirostat sampling
1 parent 141293a commit 3babe35
Copy full SHA for 3babe35

File tree

Expand file treeCollapse file tree

1 file changed

+11
-2
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+11
-2
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+11-2Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ def __init__(
329329
(n_ctx, self._n_vocab), dtype=np.single
330330
)
331331

332+
self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context
333+
332334
@property
333335
def ctx(self) -> llama_cpp.llama_context_p:
334336
assert self._ctx.ctx is not None
@@ -516,7 +518,7 @@ def sample(
516518
candidates=self._candidates,
517519
tau=mirostat_tau,
518520
eta=mirostat_eta,
519-
mu=2.0 * mirostat_tau,
521+
mu=ctypes.pointer(self._mirostat_mu),
520522
m=100,
521523
)
522524
elif mirostat_mode == 2:
@@ -525,7 +527,7 @@ def sample(
525527
candidates=self._candidates,
526528
tau=mirostat_tau,
527529
eta=mirostat_eta,
528-
mu=2.0 * mirostat_tau,
530+
mu=ctypes.pointer(self._mirostat_mu)
529531
)
530532
else:
531533
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
@@ -581,6 +583,10 @@ def generate(
581583
Yields:
582584
The generated tokens.
583585
"""
586+
# Reset mirostat sampling
587+
self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau)
588+
589+
# Check for kv cache prefix match
584590
if reset and self.n_tokens > 0:
585591
longest_prefix = 0
586592
for a, b in zip(self._input_ids, tokens[:-1]):
@@ -595,12 +601,15 @@ def generate(
595601
tokens = tokens[longest_prefix:]
596602
self.n_tokens = longest_prefix
597603

604+
# Reset the model state
598605
if reset:
599606
self.reset()
600607

608+
# Reset the grammar
601609
if grammar is not None:
602610
grammar.reset()
603611

612+
# Eval and sample
604613
while True:
605614
self.eval(tokens)
606615
token = self.sample(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.