Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7b46bb5

Browse filesBrowse files
committed
Re-order classes in llama.py
1 parent cc4630e commit 7b46bb5
Copy full SHA for 7b46bb5

File tree

Expand file treeCollapse file tree

1 file changed

+42
-40
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+42
-40
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+42-40Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import os
24
import sys
35
import uuid
@@ -40,46 +42,6 @@
4042
)
4143

4244

43-
class LlamaState:
44-
def __init__(
45-
self,
46-
input_ids: npt.NDArray[np.intc],
47-
scores: npt.NDArray[np.single],
48-
n_tokens: int,
49-
llama_state: bytes,
50-
llama_state_size: int,
51-
):
52-
self.input_ids = input_ids
53-
self.scores = scores
54-
self.n_tokens = n_tokens
55-
self.llama_state = llama_state
56-
self.llama_state_size = llama_state_size
57-
58-
59-
LogitsProcessor = Callable[
60-
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
61-
]
62-
63-
64-
class LogitsProcessorList(List[LogitsProcessor]):
65-
def __call__(
66-
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
67-
) -> npt.NDArray[np.single]:
68-
for processor in self:
69-
scores = processor(input_ids, scores)
70-
return scores
71-
72-
73-
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
74-
75-
76-
class StoppingCriteriaList(List[StoppingCriteria]):
77-
def __call__(
78-
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
79-
) -> bool:
80-
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
81-
82-
8345
class Llama:
8446
"""High-level Python wrapper for a llama.cpp model."""
8547

@@ -1733,3 +1695,43 @@ def decode(self, tokens: List[int]) -> str:
17331695
@classmethod
17341696
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
17351697
return cls(Llama(model_path=path, vocab_only=True))
1698+
1699+
1700+
class LlamaState:
1701+
def __init__(
1702+
self,
1703+
input_ids: npt.NDArray[np.intc],
1704+
scores: npt.NDArray[np.single],
1705+
n_tokens: int,
1706+
llama_state: bytes,
1707+
llama_state_size: int,
1708+
):
1709+
self.input_ids = input_ids
1710+
self.scores = scores
1711+
self.n_tokens = n_tokens
1712+
self.llama_state = llama_state
1713+
self.llama_state_size = llama_state_size
1714+
1715+
1716+
LogitsProcessor = Callable[
1717+
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
1718+
]
1719+
1720+
1721+
class LogitsProcessorList(List[LogitsProcessor]):
1722+
def __call__(
1723+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
1724+
) -> npt.NDArray[np.single]:
1725+
for processor in self:
1726+
scores = processor(input_ids, scores)
1727+
return scores
1728+
1729+
1730+
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
1731+
1732+
1733+
class StoppingCriteriaList(List[StoppingCriteria]):
1734+
def __call__(
1735+
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
1736+
) -> bool:
1737+
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.