Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 2c0d9b1

Browse filesBrowse files
committed
Fix session loading and saving in low level example chat
1 parent ed66a46 commit 2c0d9b1
Copy full SHA for 2c0d9b1

File tree

Expand file treeCollapse file tree

1 file changed

+14
-14
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+14
-14
lines changed

‎examples/low_level_api/low_level_api_chat_cpp.py

Copy file name to clipboardExpand all lines: examples/low_level_api/low_level_api_chat_cpp.py
+14-14Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,17 @@ def __init__(self, params: GptParams) -> None:
112112

113113
if (path.exists(self.params.path_session)):
114114
_session_tokens = (llama_cpp.llama_token * (self.params.n_ctx))()
115-
_n_token_count_out = llama_cpp.c_int()
115+
_n_token_count_out = llama_cpp.c_size_t()
116116
if (llama_cpp.llama_load_session_file(
117117
self.ctx,
118118
self.params.path_session.encode("utf8"),
119119
_session_tokens,
120120
self.params.n_ctx,
121121
ctypes.byref(_n_token_count_out)
122-
) != 0):
122+
) != 1):
123123
print(f"error: failed to load session file '{self.params.path_session}'", file=sys.stderr)
124124
return
125+
_n_token_count_out = _n_token_count_out.value
125126
self.session_tokens = _session_tokens[:_n_token_count_out]
126127
print(f"loaded a session with prompt size of {_n_token_count_out} tokens", file=sys.stderr)
127128
else:
@@ -135,19 +136,21 @@ def __init__(self, params: GptParams) -> None:
135136
raise RuntimeError(f"error: prompt is too long ({len(self.embd_inp)} tokens, max {self.params.n_ctx - 4})")
136137

137138
# debug message about similarity of saved session, if applicable
138-
n_matching_session_tokens = 0
139+
self.n_matching_session_tokens = 0
139140
if len(self.session_tokens) > 0:
140141
for id in self.session_tokens:
141-
if n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[n_matching_session_tokens]:
142+
if self.n_matching_session_tokens >= len(self.embd_inp) or id != self.embd_inp[self.n_matching_session_tokens]:
142143
break
143-
n_matching_session_tokens += 1
144+
self.n_matching_session_tokens += 1
144145

145-
if n_matching_session_tokens >= len(self.embd_inp):
146+
if self.n_matching_session_tokens >= len(self.embd_inp):
146147
print(f"session file has exact match for prompt!")
147-
elif n_matching_session_tokens < (len(self.embd_inp) / 2):
148-
print(f"warning: session file has low similarity to prompt ({n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
148+
elif self.n_matching_session_tokens < (len(self.embd_inp) / 2):
149+
print(f"warning: session file has low similarity to prompt ({self.n_matching_session_tokens} / {len(self.embd_inp)} tokens); will mostly be reevaluated")
149150
else:
150-
print(f"session file matches {n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
151+
print(f"session file matches {self.n_matching_session_tokens} / {len(self.embd_inp)} tokens of prompt")
152+
153+
self.need_to_save_session = len(self.params.path_session) > 0 and self.n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
151154

152155
# number of tokens to keep when resetting context
153156
if (self.params.n_keep < 0 or self.params.n_keep > len(self.embd_inp) or self.params.instruct):
@@ -232,9 +235,6 @@ def __init__(self, params: GptParams) -> None:
232235
""", file=sys.stderr)
233236
self.set_color(util.CONSOLE_COLOR_PROMPT)
234237

235-
self.need_to_save_session = len(self.params.path_session) > 0 and n_matching_session_tokens < (len(self.embd_inp) * 3 / 4)
236-
237-
238238
# tokenize a prompt
239239
def _tokenize(self, prompt, bos=True):
240240
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
@@ -302,7 +302,7 @@ def generate(self):
302302
) != 0):
303303
raise Exception("Failed to llama_eval!")
304304

305-
if len(self.embd) > 0 and not len(self.params.path_session) > 0:
305+
if len(self.embd) > 0 and len(self.params.path_session) > 0:
306306
self.session_tokens.extend(self.embd)
307307
self.n_session_consumed = len(self.session_tokens)
308308

@@ -319,7 +319,7 @@ def generate(self):
319319
llama_cpp.llama_save_session_file(
320320
self.ctx,
321321
self.params.path_session.encode("utf8"),
322-
self.session_tokens,
322+
(llama_cpp.llama_token * len(self.session_tokens))(*self.session_tokens),
323323
len(self.session_tokens)
324324
)
325325

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.