Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 28103f4

Browse filesBrowse files
Server: fix seed for multiple slots (#6835)
* Server: add tests for consistent results * sampling: separate rng per sampling context
1 parent c0d1b3e commit 28103f4
Copy full SHA for 28103f4

File tree

Expand file treeCollapse file tree

11 files changed

+145
-30
lines changed
Filter options
Expand file treeCollapse file tree

11 files changed

+145
-30
lines changed

‎common/common.cpp

Copy file name to clipboardExpand all lines: common/common.cpp
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
242242
invalid_param = true;
243243
return true;
244244
}
245+
// This is temporary, in the future the samplign state will be moved fully to llama_sampling_context.
245246
params.seed = std::stoul(argv[i]);
247+
sparams.seed = std::stoul(argv[i]);
246248
return true;
247249
}
248250
if (arg == "-t" || arg == "--threads") {

‎common/sampling.cpp

Copy file name to clipboardExpand all lines: common/sampling.cpp
+12-1Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
#define LLAMA_API_INTERNAL
12
#include "sampling.h"
3+
#include <random>
24

35
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
46
struct llama_sampling_context * result = new llama_sampling_context();
@@ -33,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
3335

3436
result->prev.resize(params.n_prev);
3537

38+
llama_sampling_set_rng_seed(result, params.seed);
39+
3640
return result;
3741
}
3842

@@ -62,6 +66,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
6266
ctx->cur.clear();
6367
}
6468

