Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1fd8840

Browse filesBrowse files
authored
1 parent 4c74a82 commit 1fd8840
Copy full SHA for 1fd8840

File tree

Expand file treeCollapse file tree

1 file changed

+44
-14
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+44
-14
lines changed

‎llama_cpp/llama_grammar.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_grammar.py
+44-14Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
891891
pos += 1
892892
last_sym_start = out_elements.size()
893893
while pos[0] != '"':
894+
assert pos[0] is not None, "Unexpected end of input"
894895
char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
895896
pos = char_pair[1]
896897
out_elements.push_back(
@@ -920,6 +921,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
920921
# : start_type;
921922
# out_elements.push_back({type, char_pair.first});
922923
while pos[0] != "]":
924+
assert pos[0] is not None, "Unexpected end of input"
923925
char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
924926
pos = char_pair[1]
925927
_type = (
@@ -935,6 +937,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
935937
# }
936938
# }
937939
if pos[0] == "-" and pos[1] != "]":
940+
assert pos[1] is not None, "Unexpected end of input"
938941
endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p]
939942
pos = endchar_pair[1]
940943
out_elements.push_back(
@@ -1159,33 +1162,59 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
11591162
elif pos[0]:
11601163
raise RuntimeError("expecting newline or end at " + str(pos))
11611164
return parse_space(pos, True)
1165+
1166+
#parse_state parse(const char * src) {
1167+
# try {
1168+
# parse_state state;
1169+
# const char * pos = parse_space(src, true);
1170+
# while (*pos) {
1171+
# pos = parse_rule(state, pos);
1172+
# }
1173+
# // Validate the state to ensure that all rules are defined
1174+
# for (const auto & rule : state.rules) {
1175+
# for (const auto & elem : rule) {
1176+
# if (elem.type == LLAMA_GRETYPE_RULE_REF) {
1177+
# // Ensure that the rule at that location exists
1178+
# if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
1179+
# // Get the name of the rule that is missing
1180+
# for (const auto & kv : state.symbol_ids) {
1181+
# if (kv.second == elem.value) {
1182+
# throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
1183+
# }
1184+
# }
1185+
# }
1186+
# }
1187+
# }
1188+
# }
1189+
# return state;
1190+
# } catch (const std::exception & err) {
1191+
# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
1192+
# return parse_state();
1193+
# }
1194+
#}
11621195

11631196

1164-
# parse_state parse(const char * src) {
1165-
# try {
1166-
# parse_state state;
1167-
# const char * pos = parse_space(src, true);
1168-
# while (*pos) {
1169-
# pos = parse_rule(state, pos);
1170-
# }
1171-
# return state;
1172-
# } catch (const std::exception & err) {
1173-
# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
1174-
# return parse_state();
1175-
# }
1176-
# }
11771197
def parse(src: const_char_p) -> parse_state:
11781198
try:
11791199
state = parse_state() # type: parse_state
11801200
pos = parse_space(src, True) # type: const_char_p
11811201
while pos[0]:
11821202
pos = parse_rule(state, pos)
1203+
# Validate the state to ensure that all rules are defined
1204+
for rule in state.rules:
1205+
for elem in rule:
1206+
if elem.type == llama_gretype.LLAMA_GRETYPE_RULE_REF:
1207+
# Ensure that the rule at that location exists
1208+
if elem.value >= len(state.rules) or not state.rules[elem.value]:
1209+
# Get the name of the rule that is missing
1210+
for kv in state.symbol_ids:
1211+
if kv.second == elem.value:
1212+
raise RuntimeError("Undefined rule identifier '" + kv.first + "'")
11831213
return state
11841214
except Exception as err:
11851215
print(f"{parse.__name__}: error parsing grammar: {err}")
11861216
return parse_state()
11871217

1188-
11891218
# void print_grammar_char(FILE * file, uint32_t c) {
11901219
# if (0x20 <= c && c <= 0x7f) {
11911220
# fprintf(file, "%c", static_cast<char>(c));
@@ -1283,6 +1312,7 @@ def print_rule(
12831312
# }
12841313

12851314

1315+
12861316
for i, elem in enumerate(rule[:-1]):
12871317
case = elem.type # type: llama_gretype
12881318
if case is llama_gretype.LLAMA_GRETYPE_END:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.