@@ -112,16 +112,17 @@ def __init__(self, params: GptParams) -> None:
112
112
113
113
if (path .exists (self .params .path_session )):
114
114
_session_tokens = (llama_cpp .llama_token * (self .params .n_ctx ))()
115
- _n_token_count_out = llama_cpp .c_int ()
115
+ _n_token_count_out = llama_cpp .c_size_t ()
116
116
if (llama_cpp .llama_load_session_file (
117
117
self .ctx ,
118
118
self .params .path_session .encode ("utf8" ),
119
119
_session_tokens ,
120
120
self .params .n_ctx ,
121
121
ctypes .byref (_n_token_count_out )
122
- ) != 0 ):
122
+ ) != 1 ):
123
123
print (f"error: failed to load session file '{ self .params .path_session } '" , file = sys .stderr )
124
124
return
125
+ _n_token_count_out = _n_token_count_out .value
125
126
self .session_tokens = _session_tokens [:_n_token_count_out ]
126
127
print (f"loaded a session with prompt size of { _n_token_count_out } tokens" , file = sys .stderr )
127
128
else :
@@ -135,19 +136,21 @@ def __init__(self, params: GptParams) -> None:
135
136
raise RuntimeError (f"error: prompt is too long ({ len (self .embd_inp )} tokens, max { self .params .n_ctx - 4 } )" )
136
137
137
138
# debug message about similarity of saved session, if applicable
138
- n_matching_session_tokens = 0
139
+ self . n_matching_session_tokens = 0
139
140
if len (self .session_tokens ) > 0 :
140
141
for id in self .session_tokens :
141
- if n_matching_session_tokens >= len (self .embd_inp ) or id != self .embd_inp [n_matching_session_tokens ]:
142
+ if self . n_matching_session_tokens >= len (self .embd_inp ) or id != self .embd_inp [self . n_matching_session_tokens ]:
142
143
break
143
- n_matching_session_tokens += 1
144
+ self . n_matching_session_tokens += 1
144
145
145
- if n_matching_session_tokens >= len (self .embd_inp ):
146
+ if self . n_matching_session_tokens >= len (self .embd_inp ):
146
147
print (f"session file has exact match for prompt!" )
147
- elif n_matching_session_tokens < (len (self .embd_inp ) / 2 ):
148
- print (f"warning: session file has low similarity to prompt ({ n_matching_session_tokens } / { len (self .embd_inp )} tokens); will mostly be reevaluated" )
148
+ elif self . n_matching_session_tokens < (len (self .embd_inp ) / 2 ):
149
+ print (f"warning: session file has low similarity to prompt ({ self . n_matching_session_tokens } / { len (self .embd_inp )} tokens); will mostly be reevaluated" )
149
150
else :
150
- print (f"session file matches { n_matching_session_tokens } / { len (self .embd_inp )} tokens of prompt" )
151
+ print (f"session file matches { self .n_matching_session_tokens } / { len (self .embd_inp )} tokens of prompt" )
152
+
153
+ self .need_to_save_session = len (self .params .path_session ) > 0 and self .n_matching_session_tokens < (len (self .embd_inp ) * 3 / 4 )
151
154
152
155
# number of tokens to keep when resetting context
153
156
if (self .params .n_keep < 0 or self .params .n_keep > len (self .embd_inp ) or self .params .instruct ):
@@ -232,9 +235,6 @@ def __init__(self, params: GptParams) -> None:
232
235
""" , file = sys .stderr )
233
236
self .set_color (util .CONSOLE_COLOR_PROMPT )
234
237
235
- self .need_to_save_session = len (self .params .path_session ) > 0 and n_matching_session_tokens < (len (self .embd_inp ) * 3 / 4 )
236
-
237
-
238
238
# tokenize a prompt
239
239
def _tokenize (self , prompt , bos = True ):
240
240
_arr = (llama_cpp .llama_token * ((len (prompt ) + 1 ) * 4 ))()
@@ -302,7 +302,7 @@ def generate(self):
302
302
) != 0 ):
303
303
raise Exception ("Failed to llama_eval!" )
304
304
305
- if len (self .embd ) > 0 and not len (self .params .path_session ) > 0 :
305
+ if len (self .embd ) > 0 and len (self .params .path_session ) > 0 :
306
306
self .session_tokens .extend (self .embd )
307
307
self .n_session_consumed = len (self .session_tokens )
308
308
@@ -319,7 +319,7 @@ def generate(self):
319
319
llama_cpp .llama_save_session_file (
320
320
self .ctx ,
321
321
self .params .path_session .encode ("utf8" ),
322
- self .session_tokens ,
322
+ ( llama_cpp . llama_token * len ( self .session_tokens ))( * self . session_tokens ) ,
323
323
len (self .session_tokens )
324
324
)
325
325
0 commit comments