Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 01ceae6

Browse filesBrowse files
ochafikorca-zhang
authored andcommitted
tool-call: support Command R7B (+ return tool_plan "thoughts" in API) (ggml-org#11585)
* `tool-call`: support Command R7B (w/ tool_plan return) * `tool-call`: cleaner preservation of tokens + warn when likely bad chat template override * `tool-call`: test cleanup / handle lazy grammar triggers
1 parent 46b7e4f commit 01ceae6
Copy full SHA for 01ceae6

File tree

Expand file treeCollapse file tree

8 files changed

+420
-56
lines changed
Filter options
Expand file treeCollapse file tree

8 files changed

+420
-56
lines changed

‎common/chat.cpp

Copy file name to clipboardExpand all lines: common/chat.cpp
+84-2Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ std::string common_chat_format_name(common_chat_format format) {
1616
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
1717
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
1818
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
19+
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
1920
default:
2021
throw std::runtime_error("Unknown chat format");
2122
}
@@ -317,6 +318,79 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
317318
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
318319
}
319320

321+
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
322+
common_chat_params data;
323+
data.grammar_lazy = inputs.tool_choice != "required";
324+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
325+
auto schemas = json::array();
326+
foreach_function(inputs.tools, [&](const json & tool) {
327+
const auto & function = tool["function"];
328+
schemas.push_back({
329+
{"type", "object"},
330+
{"properties", {
331+
{"tool_call_id", {
332+
{"type", "string"},
333+
// Command-R's template expects an integer string.
334+
{"pattern", "^[0-9]{1,10}$"},
335+
}},
336+
{"tool_name", {
337+
{"type", "string"},
338+
{"const", function["name"]},
339+
}},
340+
{"parameters", function["parameters"]},
341+
}},
342+
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
343+
});
344+
});
345+
auto schema = json {
346+
{"type", "array"},
347+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
348+
{"minItems", 1},
349+
};
350+
if (!inputs.parallel_tool_calls) {
351+
schema["maxItems"] = 1;
352+
}
353+
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
354+
}, grammar_options);
355+
data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false});
356+
data.preserved_tokens = {
357+
"<|START_RESPONSE|>",
358+
"<|END_RESPONSE|>",
359+
"<|START_THINKING|>",
360+
"<|END_THINKING|>",
361+
"<|END_ACTION|>",
362+
};
363+
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
364+
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
365+
return data;
366+
}
367+
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
368+
static std::regex response_regex("<\\|START_RESPONSE\\|>(.*?)<\\|END_RESPONSE\\|>");
369+
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
370+
std::smatch match;
371+
372+
common_chat_msg result;
373+
result.role = "assistant";
374+
if (std::regex_match(input, match, response_regex)) {
375+
result.content = match[1].str();
376+
} else if (std::regex_match(input, match, thought_action_regex)) {
377+
result.tool_plan = match[1].str();
378+
auto actions_str = match[2].str();
379+
auto actions = json::parse(actions_str);
380+
for (const auto & action : actions) {
381+
result.tool_calls.push_back({
382+
/* .name = */ action["tool_name"],
383+
/* .arguments = */ action["parameters"].dump(),
384+
/* .id = */ action["tool_call_id"],
385+
});
386+
}
387+
} else {
388+
LOG_ERR("Failed to parse command_r output");
389+
result.content = input;
390+
}
391+
return result;
392+
}
393+
320394
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
321395
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
322396
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
@@ -462,6 +536,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
462536
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\""));
463537
});
464538
data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false});
539+
data.preserved_tokens = {
540+
"<|tool▁sep|>",
541+
"<|tool▁call▁end|>",
542+
};
465543
builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space");
466544
}, grammar_options);
467545
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
@@ -704,8 +782,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
704782
auto tool_call = "\"<tool_call>\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"</tool_call>\" space";
705783
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
706784
data.grammar_triggers.push_back({"<tool_call>", /* .at_start = */ false});
707-
// Not really a trigger but need to print this special token to get a successful parse.
708-
data.grammar_triggers.push_back({"</tool_call>", /* .at_start = */ false});
785+
data.preserved_tokens = { "</tool_call>" };
709786
}, grammar_options);
710787

