1
1
import json
2
2
import multiprocessing
3
+ from re import compile , Match , Pattern
3
4
from threading import Lock
4
5
from functools import partial
5
- from typing import Iterator , List , Optional , Union , Dict
6
+ from typing import Callable , Coroutine , Iterator , List , Optional , Union , Dict
6
7
from typing_extensions import TypedDict , Literal
7
8
8
9
import llama_cpp
9
10
10
11
import anyio
11
12
from anyio .streams .memory import MemoryObjectSendStream
12
13
from starlette .concurrency import run_in_threadpool , iterate_in_threadpool
13
- from fastapi import Depends , FastAPI , APIRouter , Request
14
+ from fastapi import Depends , FastAPI , APIRouter , Request , Response
14
15
from fastapi .middleware .cors import CORSMiddleware
16
+ from fastapi .responses import JSONResponse
17
+ from fastapi .routing import APIRoute
15
18
from pydantic import BaseModel , Field
16
19
from pydantic_settings import BaseSettings
17
20
from sse_starlette .sse import EventSourceResponse
@@ -92,7 +95,190 @@ class Settings(BaseSettings):
92
95
)
93
96
94
97
95
- router = APIRouter ()
98
+ class ErrorResponse (TypedDict ):
99
+ """OpenAI style error response"""
100
+
101
+ message : str
102
+ type : str
103
+ param : Optional [str ]
104
+ code : Optional [str ]
105
+
106
+
107
+ class ErrorResponseFormatters :
108
+ """Collection of formatters for error responses.
109
+
110
+ Args:
111
+ request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
112
+ Request body
113
+ match (Match[str]): Match object from regex pattern
114
+
115
+ Returns:
116
+ tuple[int, ErrorResponse]: Status code and error response
117
+ """
118
+
119
+ @staticmethod
120
+ def context_length_exceeded (
121
+ request : Union [
122
+ "CreateCompletionRequest" , "CreateChatCompletionRequest"
123
+ ],
124
+ match : Match [str ],
125
+ ) -> tuple [int , ErrorResponse ]:
126
+ """Formatter for context length exceeded error"""
127
+
128
+ context_window = int (match .group (2 ))
129
+ prompt_tokens = int (match .group (1 ))
130
+ completion_tokens = request .max_tokens
131
+ if hasattr (request , "messages" ):
132
+ # Chat completion
133
+ message = (
134
+ "This model's maximum context length is {} tokens. "
135
+ "However, you requested {} tokens "
136
+ "({} in the messages, {} in the completion). "
137
+ "Please reduce the length of the messages or completion."
138
+ )
139
+ else :
140
+ # Text completion
141
+ message = (
142
+ "This model's maximum context length is {} tokens, "
143
+ "however you requested {} tokens "
144
+ "({} in your prompt; {} for the completion). "
145
+ "Please reduce your prompt; or completion length."
146
+ )
147
+ return 400 , ErrorResponse (
148
+ message = message .format (
149
+ context_window ,
150
+ completion_tokens + prompt_tokens ,
151
+ prompt_tokens ,
152
+ completion_tokens ,
153
+ ),
154
+ type = "invalid_request_error" ,
155
+ param = "messages" ,
156
+ code = "context_length_exceeded" ,
157
+ )
158
+
159
+ @staticmethod
160
+ def model_not_found (
161
+ request : Union [
162
+ "CreateCompletionRequest" , "CreateChatCompletionRequest"
163
+ ],
164
+ match : Match [str ],
165
+ ) -> tuple [int , ErrorResponse ]:
166
+ """Formatter for model_not_found error"""
167
+
168
+ model_path = str (match .group (1 ))
169
+ message = f"The model `{ model_path } ` does not exist"
170
+ return 400 , ErrorResponse (
171
+ message = message ,
172
+ type = "invalid_request_error" ,
173
+ param = None ,
174
+ code = "model_not_found" ,
175
+ )
176
+
177
+
178
+ class RouteErrorHandler (APIRoute ):
179
+ """Custom APIRoute that handles application errors and exceptions"""
180
+
181
+ # key: regex pattern for original error message from llama_cpp
182
+ # value: formatter function
183
+ pattern_and_formatters : dict [
184
+ "Pattern" ,
185
+ Callable [
186
+ [
187
+ Union ["CreateCompletionRequest" , "CreateChatCompletionRequest" ],
188
+ Match [str ],
189
+ ],
190
+ tuple [int , ErrorResponse ],
191
+ ],
192
+ ] = {
193
+ compile (
194
+ r"Requested tokens \((\d+)\) exceed context window of (\d+)"
195
+ ): ErrorResponseFormatters .context_length_exceeded ,
196
+ compile (
197
+ r"Model path does not exist: (.+)"
198
+ ): ErrorResponseFormatters .model_not_found ,
199
+ }
200
+
201
+ def error_message_wrapper (
202
+ self ,
203
+ error : Exception ,
204
+ body : Optional [
205
+ Union [
206
+ "CreateChatCompletionRequest" ,
207
+ "CreateCompletionRequest" ,
208
+ "CreateEmbeddingRequest" ,
209
+ ]
210
+ ] = None ,
211
+ ) -> tuple [int , ErrorResponse ]:
212
+ """Wraps error message in OpenAI style error response"""
213
+
214
+ if body is not None and isinstance (
215
+ body ,
216
+ (
217
+ CreateCompletionRequest ,
218
+ CreateChatCompletionRequest ,
219
+ ),
220
+ ):
221
+ # When text completion or chat completion
222
+ for pattern , callback in self .pattern_and_formatters .items ():
223
+ match = pattern .search (str (error ))
224
+ if match is not None :
225
+ return callback (body , match )
226
+
227
+ # Wrap other errors as internal server error
228
+ return 500 , ErrorResponse (
229
+ message = str (error ),
230
+ type = "internal_server_error" ,
231
+ param = None ,
232
+ code = None ,
233
+ )
234
+
235
+ def get_route_handler (
236
+ self ,
237
+ ) -> Callable [[Request ], Coroutine [None , None , Response ]]:
238
+ """Defines custom route handler that catches exceptions and formats
239
+ in OpenAI style error response"""
240
+
241
+ original_route_handler = super ().get_route_handler ()
242
+
243
+ async def custom_route_handler (request : Request ) -> Response :
244
+ try :
245
+ return await original_route_handler (request )
246
+ except Exception as exc :
247
+ json_body = await request .json ()
248
+ try :
249
+ if "messages" in json_body :
250
+ # Chat completion
251
+ body : Optional [
252
+ Union [
253
+ CreateChatCompletionRequest ,
254
+ CreateCompletionRequest ,
255
+ CreateEmbeddingRequest ,
256
+ ]
257
+ ] = CreateChatCompletionRequest (** json_body )
258
+ elif "prompt" in json_body :
259
+ # Text completion
260
+ body = CreateCompletionRequest (** json_body )
261
+ else :
262
+ # Embedding
263
+ body = CreateEmbeddingRequest (** json_body )
264
+ except Exception :
265
+ # Invalid request body
266
+ body = None
267
+
268
+ # Get proper error message from the exception
269
+ (
270
+ status_code ,
271
+ error_message ,
272
+ ) = self .error_message_wrapper (error = exc , body = body )
273
+ return JSONResponse (
274
+ {"error" : error_message },
275
+ status_code = status_code ,
276
+ )
277
+
278
+ return custom_route_handler
279
+
280
+
281
+ router = APIRouter (route_class = RouteErrorHandler )
96
282
97
283
settings : Optional [Settings ] = None
98
284
llama : Optional [llama_cpp .Llama ] = None
@@ -179,10 +365,33 @@ def get_settings():
179
365
yield settings
180
366
181
367
368
+ async def get_event_publisher (
369
+ request : Request ,
370
+ inner_send_chan : MemoryObjectSendStream ,
371
+ iterator : Iterator ,
372
+ ):
373
+ async with inner_send_chan :
374
+ try :
375
+ async for chunk in iterate_in_threadpool (iterator ):
376
+ await inner_send_chan .send (dict (data = json .dumps (chunk )))
377
+ if await request .is_disconnected ():
378
+ raise anyio .get_cancelled_exc_class ()()
379
+ if settings .interrupt_requests and llama_outer_lock .locked ():
380
+ await inner_send_chan .send (dict (data = "[DONE]" ))
381
+ raise anyio .get_cancelled_exc_class ()()
382
+ await inner_send_chan .send (dict (data = "[DONE]" ))
383
+ except anyio .get_cancelled_exc_class () as e :
384
+ print ("disconnected" )
385
+ with anyio .move_on_after (1 , shield = True ):
386
+ print (
387
+ f"Disconnected from client (via refresh/close) { request .client } "
388
+ )
389
+ raise e
390
+
182
391
model_field = Field (description = "The model to use for generating completions." , default = None )
183
392
184
393
max_tokens_field = Field (
185
- default = 16 , ge = 1 , le = 2048 , description = "The maximum number of tokens to generate."
394
+ default = 16 , ge = 1 , description = "The maximum number of tokens to generate."
186
395
)
187
396
188
397
temperature_field = Field (
@@ -370,35 +579,31 @@ async def create_completion(
370
579
make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
371
580
])
372
581
373
- if body .stream :
374
- send_chan , recv_chan = anyio .create_memory_object_stream (10 )
582
+ iterator_or_completion : Union [llama_cpp .Completion , Iterator [
583
+ llama_cpp .CompletionChunk
584
+ ]] = await run_in_threadpool (llama , ** kwargs )
375
585
376
- async def event_publisher (inner_send_chan : MemoryObjectSendStream ):
377
- async with inner_send_chan :
378
- try :
379
- iterator : Iterator [llama_cpp .CompletionChunk ] = await run_in_threadpool (llama , ** kwargs ) # type: ignore
380
- async for chunk in iterate_in_threadpool (iterator ):
381
- await inner_send_chan .send (dict (data = json .dumps (chunk )))
382
- if await request .is_disconnected ():
383
- raise anyio .get_cancelled_exc_class ()()
384
- if settings .interrupt_requests and llama_outer_lock .locked ():
385
- await inner_send_chan .send (dict (data = "[DONE]" ))
386
- raise anyio .get_cancelled_exc_class ()()
387
- await inner_send_chan .send (dict (data = "[DONE]" ))
388
- except anyio .get_cancelled_exc_class () as e :
389
- print ("disconnected" )
390
- with anyio .move_on_after (1 , shield = True ):
391
- print (
392
- f"Disconnected from client (via refresh/close) { request .client } "
393
- )
394
- raise e
586
+ if isinstance (iterator_or_completion , Iterator ):
587
+ # EAFP: It's easier to ask for forgiveness than permission
588
+ first_response = await run_in_threadpool (next , iterator_or_completion )
395
589
590
+ # If no exception was raised from first_response, we can assume that
591
+ # the iterator is valid and we can use it to stream the response.
592
+ def iterator () -> Iterator [llama_cpp .CompletionChunk ]:
593
+ yield first_response
594
+ yield from iterator_or_completion
595
+
596
+ send_chan , recv_chan = anyio .create_memory_object_stream (10 )
396
597
return EventSourceResponse (
397
- recv_chan , data_sender_callable = partial (event_publisher , send_chan )
398
- ) # type: ignore
598
+ recv_chan , data_sender_callable = partial ( # type: ignore
599
+ get_event_publisher ,
600
+ request = request ,
601
+ inner_send_chan = send_chan ,
602
+ iterator = iterator (),
603
+ )
604
+ )
399
605
else :
400
- completion : llama_cpp .Completion = await run_in_threadpool (llama , ** kwargs ) # type: ignore
401
- return completion
606
+ return iterator_or_completion
402
607
403
608
404
609
class CreateEmbeddingRequest (BaseModel ):
@@ -501,38 +706,31 @@ async def create_chat_completion(
501
706
make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
502
707
])
503
708
504
- if body .stream :
505
- send_chan , recv_chan = anyio .create_memory_object_stream (10 )
709
+ iterator_or_completion : Union [llama_cpp .ChatCompletion , Iterator [
710
+ llama_cpp .ChatCompletionChunk
711
+ ]] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
506
712
507
- async def event_publisher (inner_send_chan : MemoryObjectSendStream ):
508
- async with inner_send_chan :
509
- try :
510
- iterator : Iterator [llama_cpp .ChatCompletionChunk ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs ) # type: ignore
511
- async for chat_chunk in iterate_in_threadpool (iterator ):
512
- await inner_send_chan .send (dict (data = json .dumps (chat_chunk )))
513
- if await request .is_disconnected ():
514
- raise anyio .get_cancelled_exc_class ()()
515
- if settings .interrupt_requests and llama_outer_lock .locked ():
516
- await inner_send_chan .send (dict (data = "[DONE]" ))
517
- raise anyio .get_cancelled_exc_class ()()
518
- await inner_send_chan .send (dict (data = "[DONE]" ))
519
- except anyio .get_cancelled_exc_class () as e :
520
- print ("disconnected" )
521
- with anyio .move_on_after (1 , shield = True ):
522
- print (
523
- f"Disconnected from client (via refresh/close) { request .client } "
524
- )
525
- raise e
713
+ if isinstance (iterator_or_completion , Iterator ):
714
+ # EAFP: It's easier to ask for forgiveness than permission
715
+ first_response = await run_in_threadpool (next , iterator_or_completion )
716
+
717
+ # If no exception was raised from first_response, we can assume that
718
+ # the iterator is valid and we can use it to stream the response.
719
+ def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
720
+ yield first_response
721
+ yield from iterator_or_completion
526
722
723
+ send_chan , recv_chan = anyio .create_memory_object_stream (10 )
527
724
return EventSourceResponse (
528
- recv_chan ,
529
- data_sender_callable = partial ( event_publisher , send_chan ) ,
530
- ) # type: ignore
531
- else :
532
- completion : llama_cpp . ChatCompletion = await run_in_threadpool (
533
- llama . create_chat_completion , ** kwargs # type: ignore
725
+ recv_chan , data_sender_callable = partial ( # type: ignore
726
+ get_event_publisher ,
727
+ request = request ,
728
+ inner_send_chan = send_chan ,
729
+ iterator = iterator (),
730
+ )
534
731
)
535
- return completion
732
+ else :
733
+ return iterator_or_completion
536
734
537
735
538
736
class ModelData (TypedDict ):
0 commit comments