Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 418aa83

Browse filesBrowse files
committed
Added grammar based sampling
1 parent ac188a2 commit 418aa83
Copy full SHA for 418aa83

File tree

Expand file treeCollapse file tree

2 files changed

+512
-518
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+512
-518
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+32-4Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from pathlib import Path
23
import sys
34
import uuid
45
import time
@@ -23,6 +24,7 @@
2324

2425
from . import llama_cpp
2526
from .llama_types import *
27+
from .llama_grammar import LlamaGrammar
2628

2729
import numpy as np
2830
import numpy.typing as npt
@@ -223,6 +225,7 @@ def __init__(
223225
tensor_split: Optional[List[float]] = None,
224226
rope_freq_base: float = 10000.0,
225227
rope_freq_scale: float = 1.0,
228+
grammar: Optional[Union[str, Path]] = None,
226229
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
227230
rms_norm_eps: Optional[float] = None, # (TEMPORARY)
228231
verbose: bool = True,
@@ -248,6 +251,7 @@ def __init__(
248251
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
249252
rope_freq_base: Base frequency for rope sampling.
250253
rope_freq_scale: Scale factor for rope sampling.
254+
grammar: Path to a BNF grammar file to use for grammar based sampling.
251255
verbose: Print verbose output to stderr.
252256
253257
Raises:
@@ -358,6 +362,12 @@ def __init__(
358362
self.scores: npt.NDArray[np.single] = np.ndarray(
359363
(n_ctx, self._n_vocab), dtype=np.single
360364
)
365+
if grammar is not None:
366+
self.grammar = LlamaGrammar.from_file(
367+
grammar
368+
) # type: Optional[LlamaGrammar]
369+
else:
370+
self.grammar = None
361371

362372
@property
363373
def _input_ids(self) -> npt.NDArray[np.intc]:
@@ -542,8 +552,16 @@ def _sample(
542552
)
543553
if not penalize_nl:
544554
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
555+
556+
if self.grammar is not None:
557+
llama_cpp.llama_sample_grammar(
558+
ctx=self.ctx,
559+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
560+
grammar=self.grammar.grammar,
561+
)
562+
545563
if temp.value == 0.0:
546-
return llama_cpp.llama_sample_token_greedy(
564+
id = llama_cpp.llama_sample_token_greedy(
547565
ctx=self.ctx,
548566
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
549567
)
@@ -555,7 +573,7 @@ def _sample(
555573
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
556574
temp=temp,
557575
)
558-
return llama_cpp.llama_sample_token_mirostat(
576+
id = llama_cpp.llama_sample_token_mirostat(
559577
ctx=self.ctx,
560578
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
561579
tau=mirostat_tau,
@@ -570,7 +588,7 @@ def _sample(
570588
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
571589
temp=temp,
572590
)
573-
return llama_cpp.llama_sample_token_mirostat_v2(
591+
id = llama_cpp.llama_sample_token_mirostat_v2(
574592
ctx=self.ctx,
575593
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
576594
tau=mirostat_tau,
@@ -607,10 +625,17 @@ def _sample(
607625
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
608626
temp=temp,
609627
)
610-
return llama_cpp.llama_sample_token(
628+
id = llama_cpp.llama_sample_token(
611629
ctx=self.ctx,
612630
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
613631
)
632+
if self.grammar is not None:
633+
llama_cpp.llama_grammar_accept_token(
634+
ctx=self.ctx,
635+
grammar=self.grammar.grammar,
636+
token=llama_cpp.ctypes.c_int(id),
637+
)
638+
return id
614639

615640
def sample(
616641
self,
@@ -1509,6 +1534,9 @@ def __del__(self):
15091534
if self.ctx is not None:
15101535
llama_cpp.llama_free(self.ctx)
15111536
self.ctx = None
1537+
if self.grammar is not None:
1538+
llama_cpp.llama_grammar_free(self.grammar.grammar)
1539+
self.grammar = None
15121540

15131541
def __getstate__(self):
15141542
return dict(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.