711788
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
@@ -822,6 +899,9 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
822899
if (src.find("[TOOL_CALLS]") != std::string::npos) {
823900
return common_chat_params_init_mistral_nemo(tmpl, inputs);
824901
}
902+
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
903+
return common_chat_params_init_command_r7b(tmpl, inputs);
904+
}
825905
return common_chat_params_init_generic(tmpl, inputs);
826906
}
827907

@@ -855,6 +935,8 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
855935
return common_chat_parse_hermes_2_pro(input);
856936
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
857937
return common_chat_parse_firefunction_v2(input);
938+
case COMMON_CHAT_FORMAT_COMMAND_R7B:
939+
return common_chat_parse_command_r7b(input);
858940
default:
859941
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
860942
}

‎common/chat.hpp

Copy file name to clipboardExpand all lines: common/chat.hpp
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum common_chat_format {
3232
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
3333
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
3434
COMMON_CHAT_FORMAT_HERMES_2_PRO,
35+
COMMON_CHAT_FORMAT_COMMAND_R7B,
3536

3637
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
3738
};
@@ -42,6 +43,7 @@ struct common_chat_params {
4243
std::string grammar;
4344
bool grammar_lazy = false;
4445
std::vector<common_grammar_trigger> grammar_triggers;
46+
std::vector<std::string> preserved_tokens;
4547
std::vector<std::string> additional_stops;
4648
};
4749

‎common/common.h

Copy file name to clipboardExpand all lines: common/common.h
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "llama-cpp.h"
66

