@@ -329,6 +329,8 @@ def __init__(
329
329
(n_ctx , self ._n_vocab ), dtype = np .single
330
330
)
331
331
332
+ self ._mirostat_mu = ctypes .c_float (2.0 * 5.0 ) # TODO: Move this to sampling context
333
+
332
334
@property
333
335
def ctx (self ) -> llama_cpp .llama_context_p :
334
336
assert self ._ctx .ctx is not None
@@ -516,7 +518,7 @@ def sample(
516
518
candidates = self ._candidates ,
517
519
tau = mirostat_tau ,
518
520
eta = mirostat_eta ,
519
- mu = 2.0 * mirostat_tau ,
521
+ mu = ctypes . pointer ( self . _mirostat_mu ) ,
520
522
m = 100 ,
521
523
)
522
524
elif mirostat_mode == 2 :
@@ -525,7 +527,7 @@ def sample(
525
527
candidates = self ._candidates ,
526
528
tau = mirostat_tau ,
527
529
eta = mirostat_eta ,
528
- mu = 2.0 * mirostat_tau ,
530
+ mu = ctypes . pointer ( self . _mirostat_mu )
529
531
)
530
532
else :
531
533
self ._ctx .sample_top_k (candidates = self ._candidates , k = top_k , min_keep = 1 )
@@ -581,6 +583,10 @@ def generate(
581
583
Yields:
582
584
The generated tokens.
583
585
"""
586
+ # Reset mirostat sampling
587
+ self ._mirostat_mu = ctypes .c_float (2.0 * mirostat_tau )
588
+
589
+ # Check for kv cache prefix match
584
590
if reset and self .n_tokens > 0 :
585
591
longest_prefix = 0
586
592
for a , b in zip (self ._input_ids , tokens [:- 1 ]):
@@ -595,12 +601,15 @@ def generate(
595
601
tokens = tokens [longest_prefix :]
596
602
self .n_tokens = longest_prefix
597
603
604
+ # Reset the model state
598
605
if reset :
599
606
self .reset ()
600
607
608
+ # Reset the grammar
601
609
if grammar is not None :
602
610
grammar .reset ()
603
611
612
+ # Eval and sample
604
613
while True :
605
614
self .eval (tokens )
606
615
token = self .sample (
0 commit comments