Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 66fb034

Browse filesBrowse files
committed
Move grammar to function call argument
1 parent 1e844d3 commit 66fb034
Copy full SHA for 66fb034

File tree

Expand file treeCollapse file tree

1 file changed

+19
-14
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+19
-14
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+19-14Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def __init__(
227227
tensor_split: Optional[List[float]] = None,
228228
rope_freq_base: float = 10000.0,
229229
rope_freq_scale: float = 1.0,
230-
grammar: Optional[Union[str, Path]] = None,
231230
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
232231
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
233232
mul_mat_q: Optional[bool] = None, # (TEMPORARY)
@@ -254,7 +253,6 @@ def __init__(
254253
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
255254
rope_freq_base: Base frequency for rope sampling.
256255
rope_freq_scale: Scale factor for rope sampling.
257-
grammar: Path to a BNF grammar file to use for grammar based sampling.
258256
verbose: Print verbose output to stderr.
259257
260258
Raises:
@@ -383,12 +381,6 @@ def __init__(
383381
self.scores: npt.NDArray[np.single] = np.ndarray(
384382
(n_ctx, self._n_vocab), dtype=np.single
385383
)
386-
if grammar is not None:
387-
self.grammar = LlamaGrammar.from_file(
388-
grammar, verbose=verbose
389-
) # type: Optional[LlamaGrammar]
390-
else:
391-
self.grammar = None
392384

393385
@property
394386
def _input_ids(self) -> npt.NDArray[np.intc]:
@@ -527,6 +519,7 @@ def _sample(
527519
mirostat_eta: llama_cpp.c_float,
528520
penalize_nl: bool = True,
529521
logits_processor: Optional[LogitsProcessorList] = None,
522+
grammar: Optional[LlamaGrammar] = None,
530523
):
531524
assert self.ctx is not None
532525
assert self.n_tokens > 0
@@ -574,11 +567,11 @@ def _sample(
574567
if not penalize_nl:
575568
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
576569

577-
if self.grammar is not None:
570+
if grammar is not None:
578571
llama_cpp.llama_sample_grammar(
579572
ctx=self.ctx,
580573
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
581-
grammar=self.grammar.grammar,
574+
grammar=grammar.grammar,
582575
)
583576

584577
if temp.value == 0.0:
@@ -650,10 +643,10 @@ def _sample(
650643
ctx=self.ctx,
651644
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
652645
)
653-
if self.grammar is not None:
646+
if grammar is not None:
654647
llama_cpp.llama_grammar_accept_token(
655648
ctx=self.ctx,
656-
grammar=self.grammar.grammar,
649+
grammar=grammar.grammar,
657650
token=llama_cpp.ctypes.c_int(id),
658651
)
659652
return id
@@ -672,6 +665,7 @@ def sample(
672665
mirostat_tau: float = 5.0,
673666
penalize_nl: bool = True,
674667
logits_processor: Optional[LogitsProcessorList] = None,
668+
grammar: Optional[LlamaGrammar] = None,
675669
):
676670
"""Sample a token from the model.
677671
@@ -705,6 +699,7 @@ def sample(
705699
mirostat_eta=llama_cpp.c_float(mirostat_eta),
706700
penalize_nl=penalize_nl,
707701
logits_processor=logits_processor,
702+
grammar=grammar,
708703
)
709704

710705
def generate(
@@ -723,6 +718,7 @@ def generate(
723718
mirostat_eta: float = 0.1,
724719
logits_processor: Optional[LogitsProcessorList] = None,
725720
stopping_criteria: Optional[StoppingCriteriaList] = None,
721+
grammar: Optional[LlamaGrammar] = None,
726722
) -> Generator[int, Optional[Sequence[int]], None]:
727723
"""Create a generator of tokens from a prompt.
728724
@@ -761,8 +757,8 @@ def generate(
761757
if reset:
762758
self.reset()
763759

764-
if self.grammar is not None:
765-
self.grammar.reset()
760+
if grammar is not None:
761+
grammar.reset()
766762

767763
while True:
768764
self.eval(tokens)
@@ -778,6 +774,7 @@ def generate(
778774
mirostat_tau=mirostat_tau,
779775
mirostat_eta=mirostat_eta,
780776
logits_processor=logits_processor,
777+
grammar=grammar,
781778
)
782779
if stopping_criteria is not None and stopping_criteria(
783780
self._input_ids.tolist(), self._scores[-1, :].tolist()
@@ -880,6 +877,7 @@ def _create_completion(
880877
model: Optional[str] = None,
881878
stopping_criteria: Optional[StoppingCriteriaList] = None,
882879
logits_processor: Optional[LogitsProcessorList] = None,
880+
grammar: Optional[LlamaGrammar] = None,
883881
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
884882
assert self.ctx is not None
885883

@@ -957,6 +955,7 @@ def _create_completion(
957955
repeat_penalty=repeat_penalty,
958956
stopping_criteria=stopping_criteria,
959957
logits_processor=logits_processor,
958+
grammar=grammar,
960959
):
961960
if token == self._token_eos:
962961
text = self.detokenize(completion_tokens)
@@ -1301,6 +1300,7 @@ def create_completion(
13011300
model: Optional[str] = None,
13021301
stopping_criteria: Optional[StoppingCriteriaList] = None,
13031302
logits_processor: Optional[LogitsProcessorList] = None,
1303+
grammar: Optional[LlamaGrammar] = None,
13041304
) -> Union[Completion, Iterator[CompletionChunk]]:
13051305
"""Generate text from a prompt.
13061306
@@ -1345,6 +1345,7 @@ def create_completion(
13451345
model=model,
13461346
stopping_criteria=stopping_criteria,
13471347
logits_processor=logits_processor,
1348+
grammar=grammar
13481349
)
13491350
if stream:
13501351
chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -1374,6 +1375,7 @@ def __call__(
13741375
model: Optional[str] = None,
13751376
stopping_criteria: Optional[StoppingCriteriaList] = None,
13761377
logits_processor: Optional[LogitsProcessorList] = None,
1378+
grammar: Optional[LlamaGrammar] = None,
13771379
) -> Union[Completion, Iterator[CompletionChunk]]:
13781380
"""Generate text from a prompt.
13791381
@@ -1418,6 +1420,7 @@ def __call__(
14181420
model=model,
14191421
stopping_criteria=stopping_criteria,
14201422
logits_processor=logits_processor,
1423+
grammar=grammar,
14211424
)
14221425

14231426
def _convert_text_completion_to_chat(
@@ -1498,6 +1501,7 @@ def create_chat_completion(
14981501
mirostat_eta: float = 0.1,
14991502
model: Optional[str] = None,
15001503
logits_processor: Optional[LogitsProcessorList] = None,
1504+
grammar: Optional[LlamaGrammar] = None,
15011505
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
15021506
"""Generate a chat completion from a list of messages.
15031507
@@ -1540,6 +1544,7 @@ def create_chat_completion(
15401544
mirostat_eta=mirostat_eta,
15411545
model=model,
15421546
logits_processor=logits_processor,
1547+
grammar=grammar,
15431548
)
15441549
if stream:
15451550
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.