Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f93f315

Browse filesBrowse files
mscheong01hazelnutcloud
authored andcommitted
speculative : implement stochastic speculative sampling (ggml-org#5625)
* (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix ggml-org#5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README
1 parent d667ada commit f93f315
Copy full SHA for f93f315

File tree

Expand file treeCollapse file tree

6 files changed

+256
-57
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+256
-57
lines changed

‎common/common.cpp

Copy file name to clipboardExpand all lines: common/common.cpp
-7Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
513513
break;
514514
}
515515
params.n_sequences = std::stoi(argv[i]);
516-
} else if (arg == "--p-accept" || arg == "-pa") {
517-
if (++i >= argc) {
518-
invalid_param = true;
519-
break;
520-
}
521-
params.p_accept = std::stof(argv[i]);
522516
} else if (arg == "--p-split" || arg == "-ps") {
523517
if (++i >= argc) {
524518
invalid_param = true;
@@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
10441038
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
10451039
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
10461040
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
1047-
printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
10481041
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
10491042
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
10501043
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");

‎common/common.h

Copy file name to clipboardExpand all lines: common/common.h
+1-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,10 @@ struct gpt_params {
5353
int32_t n_ctx = 512; // context size
5454
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
5555
int32_t n_keep = 0; // number of tokens to keep from initial prompt
56-
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
56+
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
5757
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
5858
int32_t n_parallel = 1; // number of parallel sequences to decode
5959
int32_t n_sequences = 1; // number of sequences to decode
60-
float p_accept = 0.5f; // speculative decoding accept probability
6160
float p_split = 0.1f; // speculative decoding split probability
6261
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
6362
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)

‎common/sampling.cpp

Copy file name to clipboardExpand all lines: common/sampling.cpp
+79Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
295295
return id;
296296
}
297297

298+
static llama_token_data_array llama_sample_probability_distribution_impl(
299+
struct llama_sampling_context * ctx_sampling,
300+
struct llama_context * ctx_main,
301+
struct llama_context * ctx_cfg,
302+
const int idx) {
303+
const llama_sampling_params & params = ctx_sampling->params;
304+
305+
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
306+
307+
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
308+
const float penalty_repeat = params.penalty_repeat;
309+
const float penalty_freq = params.penalty_freq;
310+
const float penalty_present = params.penalty_present;
311+
const bool penalize_nl = params.penalize_nl;
312+
313+
auto & prev = ctx_sampling->prev;
314+
auto & cur = ctx_sampling->cur;
315+
316+
// Get a pointer to the logits
317+
float * logits = llama_get_logits_ith(ctx_main, idx);
318+
319+
// Declare original_logits at the beginning of the function scope
320+
std::vector<float> original_logits;
321+
322+
// apply params.logit_bias map
323+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
324+
logits[it->first] += it->second;
325+
}
326+
327+
if (ctx_cfg) {
328+
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
329+
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
330+
}
331+
332+
cur.clear();
333+
334+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
335+
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
336+
}
337+
338+
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
339+
340+
// apply penalties
341+
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
342+
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
343+
if (penalty_tokens_used_size) {
344+
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
345+
346+
llama_sample_repetition_penalties(ctx_main, &cur_p,
347+
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
348+
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
349+
350+
if (!penalize_nl) {
351+
for (size_t idx = 0; idx < cur_p.size; idx++) {
352+
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
353+
cur_p.data[idx].logit = nl_logit;
354+
break;
355+
}
356+
}
357+
}
358+
}
359+
360+
// apply grammar checks
361+
if (ctx_sampling->grammar != NULL) {
362+
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
363+
}
364+
365+
llama_sample_softmax(ctx_main, &cur_p);
366+
return cur_p;
367+
}
368+
298369
llama_token llama_sampling_sample(
299370
struct llama_sampling_context * ctx_sampling,
300371
struct llama_context * ctx_main,
@@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
304375
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
305376
}
306377

378+
llama_token_data_array llama_sampling_probability_distribution(
379+
struct llama_sampling_context * ctx_sampling,
380+
struct llama_context * ctx_main,
381+
struct llama_context * ctx_cfg,
382+
const int idx) {
383+
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
384+
}
385+
307386
void llama_sampling_accept(
308387
struct llama_sampling_context * ctx_sampling,
309388
struct llama_context * ctx_main,

‎common/sampling.h

Copy file name to clipboardExpand all lines: common/sampling.h
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
131131
struct llama_context * ctx_cfg,
132132
int idx = 0);
133133

134+
// returns the probability that token of given id will be sampled
135+
llama_token_data_array llama_sampling_probability_distribution(
136+
struct llama_sampling_context * ctx_sampling,
137+
struct llama_context * ctx_main,
138+
struct llama_context * ctx_cfg,
139+
int idx = 0);
140+
134141
void llama_sampling_accept(
135142
struct llama_sampling_context * ctx_sampling,
136143
struct llama_context * ctx_main,

‎examples/speculative/README.md

Copy file name to clipboardExpand all lines: examples/speculative/README.md
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ More info:
66

77
- https://github.com/ggerganov/llama.cpp/pull/2926
88
- https://github.com/ggerganov/llama.cpp/pull/3624
9+
- https://github.com/ggerganov/llama.cpp/pull/5625

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.