Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f609ff4

Browse filesBrowse files
committed
Refactor interactive mode in main.cpp
1 parent cb58437 commit f609ff4
Copy full SHA for f609ff4

File tree

Expand file treeCollapse file tree

1 file changed

+96
-69
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+96
-69
lines changed

‎main.cpp

Copy file name to clipboardExpand all lines: main.cpp
+96-69Lines changed: 96 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ void sigint_handler(int signo) {
5555
#endif
5656

5757

58+
void process_interactive_input(llama_context& ctx, const gpt_params& params);
59+
5860
int main(int argc, char ** argv) {
5961
ggml_time_init();
6062
const int64_t t_main_start_us = ggml_time_us();
@@ -85,15 +87,18 @@ int main(int argc, char ** argv) {
8587
// params.prompt = R"(// this function checks if the number n is prime
8688
//bool is_prime(int n) {)";
8789

88-
int64_t t_load_us = 0;
89-
9090
// load the model
91-
llama_context* ctx_ptr = llama_init_from_params(params);
91+
llama_context* ctx_ptr = nullptr;
92+
{
93+
ctx_ptr = llama_init_from_params(params);
94+
if (!ctx_ptr) {
95+
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
96+
return 1;
97+
}
98+
}
99+
92100
llama_context & ctx = *ctx_ptr;
93-
gpt_vocab & vocab = llama_context_get_vocab(ctx);
94-
95-
// print system information
96-
llama_print_context_info(ctx);
101+
const gpt_vocab & vocab = llama_context_get_vocab(ctx);
97102

98103
// Add a space in front of the first character to match OG llama tokenizer behavior
99104
params.prompt.insert(0, 1, ' ');
@@ -109,8 +114,13 @@ int main(int argc, char ** argv) {
109114
}
110115

111116
// tokenize the reverse prompt
112-
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.prompt);
117+
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.antiprompt);
118+
<<<<<<< HEAD
119+
=======
120+
121+
>>>>>>> b30724a (Fix main)
113122

123+
// Setup interactive mode
114124
if (params.interactive) {
115125
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
116126
struct sigaction sigint_action;
@@ -146,50 +156,56 @@ int main(int argc, char ** argv) {
146156
is_interacting = true;
147157
}
148158

149-
bool input_noecho = false;
150-
151-
int remaining_tokens = params.n_predict;
159+
// prompt user immediately after the starting prompt has been loaded
160+
if (params.interactive_start) {
161+
is_interacting = true;
162+
}
152163

153164
// set the color for the prompt which will be output initially
154165
if (params.use_color) {
155166
printf(ANSI_COLOR_YELLOW);
156167
}
157168

158-
if(!llama_ingest_input(ctx, params.prompt))
169+
// Prepare the context with input
170+
// Send "beginning of string"
171+
llama_add_bos(ctx);
172+
173+
// load the input
174+
llama_update_input(ctx, params.prompt);
175+
176+
llama_print_startup_stats(ctx);
177+
178+
if(!llama_prepare_context(ctx))
159179
{
160-
fprintf(stderr, "Failed to ingest prompt\n");
180+
fprintf(stderr, "%s: failed to prepare context\n", __func__);
161181
return 1;
162-
};
163-
164-
// display text
165-
input_noecho = false;
166-
const std::vector<gpt_vocab::id>& embd = llama_context_get_embedding(ctx);
167-
if (!input_noecho) {
168-
for (auto id : embd) {
169-
printf("%s", vocab.id_to_token[id].c_str());
170-
}
171-
fflush(stdout);
172-
}
173-
174-
if (!input_noecho && params.use_color) {
175-
printf(ANSI_COLOR_RESET);
176182
}
177183

178-
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
179-
180-
while (llama_context_is_finished(ctx) != true) {
181-
gpt_vocab::id model_output = 0;
182-
bool response = llama_infer(ctx, model_output);
183-
if (response) {
184-
printf("%s", vocab.id_to_token[model_output].c_str());
185-
fflush(stdout);
184+
bool input_noecho = false;
185+
bool is_end_of_text = false;
186+
while (llama_context_is_finished(ctx) == false) {
187+
std::string model_output{};
188+
189+
if (llama_has_unconsumed_input(ctx)) {
190+
llama_ingest_all_pending_input(ctx, !input_noecho);
191+
// reset color to default if we there is no pending user input
192+
if (!input_noecho && params.use_color) {
193+
printf(ANSI_COLOR_RESET);
194+
}
195+
}else{
196+
// Run inference if we don't have any pending input
197+
llama_infer(ctx, model_output, is_end_of_text);
198+
// print the single token output
199+
printf("%s", model_output.c_str());
200+
input_noecho = false;
186201
}
187202

188203
// in interactive mode, and not currently processing queued inputs;
189204
// check if we should prompt the user for more
190-
if (params.interactive) {
205+
if (params.interactive && !llama_has_unconsumed_input(ctx)) {
206+
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
191207
// check for reverse prompt
192-
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
208+
if (antiprompt_inp.size() && llama_is_anti_prompt_present(ctx, antiprompt_inp)) {
193209
// reverse prompt found
194210
is_interacting = true;
195211
}
@@ -202,38 +218,14 @@ int main(int argc, char ** argv) {
202218
}
203219

204220
// currently being interactive
205-
bool another_line = true;
206-
while (another_line) {
207-
fflush(stdout);
208-
char buf[256] = {0};
209-
int n_read;
210-
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
211-
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
212-
// presumable empty line, consume the newline
213-
std::ignore = scanf("%*c");
214-
n_read=0;
215-
}
216-
if (params.use_color) printf(ANSI_COLOR_RESET);
217-
218-
if (n_read > 0 && buf[n_read-1]=='\\') {
219-
another_line = true;
220-
buf[n_read-1] = '\n';
221-
buf[n_read] = 0;
222-
} else {
223-
another_line = false;
224-
buf[n_read] = '\n';
225-
buf[n_read+1] = 0;
226-
}
227-
// Do not clear existing context in interactive mode
228-
llama_update_context_with_prompt(ctx, buf, false);
229-
}
230-
221+
process_interactive_input(ctx, params);
222+
input_noecho = true; // do not echo this input again
231223
is_interacting = false;
232224
}
233225
}
234226

235227
// end of text token
236-
if (embd.back() == EOS_TOKEN_ID) {
228+
if (is_end_of_text) {
237229
if (params.interactive) {
238230
is_interacting = true;
239231
} else {
@@ -243,23 +235,58 @@ int main(int argc, char ** argv) {
243235
}
244236

245237
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
246-
if (params.interactive && remaining_tokens <= 0) {
247-
remaining_tokens = params.n_predict;
238+
if (params.interactive && llama_context_is_finished(ctx)) {
239+
llama_context_reset_remaining_tokens(ctx)
248240
is_interacting = true;
249241
}
250242
}
251243

252-
// report timing from context
244+
245+
#if defined (_WIN32)
246+
signal(SIGINT, SIG_DFL);
247+
#endif
248+
249+
// report timing
253250
{
254251
const int64_t t_main_end_us = ggml_time_us();
255252
llama_print_end_stats(ctx);
256253
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
257254
}
258-
llama_free_context(ctx_ptr);
255+
256+
llama_free_context(ctx_ptr);
259257

260258
if (params.use_color) {
261259
printf(ANSI_COLOR_RESET);
262260
}
263-
264261
return 0;
265262
}
263+
264+
void process_interactive_input(llama_context& ctx, const gpt_params& params)
265+
{
266+
bool another_line = true;
267+
while (another_line) {
268+
fflush(stdout);
269+
char buf[256] = {0};
270+
int n_read;
271+
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
272+
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
273+
// presumable empty line, consume the newline
274+
std::ignore = scanf("%*c");
275+
n_read=0;
276+
}
277+
if (params.use_color) printf(ANSI_COLOR_RESET);
278+
279+
if (n_read > 0 && buf[n_read-1]=='\\') {
280+
another_line = true;
281+
buf[n_read-1] = '\n';
282+
buf[n_read] = 0;
283+
} else {
284+
another_line = false;
285+
buf[n_read] = '\n';
286+
buf[n_read+1] = 0;
287+
}
288+
289+
// Do not clear existing context in interactive mode
290+
llama_update_input(ctx, buf);
291+
}
292+
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.