Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 68cf1e5

Browse filesBrowse files
mscheong01tybalex
authored andcommitted
sampling : deduplicated code for probability distribution access (ggml-org#6240)
* sampling: remove duplicated code for probability distribution access * free original_logits * fix original_logits allocation * fixes based on review @cebtenzzre * change function name to `llama_sampling_prepare`
1 parent bd69ff2 commit 68cf1e5
Copy full SHA for 68cf1e5

File tree

Expand file treeCollapse file tree

4 files changed

+28
-76
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+28
-76
lines changed

‎common/sampling.cpp

Copy file name to clipboardExpand all lines: common/sampling.cpp
+21-72Lines changed: 21 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -168,76 +168,19 @@ static llama_token llama_sampling_sample_impl(
168168
bool is_resampling) { // Add a parameter to indicate if we are resampling
169169
const llama_sampling_params & params = ctx_sampling->params;
170170

171-
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
172-
173171
const float temp = params.temp;
174-
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
175-
const float penalty_repeat = params.penalty_repeat;
176-
const float penalty_freq = params.penalty_freq;
177-
const float penalty_present = params.penalty_present;
178172
const int mirostat = params.mirostat;
179173
const float mirostat_tau = params.mirostat_tau;
180174
const float mirostat_eta = params.mirostat_eta;
181-
const bool penalize_nl = params.penalize_nl;
182175

183-
auto & prev = ctx_sampling->prev;
184-
auto & cur = ctx_sampling->cur;
185-
186-
llama_token id = 0;
187-
188-
// Get a pointer to the logits
189-
float * logits = llama_get_logits_ith(ctx_main, idx);
190-
191-
// Declare original_logits at the beginning of the function scope
192176
std::vector<float> original_logits;
193-
177+
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
194178
if (!is_resampling) {
195-
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
196-
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
197-
}
198-
199-
// apply params.logit_bias map
200-
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
201-
logits[it->first] += it->second;
202-
}
203-
204-
if (ctx_cfg) {
205-
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
206-
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
207-
}
208-
209-
cur.clear();
210-
211-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
212-
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
213-
}
214-
215-
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
216-
217-
// apply penalties
218-
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
219-
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
220-
if (penalty_tokens_used_size) {
221-
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
222-
223-
llama_sample_repetition_penalties(ctx_main, &cur_p,
224-
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
225-
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
226-
227-
if (!penalize_nl) {
228-
for (size_t idx = 0; idx < cur_p.size; idx++) {
229-
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
230-
cur_p.data[idx].logit = nl_logit;
231-
break;
232-
}
233-
}
234-
}
235-
}
236-
237-
// If we are in the resampling phase, apply grammar checks before sampling logic
238-
if (is_resampling && ctx_sampling->grammar != NULL) {
239-
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
179+
GGML_ASSERT(!original_logits.empty());
240180
}
181+
llama_token id = 0;
182+
// Get a pointer to the logits
183+
float * logits = llama_get_logits_ith(ctx_main, idx);
241184

242185
if (temp < 0.0) {
243186
// greedy sampling, with probs
@@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
302245
return id;
303246
}
304247

305-
static llama_token_data_array llama_sample_probability_distribution_impl(
248+
static llama_token_data_array llama_sampling_prepare_impl(
306249
struct llama_sampling_context * ctx_sampling,
307250
struct llama_context * ctx_main,
308251
struct llama_context * ctx_cfg,
309-
const int idx) {
252+
const int idx,
253+
bool apply_grammar,
254+
std::vector<float> * original_logits) {
310255
const llama_sampling_params & params = ctx_sampling->params;
311256

312257
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
315260
const float penalty_repeat = params.penalty_repeat;
316261
const float penalty_freq = params.penalty_freq;
317262
const float penalty_present = params.penalty_present;
263+
318264
const bool penalize_nl = params.penalize_nl;
319265

320266
auto & prev = ctx_sampling->prev;
@@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
323269
// Get a pointer to the logits
324270
float * logits = llama_get_logits_ith(ctx_main, idx);
325271

326-
// Declare original_logits at the beginning of the function scope
327-
std::vector<float> original_logits;
272+
if (apply_grammar && original_logits != NULL) {
273+
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
274+
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
275+
}
328276

329277
// apply params.logit_bias map
330278
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
@@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
364312
}
365313
}
366314

367-
// apply grammar checks
368-
if (ctx_sampling->grammar != NULL) {
315+
// apply grammar checks before sampling logic
316+
if (apply_grammar && ctx_sampling->grammar != NULL) {
369317
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
370318
}
371319

372-
llama_sample_softmax(ctx_main, &cur_p);
373320
return cur_p;
374321
}
375322

@@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
382329
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
383330
}
384331

385-
llama_token_data_array llama_sampling_probability_distribution(
332+
llama_token_data_array llama_sampling_prepare(
386333
struct llama_sampling_context * ctx_sampling,
387334
struct llama_context * ctx_main,
388335
struct llama_context * ctx_cfg,
389-
const int idx) {
390-
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
336+
const int idx,
337+
bool apply_grammar,
338+
std::vector<float> * original_logits) {
339+
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
391340
}
392341

393342
void llama_sampling_accept(

‎common/sampling.h

Copy file name to clipboardExpand all lines: common/sampling.h
+5-3Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
131131
struct llama_context * ctx_cfg,
132132
int idx = 0);
133133

134-
// returns the probability that token of given id will be sampled
135-
llama_token_data_array llama_sampling_probability_distribution(
134+
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
135+
llama_token_data_array llama_sampling_prepare(
136136
struct llama_sampling_context * ctx_sampling,
137137
struct llama_context * ctx_main,
138138
struct llama_context * ctx_cfg,
139-
int idx = 0);
139+
int idx = 0,
140+
bool apply_grammar = true,
141+
std::vector<float> * original_logits = nullptr);
140142

141143
void llama_sampling_accept(
142144
struct llama_sampling_context * ctx_sampling,

‎examples/speculative/speculative.cpp

Copy file name to clipboardExpand all lines: examples/speculative/speculative.cpp
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ int main(int argc, char ** argv) {
219219
if (params.sparams.temp > 0) {
220220
// stochastic verification
221221

222-
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
222+
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
223+
llama_sample_softmax(ctx_tgt, &dist_tgt);
223224
float p_tgt = 0, p_dft = 0;
224225

225226
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());

‎retrieval

Copy file name to clipboard
1.56 MB
Binary file not shown.

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.