Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7837c3f

Browse filesBrowse files
committed
Fix return types and import comments
1 parent 55d6308 commit 7837c3f
Copy full SHA for 7837c3f

File tree

Expand file treeCollapse file tree

1 file changed

+38
-34
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+38
-34
lines changed

‎llama_cpp/llama_cpp.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_cpp.py
+38-34Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -427,13 +427,16 @@ def llama_token_nl() -> llama_token:
427427

428428

429429
# Sampling functions
430+
431+
432+
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
430433
def llama_sample_repetition_penalty(
431434
ctx: llama_context_p,
432435
candidates,
433436
last_tokens_data,
434437
last_tokens_size: c_int,
435438
penalty: c_float,
436-
) -> llama_token:
439+
):
437440
return _lib.llama_sample_repetition_penalty(
438441
ctx, candidates, last_tokens_data, last_tokens_size, penalty
439442
)
@@ -446,18 +449,18 @@ def llama_sample_repetition_penalty(
446449
c_int,
447450
c_float,
448451
]
449-
_lib.llama_sample_repetition_penalty.restype = llama_token
452+
_lib.llama_sample_repetition_penalty.restype = None
450453

451454

452-
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
455+
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
453456
def llama_sample_frequency_and_presence_penalties(
454457
ctx: llama_context_p,
455458
candidates,
456459
last_tokens_data,
457460
last_tokens_size: c_int,
458461
alpha_frequency: c_float,
459462
alpha_presence: c_float,
460-
) -> llama_token:
463+
):
461464
return _lib.llama_sample_frequency_and_presence_penalties(
462465
ctx,
463466
candidates,
@@ -476,25 +479,23 @@ def llama_sample_frequency_and_presence_penalties(
476479
c_float,
477480
c_float,
478481
]
479-
_lib.llama_sample_frequency_and_presence_penalties.restype = llama_token
482+
_lib.llama_sample_frequency_and_presence_penalties.restype = None
480483

481484

482-
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
483-
def llama_sample_softmax(ctx: llama_context_p, candidates) -> llama_token:
485+
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
486+
def llama_sample_softmax(ctx: llama_context_p, candidates):
484487
return _lib.llama_sample_softmax(ctx, candidates)
485488

486489

487490
_lib.llama_sample_softmax.argtypes = [
488491
llama_context_p,
489492
llama_token_data_array_p,
490493
]
491-
_lib.llama_sample_softmax.restype = llama_token
494+
_lib.llama_sample_softmax.restype = None
492495

493496

494-
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
495-
def llama_sample_top_k(
496-
ctx: llama_context_p, candidates, k: c_int, min_keep: c_int
497-
) -> llama_token:
497+
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
498+
def llama_sample_top_k(ctx: llama_context_p, candidates, k: c_int, min_keep: c_int):
498499
return _lib.llama_sample_top_k(ctx, candidates, k, min_keep)
499500

500501

@@ -504,12 +505,11 @@ def llama_sample_top_k(
504505
c_int,
505506
c_int,
506507
]
508+
_lib.llama_sample_top_k.restype = None
507509

508510

509-
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
510-
def llama_sample_top_p(
511-
ctx: llama_context_p, candidates, p: c_float, min_keep: c_int
512-
) -> llama_token:
511+
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
512+
def llama_sample_top_p(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
513513
return _lib.llama_sample_top_p(ctx, candidates, p, min_keep)
514514

515515

@@ -519,13 +519,13 @@ def llama_sample_top_p(
519519
c_float,
520520
c_int,
521521
]
522-
_lib.llama_sample_top_p.restype = llama_token
522+
_lib.llama_sample_top_p.restype = None
523523

524524

525-
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
525+
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
526526
def llama_sample_tail_free(
527527
ctx: llama_context_p, candidates, z: c_float, min_keep: c_int
528-
) -> llama_token:
528+
):
529529
return _lib.llama_sample_tail_free(ctx, candidates, z, min_keep)
530530

531531

@@ -535,13 +535,11 @@ def llama_sample_tail_free(
535535
c_float,
536536
c_int,
537537
]
538-
_lib.llama_sample_tail_free.restype = llama_token
538+
_lib.llama_sample_tail_free.restype = None
539539

540540

541-
# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
542-
def llama_sample_typical(
543-
ctx: llama_context_p, candidates, p: c_float, min_keep: c_int
544-
) -> llama_token:
541+
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
542+
def llama_sample_typical(ctx: llama_context_p, candidates, p: c_float, min_keep: c_int):
545543
return _lib.llama_sample_typical(ctx, candidates, p, min_keep)
546544

547545

@@ -551,13 +549,10 @@ def llama_sample_typical(
551549
c_float,
552550
c_int,
553551
]
554-
_lib.llama_sample_typical.restype = llama_token
552+
_lib.llama_sample_typical.restype = None
555553

556554

557-
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
558-
def llama_sample_temperature(
559-
ctx: llama_context_p, candidates, temp: c_float
560-
) -> llama_token:
555+
def llama_sample_temperature(ctx: llama_context_p, candidates, temp: c_float):
561556
return _lib.llama_sample_temperature(ctx, candidates, temp)
562557

563558

@@ -566,10 +561,15 @@ def llama_sample_temperature(
566561
llama_token_data_array_p,
567562
c_float,
568563
]
569-
_lib.llama_sample_temperature.restype = llama_token
564+
_lib.llama_sample_temperature.restype = None
570565

571566

572-
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
567+
# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
568+
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
569+
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
570+
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
571+
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
572+
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
573573
def llama_sample_token_mirostat(
574574
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, m: c_int, mu
575575
) -> llama_token:
@@ -587,7 +587,11 @@ def llama_sample_token_mirostat(
587587
_lib.llama_sample_token_mirostat.restype = llama_token
588588

589589

590-
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
590+
# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
591+
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
592+
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
593+
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
594+
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
591595
def llama_sample_token_mirostat_v2(
592596
ctx: llama_context_p, candidates, tau: c_float, eta: c_float, mu
593597
) -> llama_token:
@@ -604,7 +608,7 @@ def llama_sample_token_mirostat_v2(
604608
_lib.llama_sample_token_mirostat_v2.restype = llama_token
605609

606610

607-
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
611+
# @details Selects the token with the highest probability.
608612
def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
609613
return _lib.llama_sample_token_greedy(ctx, candidates)
610614

@@ -616,7 +620,7 @@ def llama_sample_token_greedy(ctx: llama_context_p, candidates) -> llama_token:
616620
_lib.llama_sample_token_greedy.restype = llama_token
617621

618622

619-
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
623+
# @details Randomly selects a token from the candidates based on their probabilities.
620624
def llama_sample_token(ctx: llama_context_p, candidates) -> llama_token:
621625
return _lib.llama_sample_token(ctx, candidates)
622626

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.