Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit db39019

Browse filesBrowse files
committed
llama : move grammar code into llama-grammar
ggml-ci
1 parent 516746a commit db39019
Copy full SHA for db39019

File tree

Expand file treeCollapse file tree

12 files changed

+741
-672
lines changed
Filter options
Expand file treeCollapse file tree

12 files changed

+741
-672
lines changed

‎Makefile

Copy file name to clipboardExpand all lines: Makefile
+17-1Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,8 @@ OBJ_GGML += \
868868

869869
OBJ_LLAMA = \
870870
src/llama.o \
871+
src/llama-vocab.o \
872+
src/llama-grammar.o \
871873
src/llama-sampling.o \
872874
src/unicode.o \
873875
src/unicode-data.o
@@ -1058,6 +1060,20 @@ src/llama.o: \
10581060
ggml/include/ggml-backend.h
10591061
$(CXX) $(CXXFLAGS) -c $< -o $@
10601062

1063+
src/llama-vocab.o: \
1064+
src/llama-vocab.cpp \
1065+
src/llama-vocab.h \
1066+
src/llama-impl.h \
1067+
include/llama.h
1068+
$(CXX) $(CXXFLAGS) -c $< -o $@
1069+
1070+
src/llama-grammar.o: \
1071+
src/llama-grammar.cpp \
1072+
src/llama-grammar.h \
1073+
src/llama-impl.h \
1074+
include/llama.h
1075+
$(CXX) $(CXXFLAGS) -c $< -o $@
1076+
10611077
src/llama-sampling.o: \
10621078
src/llama-sampling.cpp \
10631079
src/llama-sampling.h \
@@ -1440,7 +1456,7 @@ run-benchmark-matmult: llama-benchmark-matmult
14401456
.PHONY: run-benchmark-matmult swift
14411457

14421458
tests/test-llama-grammar: tests/test-llama-grammar.cpp \
1443-
$(OBJ_GGML) $(OBJ_COMMON) src/unicode.o src/unicode-data.o
1459+
$(OBJ_ALL)
14441460
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
14451461
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
14461462

‎common/sampling.cpp

Copy file name to clipboardExpand all lines: common/sampling.cpp
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
330330
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
331331

332332
// Apply grammar constraints to the single token
333-
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
333+
llama_grammar_sample(ctx_main, &single_token_data_array, ctx_sampling->grammar);
334334

335335
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
336336
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
421421

422422
// apply grammar checks before sampling logic
423423
if (apply_grammar && ctx_sampling->grammar != NULL) {
424-
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
424+
llama_grammar_sample(ctx_main, &cur_p, ctx_sampling->grammar);
425425
}
426426

427427
return cur_p;

‎examples/gbnf-validator/gbnf-validator.cpp

Copy file name to clipboardExpand all lines: examples/gbnf-validator/gbnf-validator.cpp
+8-5Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,23 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
1616
auto decoded = decode_utf8(input_str, {});
1717
const auto & code_points = decoded.first;
1818

19+
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
20+
1921
size_t pos = 0;
2022
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
21-
auto prev_stacks = grammar->stacks;
22-
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
23-
if (grammar->stacks.empty()) {
23+
const llama_grammar_rules & prev_rules = llama_grammar_get_rules (grammar);
24+
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
25+
llama_grammar_accept(prev_rules, prev_stacks, *it, cur_stacks);
26+
if (cur_stacks.empty()) {
2427
error_pos = pos;
2528
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
26-
grammar->stacks = prev_stacks;
29+
cur_stacks = prev_stacks;
2730
return false;
2831
}
2932
++pos;
3033
}
3134

32-
for (const auto & stack : grammar->stacks) {
35+
for (const auto & stack : cur_stacks) {
3336
if (stack.empty()) {
3437
return true;
3538
}

‎include/llama.h

Copy file name to clipboardExpand all lines: include/llama.h
+31-28Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,18 @@ extern "C" {
10001000

10011001
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
10021002

1003+
/// @details Apply constraints from grammar
1004+
LLAMA_API void llama_grammar_sample(
1005+
struct llama_context * ctx,
1006+
llama_token_data_array * candidates,
1007+
const struct llama_grammar * grammar);
1008+
1009+
/// @details Accepts the sampled token into the grammar
1010+
LLAMA_API void llama_grammar_accept_token(
1011+
struct llama_context * ctx,
1012+
struct llama_grammar * grammar,
1013+
llama_token token);
1014+
10031015
//
10041016
// Sampling functions
10051017
//
@@ -1118,18 +1130,6 @@ extern "C" {
11181130
struct llama_context * ctx,
11191131
llama_token_data_array * candidates);
11201132

1121-
/// @details Apply constraints from grammar
1122-
LLAMA_API void llama_sample_grammar(
1123-
struct llama_context * ctx,
1124-
llama_token_data_array * candidates,
1125-
const struct llama_grammar * grammar);
1126-
1127-
/// @details Accepts the sampled token into the grammar
1128-
LLAMA_API void llama_grammar_accept_token(
1129-
struct llama_context * ctx,
1130-
struct llama_grammar * grammar,
1131-
llama_token token);
1132-
11331133
//
11341134
// Model split
11351135
//
@@ -1172,38 +1172,41 @@ extern "C" {
11721172

11731173
struct ggml_tensor;
11741174

1175+
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1176+
struct llama_context * ctx
1177+
);
1178+
11751179
struct llama_partial_utf8 {
11761180
uint32_t value; // bit value so far (unshifted)
11771181
int n_remain; // num bytes remaining; -1 indicates invalid sequence
11781182
};
11791183

1180-
struct llama_grammar {
1181-
const std::vector<std::vector<llama_grammar_element>> rules;
1182-
std::vector<std::vector<const llama_grammar_element *>> stacks;
1183-
1184-
// buffer for partially generated UTF-8 sequence from accepted tokens
1185-
llama_partial_utf8 partial_utf8;
1186-
};
1187-
11881184
struct llama_grammar_candidate {
11891185
size_t index;
11901186
const uint32_t * code_points;
11911187
llama_partial_utf8 partial_utf8;
11921188
};
11931189

1194-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1195-
struct llama_context * ctx
1196-
);
1190+
using llama_grammar_rules = std::vector<std::vector<llama_grammar_element>>;
1191+
using llama_grammar_stacks = std::vector<std::vector<const llama_grammar_element *>>;
1192+
1193+
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
1194+
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
11971195

11981196
void llama_grammar_accept(
1199-
const std::vector<std::vector<llama_grammar_element>> & rules,
1200-
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
1201-
const uint32_t chr,
1202-
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
1197+
const llama_grammar_rules & rules,
1198+
const llama_grammar_stacks & stacks,
1199+
const uint32_t chr,
1200+
llama_grammar_stacks & new_stacks);
1201+
1202+
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
1203+
const std::vector<std::vector<llama_grammar_element>> & rules,
1204+
const std::vector<const llama_grammar_element *> & stack,
1205+
const std::vector<llama_grammar_candidate> & candidates);
12031206

12041207
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
12051208
const std::string & src,
1206-
llama_partial_utf8 partial_start);
1209+
llama_partial_utf8 partial_start);
12071210

12081211
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
12091212
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.

‎src/CMakeLists.txt

Copy file name to clipboardExpand all lines: src/CMakeLists.txt
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ endif()
1414
add_library(llama
1515
../include/llama.h
1616
llama.cpp
17+
llama-vocab.cpp
18+
llama-grammar.cpp
1719
llama-sampling.cpp
1820
unicode.h
1921
unicode.cpp

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.