Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 96d14f4

Browse filesBrowse files
ggerganovarthw
authored andcommitted
llama : move vocab, grammar and sampling into separate files (ggml-org#8508)
* llama : move sampling code into llama-sampling ggml-ci * llama : move grammar code into llama-grammar ggml-ci * cont ggml-ci * cont : pre-fetch rules * cont ggml-ci * llama : deprecate llama_sample_grammar * llama : move tokenizers into llama-vocab ggml-ci * make : update llama.cpp deps [no ci] * llama : redirect external API to internal APIs ggml-ci * llama : suffix the internal APIs with "_impl" ggml-ci * llama : clean-up
1 parent 4b93675 commit 96d14f4
Copy full SHA for 96d14f4
Expand file treeCollapse file tree

18 files changed

+3656
-3103
lines changed

‎Makefile

Copy file name to clipboardExpand all lines: Makefile
+31-1Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,9 @@ OBJ_GGML += \
876876

877877
OBJ_LLAMA = \
878878
src/llama.o \
879+
src/llama-vocab.o \
880+
src/llama-grammar.o \
881+
src/llama-sampling.o \
879882
src/unicode.o \
880883
src/unicode-data.o
881884

@@ -1055,6 +1058,10 @@ src/unicode-data.o: \
10551058

10561059
src/llama.o: \
10571060
src/llama.cpp \
1061+
src/llama-impl.h \
1062+
src/llama-vocab.h \
1063+
src/llama-grammar.h \
1064+
src/llama-sampling.h \
10581065
src/unicode.h \
10591066
include/llama.h \
10601067
ggml/include/ggml-cuda.h \
@@ -1064,6 +1071,29 @@ src/llama.o: \
10641071
ggml/include/ggml-backend.h
10651072
$(CXX) $(CXXFLAGS) -c $< -o $@
10661073

1074+
src/llama-vocab.o: \
1075+
src/llama-vocab.cpp \
1076+
src/llama-vocab.h \
1077+
src/llama-impl.h \
1078+
include/llama.h
1079+
$(CXX) $(CXXFLAGS) -c $< -o $@
1080+
1081+
src/llama-grammar.o: \
1082+
src/llama-grammar.cpp \
1083+
src/llama-grammar.h \
1084+
src/llama-impl.h \
1085+
src/llama-vocab.h \
1086+
src/llama-sampling.h \
1087+
include/llama.h
1088+
$(CXX) $(CXXFLAGS) -c $< -o $@
1089+
1090+
src/llama-sampling.o: \
1091+
src/llama-sampling.cpp \
1092+
src/llama-sampling.h \
1093+
src/llama-impl.h \
1094+
include/llama.h
1095+
$(CXX) $(CXXFLAGS) -c $< -o $@
1096+
10671097
$(LIB_LLAMA): \
10681098
$(OBJ_LLAMA) \
10691099
$(LIB_GGML)
@@ -1439,7 +1469,7 @@ run-benchmark-matmult: llama-benchmark-matmult
14391469
.PHONY: run-benchmark-matmult swift
14401470

14411471
tests/test-llama-grammar: tests/test-llama-grammar.cpp \
1442-
$(OBJ_GGML) $(OBJ_COMMON) src/unicode.o src/unicode-data.o
1472+
$(OBJ_ALL)
14431473
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
14441474
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
14451475

‎Package.swift

Copy file name to clipboardExpand all lines: Package.swift
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import PackageDescription
44

55
var sources = [
66
"src/llama.cpp",
7+
"src/llama-vocab.cpp",
8+
"src/llama-grammar.cpp",
9+
"src/llama-sampling.cpp",
710
"src/unicode.cpp",
811
"src/unicode-data.cpp",
912
"ggml/src/ggml.c",

‎common/sampling.cpp

Copy file name to clipboardExpand all lines: common/sampling.cpp
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
330330
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
331331

332332
// Apply grammar constraints to the single token
333-
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
333+
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
334334

335335
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
336336
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
421421

422422
// apply grammar checks before sampling logic
423423
if (apply_grammar && ctx_sampling->grammar != NULL) {
424-
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
424+
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
425425
}
426426

427427
return cur_p;
@@ -455,6 +455,6 @@ void llama_sampling_accept(
455455
ctx_sampling->prev.push_back(id);
456456

457457
if (ctx_sampling->grammar != NULL && apply_grammar) {
458-
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
458+
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
459459
}
460460
}

‎examples/gbnf-validator/gbnf-validator.cpp

Copy file name to clipboardExpand all lines: examples/gbnf-validator/gbnf-validator.cpp
+10-5Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
1616
auto decoded = decode_utf8(input_str, {});
1717
const auto & code_points = decoded.first;
1818

19+
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
20+
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
21+
1922
size_t pos = 0;
2023
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
21-
auto prev_stacks = grammar->stacks;
22-
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
23-
if (grammar->stacks.empty()) {
24+
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
25+
26+
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
27+
28+
if (cur_stacks.empty()) {
2429
error_pos = pos;
2530
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
26-
grammar->stacks = prev_stacks;
31+
cur_stacks = prev_stacks;
2732
return false;
2833
}
2934
++pos;
3035
}
3136

32-
for (const auto & stack : grammar->stacks) {
37+
for (const auto & stack : cur_stacks) {
3338
if (stack.empty()) {
3439
return true;
3540
}

‎include/llama.h

Copy file name to clipboardExpand all lines: include/llama.h
+46-30Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -906,10 +906,10 @@ extern "C" {
906906
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
907907

908908
// Returns -1 if unknown, 1 for true or 0 for false.
909-
LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
909+
LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
910910

911911
// Returns -1 if unknown, 1 for true or 0 for false.
912-
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
912+
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
913913

914914
// Codellama infill tokens
915915
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
@@ -965,6 +965,10 @@ extern "C" {
965965
bool remove_special,
966966
bool unparse_special);
967967

968+
//
969+
// Chat templates
970+
//
971+
968972
/// Apply chat template. Inspired by hf apply_chat_template() on python.
969973
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
970974
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
@@ -1003,6 +1007,23 @@ extern "C" {
10031007

10041008
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
10051009

1010+
/// @details Apply constraints from grammar
1011+
LLAMA_API void llama_grammar_sample(
1012+
const struct llama_grammar * grammar,
1013+
const struct llama_context * ctx,
1014+
llama_token_data_array * candidates);
1015+
LLAMA_API DEPRECATED(void llama_sample_grammar(
1016+
struct llama_context * ctx,
1017+
llama_token_data_array * candidates,
1018+
const struct llama_grammar * grammar),
1019+
"use llama_grammar_sample instead");
1020+
1021+
/// @details Accepts the sampled token into the grammar
1022+
LLAMA_API void llama_grammar_accept_token(
1023+
struct llama_grammar * grammar,
1024+
struct llama_context * ctx,
1025+
llama_token token);
1026+
10061027
//
10071028
// Sampling functions
10081029
//
@@ -1084,12 +1105,6 @@ extern "C" {
10841105
llama_token_data_array * candidates,
10851106
float temp);
10861107

1087-
/// @details Apply constraints from grammar
1088-
LLAMA_API void llama_sample_grammar(
1089-
struct llama_context * ctx,
1090-
llama_token_data_array * candidates,
1091-
const struct llama_grammar * grammar);
1092-
10931108
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
10941109
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
10951110
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1127,12 +1142,6 @@ extern "C" {
11271142
struct llama_context * ctx,
11281143
llama_token_data_array * candidates);
11291144

1130-
/// @details Accepts the sampled token into the grammar
1131-
LLAMA_API void llama_grammar_accept_token(
1132-
struct llama_context * ctx,
1133-
struct llama_grammar * grammar,
1134-
llama_token token);
1135-
11361145
//
11371146
// Model split
11381147
//
@@ -1175,38 +1184,45 @@ extern "C" {
11751184

11761185
struct ggml_tensor;
11771186

1187+
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1188+
struct llama_context * ctx
1189+
);
1190+
11781191
struct llama_partial_utf8 {
11791192
uint32_t value; // bit value so far (unshifted)
11801193
int n_remain; // num bytes remaining; -1 indicates invalid sequence
11811194
};
11821195

1183-
struct llama_grammar {
1184-
const std::vector<std::vector<llama_grammar_element>> rules;
1185-
std::vector<std::vector<const llama_grammar_element *>> stacks;
1186-
1187-
// buffer for partially generated UTF-8 sequence from accepted tokens
1188-
llama_partial_utf8 partial_utf8;
1189-
};
1190-
11911196
struct llama_grammar_candidate {
11921197
size_t index;
11931198
const uint32_t * code_points;
11941199
llama_partial_utf8 partial_utf8;
11951200
};
11961201

1197-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1198-
struct llama_context * ctx
1199-
);
1202+
using llama_grammar_rule = std::vector< llama_grammar_element>;
1203+
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
1204+
1205+
using llama_grammar_rules = std::vector<llama_grammar_rule>;
1206+
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
1207+
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
1208+
1209+
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
1210+
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
12001211

12011212
void llama_grammar_accept(
1202-
const std::vector<std::vector<llama_grammar_element>> & rules,
1203-
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
1204-
const uint32_t chr,
1205-
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
1213+
const llama_grammar_rules & rules,
1214+
const llama_grammar_stacks & stacks,
1215+
const uint32_t chr,
1216+
llama_grammar_stacks & new_stacks);
1217+
1218+
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
1219+
const llama_grammar_rules & rules,
1220+
const llama_grammar_stack & stack,
1221+
const llama_grammar_candidates & candidates);
12061222

12071223
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
12081224
const std::string & src,
1209-
llama_partial_utf8 partial_start);
1225+
llama_partial_utf8 partial_start);
12101226

12111227
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
12121228
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.

‎src/CMakeLists.txt

Copy file name to clipboardExpand all lines: src/CMakeLists.txt
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ endif()
1414
add_library(llama
1515
../include/llama.h
1616
llama.cpp
17+
llama-vocab.cpp
18+
llama-grammar.cpp
19+
llama-sampling.cpp
1720
unicode.h
1821
unicode.cpp
1922
unicode-data.cpp

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.