7+
#include <set>
78
#include <string>
89
#include <vector>
910
#include <sstream>
@@ -163,6 +164,7 @@ struct common_params_sampling {
163164
bool grammar_lazy = false;
164165
std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
165166
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
167+
std::set<llama_token> preserved_tokens;
166168

167169
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
168170

@@ -621,6 +623,7 @@ struct common_chat_msg {
621623
std::string role;
622624
std::string content;
623625
std::vector<common_tool_call> tool_calls;
626+
std::string tool_plan = "";
624627
};
625628

626629
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid

‎examples/server/README.md

Copy file name to clipboardExpand all lines: examples/server/README.md
+15-7Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,7 @@ curl http://localhost:8080/v1/chat/completions \
11281128
- Hermes 2/3, Qwen 2.5
11291129
- Mistral Nemo
11301130
- Firefunction v2
1131+
- Command R7B
11311132
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
11321133

11331134
<details>
@@ -1202,21 +1203,28 @@ curl http://localhost:8080/v1/chat/completions \
12021203
```shell
12031204
# Native support:
12041205
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
1205-
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M
1206-
llama-server --jinja -fa -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q6_K
1206+
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
12071207
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
1208-
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
1209-
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B )
1208+
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
12101209
12111210
# Native support requires the right template for these GGUFs:
1211+
1212+
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
1213+
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
1214+
12121215
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
12131216
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
1217+
12141218
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
1215-
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/firellama-3-firefunction-v2 )
1219+
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
1220+
1221+
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
1222+
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
12161223
12171224
# Generic format support
1218-
llama-server --jinja -fa -hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M
1219-
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q4_K_M
1225+
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
1226+
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
1227+
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
12201228
```
12211229
12221230
- Test in CLI:

‎examples/server/server.cpp

Copy file name to clipboardExpand all lines: examples/server/server.cpp
+38-14Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ struct slot_params {
131131
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
132132
}
133133

134+
std::vector<std::string> grammar_trigger_words;
135+
for (const auto & trigger : sampling.grammar_trigger_words) {
136+
grammar_trigger_words.push_back(trigger.word);
137+
}
138+
134139
return json {
135140
{"n_predict", n_predict}, // Server configured n_predict
136141
{"seed", sampling.seed},
@@ -165,8 +170,9 @@ struct slot_params {
165170
{"n_probs", sampling.n_probs},
166171
{"min_keep", sampling.min_keep},
167172
{"grammar", sampling.grammar},
168-
// {"grammar_trigger_words", sampling.grammar_trigger_words},
173+
{"grammar_trigger_words", grammar_trigger_words},
169174
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
175+
{"preserved_tokens", sampling.preserved_tokens},
170176
{"samplers", samplers},
171177
{"speculative.n_max", speculative.n_max},
172178
{"speculative.n_min", speculative.n_min},
@@ -363,12 +369,26 @@ struct server_task {
363369
if (ids.size() == 1) {
364370
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
365371
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
372+
params.sampling.preserved_tokens.insert(ids[0]);
366373
continue;
367374
}
368375
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
369376
params.sampling.grammar_trigger_words.push_back(trigger);
370377
}
371378
}
379+
const auto preserved_tokens = data.find("preserved_tokens");
380+
if (preserved_tokens != data.end()) {
381+
for (const auto & t : *preserved_tokens) {
382+
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
383+
if (ids.size() == 1) {
384+
LOG_DBG("Preserved token: %d\n", ids[0]);
385+
params.sampling.preserved_tokens.insert(ids[0]);
386+
} else {
387+
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
388+
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
389+
}
390+
}
391+
}
372392
if (params.sampling.grammar_lazy) {
373393
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
374394
}
@@ -695,19 +715,19 @@ struct server_task_result_cmpl_final : server_task_result {
695715

696716
json to_json_oaicompat_chat() {
697717
std::string finish_reason = "length";
698-
common_chat_msg message;
718+
common_chat_msg msg;
699719
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
700720
LOG_DBG("Parsing chat message: %s\n", content.c_str());
701-
message = common_chat_parse(content, oaicompat_chat_format);
702-
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
721+
msg = common_chat_parse(content, oaicompat_chat_format);
722+
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
703723
} else {
704-
message.content = content;
724+
msg.content = content;
705725
}
706726

707727
json tool_calls;
708-
if (!message.tool_calls.empty()) {
728+
if (!msg.tool_calls.empty()) {
709729
tool_calls = json::array();
710-
for (const auto & tc : message.tool_calls) {
730+
for (const auto & tc : msg.tool_calls) {
711731
tool_calls.push_back({
712732
{"type", "function"},
713733
{"function", {
@@ -719,14 +739,19 @@ struct server_task_result_cmpl_final : server_task_result {
719739
}
720740
}
721741

742+
json message {
743+
{"content", msg.content},
744+
{"tool_calls", tool_calls},
745+
{"role", "assistant"},
746+
};
747+
if (!msg.tool_plan.empty()) {
748+
message["tool_plan"] = msg.tool_plan;
749+
}
750+
722751
json choice {
723752
{"finish_reason", finish_reason},
724753
{"index", 0},
725-
{"message", json {
726-
{"content", message.content},
727-
{"tool_calls", tool_calls},
728-
{"role", "assistant"},
729-
}},
754+
{"message", message},
730755
};
731756

732757
if (!stream && probs_output.size() > 0) {
@@ -2833,8 +2858,7 @@ struct server_context {
28332858
server_slot * slot_batched = nullptr;
28342859

28352860
auto accept_special_token = [&](server_slot & slot, llama_token token) {
2836-
const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens;
2837-
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
2861+
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
28382862
};
28392863

28402864
// frist, add sampled tokens from any ongoing sequences

‎examples/server/utils.hpp

Copy file name to clipboardExpand all lines: examples/server/utils.hpp
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ static json oaicompat_completion_params_parse(
662662
});
663663
}
664664
llama_params["grammar_triggers"] = grammar_triggers;
665+
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
665666
for (const auto & stop : chat_params.additional_stops) {
666667
llama_params["stop"].push_back(stop);
667668
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.