Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b07713c

Browse filesBrowse files
author
c0sogi
committed
reset grammar for every generation
1 parent 418aa83 commit b07713c
Copy full SHA for b07713c

File tree

Expand file treeCollapse file tree

2 files changed

+39
-95
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+39
-95
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+4-5Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(
364364
)
365365
if grammar is not None:
366366
self.grammar = LlamaGrammar.from_file(
367-
grammar
367+
grammar, verbose=verbose
368368
) # type: Optional[LlamaGrammar]
369369
else:
370370
self.grammar = None
@@ -723,7 +723,6 @@ def generate(
723723
The generated tokens.
724724
"""
725725
assert self.ctx is not None
726-
727726
if reset and len(self._input_ids) > 0:
728727
longest_prefix = 0
729728
for a, b in zip(self._input_ids, tokens[:-1]):
@@ -741,6 +740,9 @@ def generate(
741740
if reset:
742741
self.reset()
743742

743+
if self.grammar is not None:
744+
self.grammar.reset()
745+
744746
while True:
745747
self.eval(tokens)
746748
token = self.sample(
@@ -1534,9 +1536,6 @@ def __del__(self):
15341536
if self.ctx is not None:
15351537
llama_cpp.llama_free(self.ctx)
15361538
self.ctx = None
1537-
if self.grammar is not None:
1538-
llama_cpp.llama_grammar_free(self.grammar.grammar)
1539-
self.grammar = None
15401539

15411540
def __getstate__(self):
15421541
return dict(

‎llama_cpp/llama_grammar.py

Copy file name to clipboardExpand all lines: llama_cpp/llama_grammar.py
+35-90Lines changed: 35 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""C++ implementation of the llama grammar parser."""
22
# flake8: noqa
3-
import argparse
43
from pathlib import Path
54
import sys
65
from ctypes import * # type: ignore
@@ -19,7 +18,7 @@
1918
overload,
2019
)
2120

22-
import llama_cpp
21+
from . import llama_cpp
2322

2423
# Type aliases
2524
llama_grammar_element = llama_cpp.llama_grammar_element
@@ -41,11 +40,19 @@ class Sentinel:
4140
class LlamaGrammar:
4241
"""Keeps reference counts of all the arguments, so that they are not
4342
garbage collected by Python."""
43+
44+
def __del__(self) -> None:
45+
"""Free the grammar pointer when the object is deleted."""
46+
if self.grammar is not None:
47+
llama_cpp.llama_grammar_free(self.grammar)
48+
self.grammar = None
4449

4550
def __init__(
4651
self,
4752
parsed_grammar: "parse_state",
4853
) -> None:
54+
"""Initialize the grammar pointer from the parsed state."""
55+
self.parsed_grammar = parsed_grammar
4956
grammar_rules = (
5057
parsed_grammar.c_rules()
5158
) # type: std.vector[std.vector[llama_grammar_element]]
@@ -69,22 +76,25 @@ def __init__(
6976

7077
self.n_rules = c_size_t(grammar_rules.size())
7178
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
72-
self.grammar = self.init_grammar()
79+
self._grammar = llama_cpp.llama_grammar_init(
80+
self.rules, self.n_rules, self.start_rule_index
81+
)
7382

7483
@classmethod
75-
def from_string(cls, grammar: str) -> "LlamaGrammar":
84+
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
7685
parsed_grammar = parse(const_char_p(grammar)) # type: parse_state
7786
if parsed_grammar.rules.empty():
7887
raise ValueError(
7988
f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
8089
)
81-
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
82-
print_grammar(sys.stdout, parsed_grammar)
83-
print(file=sys.stderr)
90+
if verbose:
91+
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
92+
print_grammar(sys.stdout, parsed_grammar)
93+
print(file=sys.stderr)
8494
return cls(parsed_grammar)
8595

8696
@classmethod
87-
def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
97+
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
8898
try:
8999
with open(file) as f:
90100
grammar = f.read()
@@ -94,14 +104,27 @@ def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
94104
)
95105

96106
if grammar:
97-
return cls.from_string(grammar)
107+
return cls.from_string(grammar, verbose=verbose)
98108

99109
raise ValueError(
100110
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
101111
)
102112