69+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
70+
if (seed == LLAMA_DEFAULT_SEED) {
71+
seed = time(NULL);
72+
}
73+
ctx->rng.seed(seed);
74+
}
75+
6576
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
6677
if (dst->grammar) {
6778
llama_grammar_free(dst->grammar);
@@ -203,7 +214,7 @@ static llama_token llama_sampling_sample_impl(
203214

204215
sampler_queue(ctx_main, params, cur_p, min_keep);
205216

206-
id = llama_sample_token(ctx_main, &cur_p);
217+
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
207218

208219
//{
209220
// const int n_top = 10;

‎common/sampling.h

Copy file name to clipboardExpand all lines: common/sampling.h
+27-20Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
#include "grammar-parser.h"
66

7+
#include <random>
78
#include <string>
8-
#include <vector>
99
#include <unordered_map>
10+
#include <vector>
1011

1112
// sampler types
1213
enum class llama_sampler_type : char {
@@ -20,25 +21,26 @@ enum class llama_sampler_type : char {
2021

2122
// sampling parameters
2223
typedef struct llama_sampling_params {
23-
int32_t n_prev = 64; // number of previous tokens to remember
24-
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
25-
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
26-
int32_t top_k = 40; // <= 0 to use vocab size
27-
float top_p = 0.95f; // 1.0 = disabled
28-
float min_p = 0.05f; // 0.0 = disabled
29-
float tfs_z = 1.00f; // 1.0 = disabled
30-
float typical_p = 1.00f; // 1.0 = disabled
31-
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
32-
float dynatemp_range = 0.00f; // 0.0 = disabled
33-
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
34-
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
35-
float penalty_repeat = 1.00f; // 1.0 = disabled
36-
float penalty_freq = 0.00f; // 0.0 = disabled
37-
float penalty_present = 0.00f; // 0.0 = disabled
38-
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
39-
float mirostat_tau = 5.00f; // target entropy
40-
float mirostat_eta = 0.10f; // learning rate
41-
bool penalize_nl = false; // consider newlines as a repeatable token
24+
int32_t n_prev = 64; // number of previous tokens to remember
25+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
26+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
27+
int32_t top_k = 40; // <= 0 to use vocab size
28+
float top_p = 0.95f; // 1.0 = disabled
29+
float min_p = 0.05f; // 0.0 = disabled
30+
float tfs_z = 1.00f; // 1.0 = disabled
31+
float typical_p = 1.00f; // 1.0 = disabled
32+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
33+
float dynatemp_range = 0.00f; // 0.0 = disabled
34+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
35+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
36+
float penalty_repeat = 1.00f; // 1.0 = disabled
37+
float penalty_freq = 0.00f; // 0.0 = disabled
38+
float penalty_present = 0.00f; // 0.0 = disabled
39+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
40+
float mirostat_tau = 5.00f; // target entropy
41+
float mirostat_eta = 0.10f; // learning rate
42+
bool penalize_nl = false; // consider newlines as a repeatable token
43+
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
4244

4345
std::vector<llama_sampler_type> samplers_sequence = {
4446
llama_sampler_type::TOP_K,
@@ -79,6 +81,8 @@ struct llama_sampling_context {
7981
// TODO: replace with ring-buffer
8082
std::vector<llama_token> prev;
8183
std::vector<llama_token_data> cur;
84+
85+
std::mt19937 rng;
8286
};
8387

8488
#include "common.h"
@@ -93,6 +97,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
9397
// - reset grammar
9498
void llama_sampling_reset(llama_sampling_context * ctx);
9599

100+
// Set the sampler seed
101+
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
102+
96103
// Copy the sampler context
97104
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
98105

‎examples/lookup/lookup-stats.cpp

Copy file name to clipboardExpand all lines: examples/lookup/lookup-stats.cpp
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ int main(int argc, char ** argv){
3030

3131
// load the model
3232
std::tie(model, ctx) = llama_init_from_gpt_params(params);
33-
llama_set_rng_seed(ctx, params.seed);
3433
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
3534

3635
// tokenize the prompt

‎examples/lookup/lookup.cpp

Copy file name to clipboardExpand all lines: examples/lookup/lookup.cpp
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ int main(int argc, char ** argv){
3838

3939
// load the model
4040
std::tie(model, ctx) = llama_init_from_gpt_params(params);
41-
llama_set_rng_seed(ctx, params.seed);
4241
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
4342

4443
// tokenize the prompt

‎examples/main/main.cpp

Copy file name to clipboardExpand all lines: examples/main/main.cpp
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
240240
return 1;
241241
}
242242
session_tokens.resize(n_token_count_out);
243-
llama_set_rng_seed(ctx, params.seed);
244243
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
245244
}
246245
}

‎examples/server/server.cpp

Copy file name to clipboardExpand all lines: examples/server/server.cpp
+1-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ struct server_context {
854854
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
855855
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
856856
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
857-
slot.params.seed = json_value(data, "seed", default_params.seed);
857+
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
858858
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
859859
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
860860

@@ -1028,7 +1028,6 @@ struct server_context {
10281028
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
10291029
return false;
10301030
}
1031-
llama_set_rng_seed(ctx, slot.params.seed);
10321031
}
10331032

10341033
slot.command = SLOT_COMMAND_LOAD_PROMPT;
+57Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
@llama.cpp
2+
@results
3+
Feature: Results
4+
5+
Background: Server startup
6+
Given a server listening on localhost:8080
7+
And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models
8+
And a model file test-model-00001-of-00003.gguf
9+
And 128 as batch size
10+
And 256 KV cache size
11+
And 128 max tokens to predict
12+
13+
Scenario Outline: Multi users completion
14+
Given <n_slots> slots
15+
And continuous batching
16+
Then the server is starting
17+
Then the server is healthy
18+
19+
Given 42 as seed
20+
And a prompt:
21+
"""
22+
Write a very long story about AI.
23+
"""
24+
25+
Given 42 as seed
26+
And a prompt:
27+
"""
28+
Write a very long story about AI.
29+
"""
30+
31+
Given 42 as seed
32+
And a prompt:
33+
"""
34+
Write a very long story about AI.
35+
"""
36+
37+
Given 42 as seed
38+
And a prompt:
39+
"""
40+
Write a very long story about AI.
41+
"""
42+
43+
Given 42 as seed
44+
And a prompt:
45+
"""
46+
Write a very long story about AI.
47+
"""
48+
49+
Given concurrent completion requests
50+
Then the server is busy
51+
Then the server is idle
52+
And all slots are idle
53+
Then all predictions are equal
54+
Examples:
55+
| n_slots |
56+
| 1 |
57+
| 2 |

‎examples/server/tests/features/steps/steps.py

Copy file name to clipboardExpand all lines: examples/server/tests/features/steps/steps.py
+34Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port):
6161
context.server_metrics = False
6262
context.server_process = None
6363
context.seed = None
64+
context.draft = None
6465
context.server_seed = None
6566
context.user_api_key = None
6667
context.response_format = None
@@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl):
107108
context.n_gpu_layer = ngl
108109

109110

111+
@step('{draft:d} as draft')
112+
def step_draft(context, draft):
113+
context.draft = draft
114+
115+
110116
@step('{n_ctx:d} KV cache size')
111117
def step_n_ctx(context, n_ctx):
112118
context.n_ctx = n_ctx
@@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n):
254260
assert_n_tokens_predicted(context.completion, predicted_n)
255261

256262

263+
@step('all predictions are equal')
264+
@async_run_until_complete
265+
async def step_predictions_equal(context):
266+
n_completions = await gather_tasks_results(context)
267+
assert n_completions >= 2, "need at least 2 completions"
268+
assert_all_predictions_equal(context.tasks_result)
269+
context.tasks_result = []
270+
271+
257272
@step('the completion is truncated')
258273
def step_assert_completion_truncated(context):
259274
step_assert_completion_truncated(context, '')
@@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
10201035
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
10211036
f' {n_predicted} <> {expected_predicted_n}')
10221037

1038+
def assert_all_predictions_equal(completion_responses):
1039+
content_0 = completion_responses[0]['content']
1040+
1041+
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
1042+
print(f"content 0: {content_0}")
1043+
1044+
i = 1
1045+
for response in completion_responses[1:]:
1046+
content = response['content']
1047+
1048+
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
1049+
print(f"content {i}: {content}")
1050+
1051+
assert content == content_0, "contents not equal"
1052+
1053+
i += 1
1054+
10231055

10241056
async def gather_tasks_results(context):
10251057
n_tasks = len(context.concurrent_tasks)
@@ -1148,6 +1180,8 @@ def start_server_background(context):
11481180
server_args.extend(['--ubatch-size', context.n_ubatch])
11491181
if context.n_gpu_layer:
11501182
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
1183+
if context.draft is not None:
1184+
server_args.extend(['--draft', context.draft])
11511185
if context.server_continuous_batching:
11521186
server_args.append('--cont-batching')
11531187
if context.server_embeddings:

‎llama.cpp

Copy file name to clipboardExpand all lines: llama.cpp
+5-2Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13667,7 +13667,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
1366713667
return result;
1366813668
}
1366913669

