Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d957422

Browse filesBrowse files
committed
Implement sampling as in llama.cpp main example
1 parent 93a9019 commit d957422
Copy full SHA for d957422

File tree

Expand file treeCollapse file tree

1 file changed

+70
-80
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+70
-80
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+70-80Lines changed: 70 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,13 @@ def _sample(
268268
top_k: llama_cpp.c_int,
269269
top_p: llama_cpp.c_float,
270270
temp: llama_cpp.c_float,
271-
mirostat_mode: llama_cpp.c_int,
272-
mirostat_tau: llama_cpp.c_float,
273-
mirostat_eta: llama_cpp.c_float,
274-
mirostat_mu: llama_cpp.c_float,
275-
mirostat_m: llama_cpp.c_int,
271+
tfs_z: llama_cpp.c_float,
276272
repeat_penalty: llama_cpp.c_float,
277273
frequency_penalty: llama_cpp.c_float,
278274
presence_penalty: llama_cpp.c_float,
275+
mirostat_mode: llama_cpp.c_int,
276+
mirostat_tau: llama_cpp.c_float,
277+
mirostat_eta: llama_cpp.c_float,
279278
):
280279
assert self.ctx is not None
281280
assert len(self.eval_logits) > 0
@@ -305,45 +304,48 @@ def _sample(
305304
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
306305
penalty=repeat_penalty,
307306
)
308-
if mirostat_mode.value == 1:
307+
llama_cpp.llama_sample_frequency_and_presence_penalties(
308+
ctx=self.ctx,
309+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
310+
last_tokens_data=last_n_tokens_data,
311+
last_tokens_size=last_n_tokens_size,
312+
alpha_frequency=frequency_penalty,
313+
alpha_presence=presence_penalty,
314+
)
315+
if temp.value == 0.0:
316+
return llama_cpp.llama_sample_token_greedy(
317+
ctx=self.ctx,
318+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
319+
)
320+
elif mirostat_mode.value == 1:
321+
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
322+
mirostat_m = llama_cpp.c_int(100)
309323
llama_cpp.llama_sample_temperature(
310324
ctx=self.ctx,
311-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
325+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
312326
temp=temp,
313327
)
314-
llama_cpp.llama_sample_token_mirostat(
328+
return llama_cpp.llama_sample_token_mirostat(
315329
ctx=self.ctx,
316-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
330+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
317331
tau=mirostat_tau,
318332
eta=mirostat_eta,
319-
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
320-
m=mirostat_m
333+
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
334+
m=mirostat_m,
321335
)
322336
elif mirostat_mode.value == 2:
337+
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
323338
llama_cpp.llama_sample_temperature(
324339
ctx=self.ctx,
325340
candidates=llama_cpp.ctypes.pointer(candidates),
326341
temp=temp,
327342
)
328-
llama_cpp.llama_sample_token_mirostat_v2(
343+
return llama_cpp.llama_sample_token_mirostat_v2(
329344
ctx=self.ctx,
330-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
345+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
331346
tau=mirostat_tau,
332347
eta=mirostat_eta,
333-
mu=llama_cpp.ctypes.byref(mirostat_mu) # type: ignore
334-
)
335-
llama_cpp.llama_sample_frequency_and_presence_penalties(
336-
ctx=self.ctx,
337-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
338-
last_tokens_data=last_n_tokens_data,
339-
last_tokens_size=last_n_tokens_size,
340-
alpha_frequency=frequency_penalty,
341-
alpha_presence=presence_penalty,
342-
)
343-
if float(temp.value) == 0.0:
344-
return llama_cpp.llama_sample_token_greedy(
345-
ctx=self.ctx,
346-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
348+
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
347349
)
348350
else:
349351
llama_cpp.llama_sample_top_k(
@@ -355,7 +357,7 @@ def _sample(
355357
llama_cpp.llama_sample_tail_free(
356358
ctx=self.ctx,
357359
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
358-
z=llama_cpp.c_float(1.0),
360+
z=tfs_z,
359361
min_keep=llama_cpp.c_size_t(1),
360362
)
361363
llama_cpp.llama_sample_typical(
@@ -382,17 +384,16 @@ def _sample(
382384

383385
def sample(
384386
self,
385-
top_k: int,
386-
top_p: float,
387-
temp: float,
388-
mirostat_mode: int,
389-
mirostat_tau: float,
390-
mirostat_eta: float,
391-
mirostat_mu: float,
392-
mirostat_m: int,
393-
repeat_penalty: float,
387+
top_k: int = 40,
388+
top_p: float = 0.95,
389+
temp: float = 0.80,
390+
repeat_penalty: float = 1.1,
394391
frequency_penalty: float = 0.0,
395392
presence_penalty: float = 0.0,
393+
tfs_z: float = 1.0,
394+
mirostat_mode: int = 0,
395+
mirostat_eta: float = 0.1,
396+
mirostat_tau: float = 5.0,
396397
):
397398
"""Sample a token from the model.
398399
@@ -417,14 +418,13 @@ def sample(
417418
top_k=llama_cpp.c_int(top_k),
418419
top_p=llama_cpp.c_float(top_p),
419420
temp=llama_cpp.c_float(temp),
420-
mirostat_mode=llama_cpp.c_int(mirostat_mode),
421-
mirostat_mu=llama_cpp.c_float(mirostat_mu),
422-
mirostat_tau=llama_cpp.c_float(mirostat_tau),
423-
mirostat_eta=llama_cpp.c_float(mirostat_eta),
424-
mirostat_m=llama_cpp.c_int(mirostat_m),
421+
tfs_z=llama_cpp.c_float(tfs_z),
425422
repeat_penalty=llama_cpp.c_float(repeat_penalty),
426423
frequency_penalty=llama_cpp.c_float(frequency_penalty),
427424
presence_penalty=llama_cpp.c_float(presence_penalty),
425+
mirostat_mode=llama_cpp.c_int(mirostat_mode),
426+
mirostat_tau=llama_cpp.c_float(mirostat_tau),
427+
mirostat_eta=llama_cpp.c_float(mirostat_eta),
428428
)
429429

430430
def generate(
@@ -433,15 +433,13 @@ def generate(
433433
top_k: int,
434434
top_p: float,
435435
temp: float,
436-
mirostat_mode: int,
437-
mirostat_tau: float,
438-
mirostat_eta: float,
439-
mirostat_mu: float,
440-
mirostat_m: int,
441436
repeat_penalty: float,
437+
reset: bool = True,
442438
frequency_penalty: float = 0.0,
443439
presence_penalty: float = 0.0,
444-
reset: bool = True,
440+
mirostat_mode: int = 0,
441+
mirostat_tau: float = 5.0,
442+
mirostat_eta: float = 0.1,
445443
) -> Generator[
446444
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
447445
]:
@@ -494,14 +492,12 @@ def generate(
494492
top_k=top_k,
495493
top_p=top_p,
496494
temp=temp,
495+
repeat_penalty=repeat_penalty,
496+
frequency_penalty=frequency_penalty,
497+
presence_penalty=presence_penalty,
497498
mirostat_mode=mirostat_mode,
498499
mirostat_tau=mirostat_tau,
499500
mirostat_eta=mirostat_eta,
500-
mirostat_mu=mirostat_mu,
501-
mirostat_m=mirostat_m,
502-
frequency_penalty=frequency_penalty,
503-
presence_penalty=presence_penalty,
504-
repeat_penalty=repeat_penalty,
505501
)
506502
tokens_or_none = yield token
507503
tokens = [token]
@@ -571,11 +567,6 @@ def _create_completion(
571567
suffix: Optional[str] = None,
572568
max_tokens: int = 16,
573569
temperature: float = 0.8,
574-
mirostat_mode: int = 0,
575-
mirostat_tau: float = 5.0,
576-
mirostat_eta: float = 0.1,
577-
mirostat_mu: float = 10,
578-
mirostat_m: int = 100,
579570
top_p: float = 0.95,
580571
logprobs: Optional[int] = None,
581572
echo: bool = False,
@@ -585,6 +576,9 @@ def _create_completion(
585576
repeat_penalty: float = 1.1,
586577
top_k: int = 40,
587578
stream: bool = False,
579+
mirostat_mode: int = 0,
580+
mirostat_tau: float = 5.0,
581+
mirostat_eta: float = 0.1,
588582
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
589583
assert self.ctx is not None
590584
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@@ -643,8 +637,6 @@ def _create_completion(
643637
mirostat_mode=mirostat_mode,
644638
mirostat_tau=mirostat_tau,
645639
mirostat_eta=mirostat_eta,
646-
mirostat_mu=mirostat_mu,
647-
mirostat_m=mirostat_m,
648640
frequency_penalty=frequency_penalty,
649641
presence_penalty=presence_penalty,
650642
repeat_penalty=repeat_penalty,
@@ -817,11 +809,6 @@ def create_completion(
817809
suffix: Optional[str] = None,
818810
max_tokens: int = 128,
819811
temperature: float = 0.8,
820-
mirostat_mode: int = 0,
821-
mirostat_tau: float = 5.0,
822-
mirostat_eta: float = 0.1,
823-
mirostat_mu: float = 10,
824-
mirostat_m: int = 100,
825812
top_p: float = 0.95,
826813
logprobs: Optional[int] = None,
827814
echo: bool = False,
@@ -831,6 +818,9 @@ def create_completion(
831818
repeat_penalty: float = 1.1,
832819
top_k: int = 40,
833820
stream: bool = False,
821+
mirostat_mode: int = 0,
822+
mirostat_tau: float = 5.0,
823+
mirostat_eta: float = 0.1,
834824
) -> Union[Completion, Iterator[CompletionChunk]]:
835825
"""Generate text from a prompt.
836826
@@ -859,11 +849,6 @@ def create_completion(
859849
suffix=suffix,
860850
max_tokens=max_tokens,
861851
temperature=temperature,
862-
mirostat_mode=mirostat_mode,
863-
mirostat_tau=mirostat_tau,
864-
mirostat_eta=mirostat_eta,
865-
mirostat_mu=mirostat_mu,
866-
mirostat_m=mirostat_m,
867852
top_p=top_p,
868853
logprobs=logprobs,
869854
echo=echo,
@@ -873,6 +858,9 @@ def create_completion(
873858
repeat_penalty=repeat_penalty,
874859
top_k=top_k,
875860
stream=stream,
861+
mirostat_mode=mirostat_mode,
862+
mirostat_tau=mirostat_tau,
863+
mirostat_eta=mirostat_eta,
876864
)
877865
if stream:
878866
chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -886,11 +874,6 @@ def __call__(
886874
suffix: Optional[str] = None,
887875
max_tokens: int = 128,
888876
temperature: float = 0.8,
889-
mirostat_mode: int = 0,
890-
mirostat_tau: float = 5.0,
891-
mirostat_eta: float = 0.1,
892-
mirostat_mu: float = 10,
893-
mirostat_m: int = 100,
894877
top_p: float = 0.95,
895878
logprobs: Optional[int] = None,
896879
echo: bool = False,
@@ -900,6 +883,9 @@ def __call__(
900883
repeat_penalty: float = 1.1,
901884
top_k: int = 40,
902885
stream: bool = False,
886+
mirostat_mode: int = 0,
887+
mirostat_tau: float = 5.0,
888+
mirostat_eta: float = 0.1,
903889
) -> Union[Completion, Iterator[CompletionChunk]]:
904890
"""Generate text from a prompt.
905891
@@ -928,11 +914,6 @@ def __call__(
928914
suffix=suffix,
929915
max_tokens=max_tokens,
930916
temperature=temperature,
931-
mirostat_mode=mirostat_mode,
932-
mirostat_tau=mirostat_tau,
933-
mirostat_eta=mirostat_eta,
934-
mirostat_mu=mirostat_mu,
935-
mirostat_m=mirostat_m,
936917
top_p=top_p,
937918
logprobs=logprobs,
938919
echo=echo,
@@ -942,6 +923,9 @@ def __call__(
942923
repeat_penalty=repeat_penalty,
943924
top_k=top_k,
944925
stream=stream,
926+
mirostat_mode=mirostat_mode,
927+
mirostat_tau=mirostat_tau,
928+
mirostat_eta=mirostat_eta,
945929
)
946930

947931
def _convert_text_completion_to_chat(
@@ -1014,6 +998,9 @@ def create_chat_completion(
1014998
presence_penalty: float = 0.0,
1015999
frequency_penalty: float = 0.0,
10161000
repeat_penalty: float = 1.1,
1001+
mirostat_mode: int = 0,
1002+
mirostat_tau: float = 5.0,
1003+
mirostat_eta: float = 0.1,
10171004
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
10181005
"""Generate a chat completion from a list of messages.
10191006
@@ -1048,6 +1035,9 @@ def create_chat_completion(
10481035
repeat_penalty=repeat_penalty,
10491036
presence_penalty=presence_penalty,
10501037
frequency_penalty=frequency_penalty,
1038+
mirostat_mode=mirostat_mode,
1039+
mirostat_tau=mirostat_tau,
1040+
mirostat_eta=mirostat_eta,
10511041
)
10521042
if stream:
10531043
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.