1
+ from __future__ import annotations
2
+
1
3
import os
2
4
import sys
3
5
import uuid
40
42
)
41
43
42
44
43
- class LlamaState :
44
- def __init__ (
45
- self ,
46
- input_ids : npt .NDArray [np .intc ],
47
- scores : npt .NDArray [np .single ],
48
- n_tokens : int ,
49
- llama_state : bytes ,
50
- llama_state_size : int ,
51
- ):
52
- self .input_ids = input_ids
53
- self .scores = scores
54
- self .n_tokens = n_tokens
55
- self .llama_state = llama_state
56
- self .llama_state_size = llama_state_size
57
-
58
-
59
- LogitsProcessor = Callable [
60
- [npt .NDArray [np .intc ], npt .NDArray [np .single ]], npt .NDArray [np .single ]
61
- ]
62
-
63
-
64
- class LogitsProcessorList (List [LogitsProcessor ]):
65
- def __call__ (
66
- self , input_ids : npt .NDArray [np .intc ], scores : npt .NDArray [np .single ]
67
- ) -> npt .NDArray [np .single ]:
68
- for processor in self :
69
- scores = processor (input_ids , scores )
70
- return scores
71
-
72
-
73
- StoppingCriteria = Callable [[npt .NDArray [np .intc ], npt .NDArray [np .single ]], bool ]
74
-
75
-
76
- class StoppingCriteriaList (List [StoppingCriteria ]):
77
- def __call__ (
78
- self , input_ids : npt .NDArray [np .intc ], logits : npt .NDArray [np .single ]
79
- ) -> bool :
80
- return any ([stopping_criteria (input_ids , logits ) for stopping_criteria in self ])
81
-
82
-
83
45
class Llama :
84
46
"""High-level Python wrapper for a llama.cpp model."""
85
47
@@ -1733,3 +1695,43 @@ def decode(self, tokens: List[int]) -> str:
1733
1695
@classmethod
1734
1696
def from_ggml_file (cls , path : str ) -> "LlamaTokenizer" :
1735
1697
return cls (Llama (model_path = path , vocab_only = True ))
1698
+
1699
+
1700
+ class LlamaState :
1701
+ def __init__ (
1702
+ self ,
1703
+ input_ids : npt .NDArray [np .intc ],
1704
+ scores : npt .NDArray [np .single ],
1705
+ n_tokens : int ,
1706
+ llama_state : bytes ,
1707
+ llama_state_size : int ,
1708
+ ):
1709
+ self .input_ids = input_ids
1710
+ self .scores = scores
1711
+ self .n_tokens = n_tokens
1712
+ self .llama_state = llama_state
1713
+ self .llama_state_size = llama_state_size
1714
+
1715
+
1716
+ LogitsProcessor = Callable [
1717
+ [npt .NDArray [np .intc ], npt .NDArray [np .single ]], npt .NDArray [np .single ]
1718
+ ]
1719
+
1720
+
1721
+ class LogitsProcessorList (List [LogitsProcessor ]):
1722
+ def __call__ (
1723
+ self , input_ids : npt .NDArray [np .intc ], scores : npt .NDArray [np .single ]
1724
+ ) -> npt .NDArray [np .single ]:
1725
+ for processor in self :
1726
+ scores = processor (input_ids , scores )
1727
+ return scores
1728
+
1729
+
1730
+ StoppingCriteria = Callable [[npt .NDArray [np .intc ], npt .NDArray [np .single ]], bool ]
1731
+
1732
+
1733
+ class StoppingCriteriaList (List [StoppingCriteria ]):
1734
+ def __call__ (
1735
+ self , input_ids : npt .NDArray [np .intc ], logits : npt .NDArray [np .single ]
1736
+ ) -> bool :
1737
+ return any ([stopping_criteria (input_ids , logits ) for stopping_criteria in self ])
0 commit comments