103-
def init_grammar(self) -> llama_grammar_p:
104-
return llama_cpp.llama_grammar_init(
113+
@property
114+
def grammar(self) -> llama_grammar_p:
115+
if self._grammar is None:
116+
raise ValueError(
117+
f"{self.__class__.__name__}.grammar: grammar is freed"
118+
)
119+
return self._grammar
120+
121+
@grammar.setter
122+
def grammar(self, value: Optional[llama_grammar_p]) -> None:
123+
self._grammar = value
124+
125+
def reset(self) -> None:
126+
llama_cpp.llama_grammar_free(self.grammar)
127+
self.grammar = llama_cpp.llama_grammar_init(
105128
self.rules, self.n_rules, self.start_rule_index
106129
)
107130

@@ -1216,82 +1239,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
12161239
print(
12171240
f"{print_grammar.__name__}: error printing grammar: {err}",
12181241
file=sys.stderr,
1219-
)
1220-
1221-
1222-
# def convert_to_rules(
1223-
# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]],
1224-
# ) -> Array[llama_grammar_element_p]:
1225-
# """Make an Array object that is used for `llama_grammer_init`"""
1226-
1227-
# # Step 1: Convert each list to llama_grammar_element array and get pointer
1228-
# element_arrays = [
1229-
# (llama_grammar_element * len(subvector))(*subvector)
1230-
# for subvector in llama_grammar_elements
1231-
# ] # type: List[Array[llama_grammar_element]]
1232-
1233-
# # Step 2: Get pointer of each array
1234-
# element_array_pointers = [
1235-
# cast(subarray, llama_grammar_element_p) for subarray in element_arrays
1236-
# ] # type: List[llama_grammar_element_p]
1237-
1238-
# # Step 3: Make array of these pointers and get its pointer
1239-
# return (llama_grammar_element_p * len(element_array_pointers))(
1240-
# *element_array_pointers
1241-
# )
1242-
1243-
1244-
if __name__ == "__main__":
1245-
parser = argparse.ArgumentParser(
1246-
description="Generate C++ parser from GBNF grammar"
1247-
)
1248-
parser.add_argument(
1249-
"-g",
1250-
"--grammar",
1251-
type=str,
1252-
default="./vendor/llama.cpp/grammars/json.gbnf",
1253-
help="path to GBNF grammar file",
1254-
)
1255-
1256-
args = parser.parse_args()
1257-
llama_grammar = LlamaGrammar.from_file(Path(args.grammar))
1258-
llama_grammar_ptr = llama_grammar.init_grammar()
1259-
1260-
# ----- USAGE:
1261-
# llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p)
1262-
# llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...)
1263-
1264-
# ----- SAMPLE OUTPUT:
1265-
# main grammar:
1266-
# root ::= object
1267-
# object ::= [{] ws object_11 [}] ws
1268-
# value ::= object | array | string | number | value_6 ws
1269-
# array ::= [[] ws array_15 []] ws
1270-
# string ::= ["] string_18 ["] ws
1271-
# number ::= number_19 number_25 number_29 ws
1272-
# value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l]
1273-
# ws ::= ws_31
1274-
# object_8 ::= string [:] ws value object_10
1275-
# object_9 ::= [,] ws string [:] ws value
1276-
# object_10 ::= object_9 object_10 |
1277-
# object_11 ::= object_8 |
1278-
# array_12 ::= value array_14
1279-
# array_13 ::= [,] ws value
1280-
# array_14 ::= array_13 array_14 |
1281-
# array_15 ::= array_12 |
1282-
# string_16 ::= [^"\] | [\] string_17
1283-
# string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]
1284-
# string_18 ::= string_16 string_18 |
1285-
# number_19 ::= number_20 number_21
1286-
# number_20 ::= [-] |
1287-
# number_21 ::= [0-9] | [1-9] number_22
1288-
# number_22 ::= [0-9] number_22 |
1289-
# number_23 ::= [.] number_24
1290-
# number_24 ::= [0-9] number_24 | [0-9]
1291-
# number_25 ::= number_23 |
1292-
# number_26 ::= [eE] number_27 number_28
1293-
# number_27 ::= [-+] |
1294-
# number_28 ::= [0-9] number_28 | [0-9]
1295-
# number_29 ::= number_26 |
1296-
# ws_30 ::= [ <U+0009><U+000A>] ws
1297-
# ws_31 ::= ws_30 |
1242+
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.