13670-
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13670+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
1367113671
GGML_ASSERT(ctx);
1367213672

1367313673
const int64_t t_start_sample_us = ggml_time_us();
@@ -13680,7 +13680,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1368013680
}
1368113681

1368213682
std::discrete_distribution<> dist(probs.begin(), probs.end());
13683-
auto & rng = ctx->rng;
1368413683
int idx = dist(rng);
1368513684

1368613685
llama_token result = candidates->data[idx].id;
@@ -13690,6 +13689,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
1369013689
return result;
1369113690
}
1369213691

13692+
llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
13693+
return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
13694+
}
13695+
1369313696
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
1369413697
const int64_t t_start_sample_us = ggml_time_us();
1369513698

‎llama.h

Copy file name to clipboardExpand all lines: llama.h
+7-2Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ extern "C" {
987987
struct llama_context * ctx,
988988
llama_token_data_array * candidates);
989989

990-
/// @details Randomly selects a token from the candidates based on their probabilities.
990+
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
991991
LLAMA_API llama_token llama_sample_token(
992992
struct llama_context * ctx,
993993
llama_token_data_array * candidates);
@@ -1074,8 +1074,9 @@ extern "C" {
10741074
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
10751075
#ifdef LLAMA_API_INTERNAL
10761076

1077-
#include <vector>
1077+
#include <random>
10781078
#include <string>
1079+
#include <vector>
10791080

10801081
struct ggml_tensor;
10811082

@@ -1112,6 +1113,10 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
11121113
const std::string & src,
11131114
llama_partial_utf8 partial_start);
11141115

1116+
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
1117+
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
1118+
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
1119+
11151120
#endif // LLAMA_API_INTERNAL
11161121

11171122
#endif // LLAMA_H

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.