1
1
import json
2
2
import multiprocessing
3
+ from re import compile , Match , Pattern
3
4
from threading import Lock
4
5
from functools import partial
5
- from typing import Iterator , List , Optional , Union , Dict
6
+ from typing import Callable , Coroutine , Iterator , List , Optional , Union , Dict
6
7
from typing_extensions import TypedDict , Literal
7
8
8
9
import llama_cpp
9
10
10
11
import anyio
11
12
from anyio .streams .memory import MemoryObjectSendStream
12
13
from starlette .concurrency import run_in_threadpool , iterate_in_threadpool
13
- from fastapi import Depends , FastAPI , APIRouter , Request
14
+ from fastapi import Depends , FastAPI , APIRouter , Request , Response
14
15
from fastapi .middleware .cors import CORSMiddleware
16
+ from fastapi .responses import JSONResponse
17
+ from fastapi .routing import APIRoute
15
18
from pydantic import BaseModel , Field
16
19
from pydantic_settings import BaseSettings
17
20
from sse_starlette .sse import EventSourceResponse
@@ -94,7 +97,190 @@ class Settings(BaseSettings):
94
97
)
95
98
96
99
97
- router = APIRouter ()
100
+ class ErrorResponse (TypedDict ):
101
+ """OpenAI style error response"""
102
+
103
+ message : str
104
+ type : str
105
+ param : Optional [str ]
106
+ code : Optional [str ]
107
+
108
+
109
+ class ErrorResponseFormatters :
110
+ """Collection of formatters for error responses.
111
+
112
+ Args:
113
+ request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
114
+ Request body
115
+ match (Match[str]): Match object from regex pattern
116
+
117
+ Returns:
118
+ tuple[int, ErrorResponse]: Status code and error response
119
+ """
120
+
121
+ @staticmethod
122
+ def context_length_exceeded (
123
+ request : Union [
124
+ "CreateCompletionRequest" , "CreateChatCompletionRequest"
125
+ ],
126
+ match : Match [str ],
127
+ ) -> tuple [int , ErrorResponse ]:
128
+ """Formatter for context length exceeded error"""
129
+
130
+ context_window = int (match .group (2 ))
131
+ prompt_tokens = int (match .group (1 ))
132
+ completion_tokens = request .max_tokens
133
+ if hasattr (request , "messages" ):
134
+ # Chat completion
135
+ message = (
136
+ "This model's maximum context length is {} tokens. "
137
+ "However, you requested {} tokens "
138
+ "({} in the messages, {} in the completion). "
139
+ "Please reduce the length of the messages or completion."
140
+ )
141
+ else :
142
+ # Text completion
143
+ message = (
144
+ "This model's maximum context length is {} tokens, "
145
+ "however you requested {} tokens "
146
+ "({} in your prompt; {} for the completion). "
147
+ "Please reduce your prompt; or completion length."
148
+ )
149
+ return 400 , ErrorResponse (
150
+ message = message .format (
151
+ context_window ,
152
+ completion_tokens + prompt_tokens ,
153
+ prompt_tokens ,
154
+ completion_tokens ,
155
+ ),
156
+ type = "invalid_request_error" ,
157
+ param = "messages" ,
158
+ code = "context_length_exceeded" ,
159
+ )
160
+
161
+ @staticmethod
162
+ def model_not_found (
163
+ request : Union [
164
+ "CreateCompletionRequest" , "CreateChatCompletionRequest"
165
+ ],
166
+ match : Match [str ],
167
+ ) -> tuple [int , ErrorResponse ]:
168
+ """Formatter for model_not_found error"""
169
+
170
+ model_path = str (match .group (1 ))
171
+ message = f"The model `{ model_path } ` does not exist"
172
+ return 400 , ErrorResponse (
173
+ message = message ,
174
+ type = "invalid_request_error" ,
175
+ param = None ,
176
+ code = "model_not_found" ,
177
+ )
178
+
179
+
180
+ class RouteErrorHandler (APIRoute ):
181
+ """Custom APIRoute that handles application errors and exceptions"""
182
+
183
+ # key: regex pattern for original error message from llama_cpp
184
+ # value: formatter function
185
+ pattern_and_formatters : dict [
186
+ "Pattern" ,
187
+ Callable [
188
+ [
189
+ Union ["CreateCompletionRequest" , "CreateChatCompletionRequest" ],
190
+ Match [str ],
191
+ ],
192
+ tuple [int , ErrorResponse ],
193
+ ],
194
+ ] = {
195
+ compile (
196
+ r"Requested tokens \((\d+)\) exceed context window of (\d+)"
197
+ ): ErrorResponseFormatters .context_length_exceeded ,
198
+ compile (
199
+ r"Model path does not exist: (.+)"
200
+ ): ErrorResponseFormatters .model_not_found ,
201
+ }
202
+
203
+ def error_message_wrapper (
204
+ self ,
205
+ error : Exception ,
206
+ body : Optional [
207
+ Union [
208
+ "CreateChatCompletionRequest" ,
209
+ "CreateCompletionRequest" ,
210
+ "CreateEmbeddingRequest" ,
211
+ ]
212
+ ] = None ,
213
+ ) -> tuple [int , ErrorResponse ]:
214
+ """Wraps error message in OpenAI style error response"""
215
+
216
+ if body is not None and isinstance (
217
+ body ,
218
+ (
219
+ CreateCompletionRequest ,
220
+ CreateChatCompletionRequest ,
221
+ ),
222
+ ):
223
+ # When text completion or chat completion
224
+ for pattern , callback in self .pattern_and_formatters .items ():
225
+ match = pattern .search (str (error ))
226
+ if match is not None :
227
+ return callback (body , match )
228
+
229
+ # Wrap other errors as internal server error
230
+ return 500 , ErrorResponse (
231
+ message = str (error ),
232
+ type = "internal_server_error" ,
233
+ param = None ,
234
+ code = None ,
235
+ )
236
+
237
+ def get_route_handler (
238
+ self ,
239
+ ) -> Callable [[Request ], Coroutine [None , None , Response ]]:
240
+ """Defines custom route handler that catches exceptions and formats
241
+ in OpenAI style error response"""
242
+
243
+ original_route_handler = super ().get_route_handler ()
244
+
245
+ async def custom_route_handler (request : Request ) -> Response :
246
+ try :
247
+ return await original_route_handler (request )
248
+ except Exception as exc :
249
+ json_body = await request .json ()
250
+ try :
251
+ if "messages" in json_body :
252
+ # Chat completion
253
+ body : Optional [
254
+ Union [
255
+ CreateChatCompletionRequest ,
256
+ CreateCompletionRequest ,
257
+ CreateEmbeddingRequest ,
258
+ ]
259
+ ] = CreateChatCompletionRequest (** json_body )
260
+ elif "prompt" in json_body :
261
+ # Text completion
262
+ body = CreateCompletionRequest (** json_body )
263
+ else :
264
+ # Embedding
265
+ body = CreateEmbeddingRequest (** json_body )
266
+ except Exception :
267
+ # Invalid request body
268
+ body = None
269
+
270
+ # Get proper error message from the exception
271
+ (
272
+ status_code ,
273
+ error_message ,
274
+ ) = self .error_message_wrapper (error = exc , body = body )
275
+ return JSONResponse (
276
+ {"error" : error_message },
277
+ status_code = status_code ,
278
+ )
279
+
280
+ return custom_route_handler
281
+
282
+
283
+ router = APIRouter (route_class = RouteErrorHandler )
98
284
99
285
settings : Optional [Settings ] = None
100
286
llama : Optional [llama_cpp .Llama ] = None
@@ -183,10 +369,33 @@ def get_settings():
183
369
yield settings
184
370
185
371
372
+ async def get_event_publisher (
373
+ request : Request ,
374
+ inner_send_chan : MemoryObjectSendStream ,
375
+ iterator : Iterator ,
376
+ ):
377
+ async with inner_send_chan :
378
+ try :
379
+ async for chunk in iterate_in_threadpool (iterator ):
380
+ await inner_send_chan .send (dict (data = json .dumps (chunk )))
381
+ if await request .is_disconnected ():
382
+ raise anyio .get_cancelled_exc_class ()()
383
+ if settings .interrupt_requests and llama_outer_lock .locked ():
384
+ await inner_send_chan .send (dict (data = "[DONE]" ))
385
+ raise anyio .get_cancelled_exc_class ()()
386
+ await inner_send_chan .send (dict (data = "[DONE]" ))
387
+ except anyio .get_cancelled_exc_class () as e :
388
+ print ("disconnected" )
389
+ with anyio .move_on_after (1 , shield = True ):
390
+ print (
391
+ f"Disconnected from client (via refresh/close) { request .client } "
392
+ )
393
+ raise e
394
+
186
395
model_field = Field (description = "The model to use for generating completions." , default = None )
187
396
188
397
max_tokens_field = Field (
189
- default = 16 , ge = 1 , le = 2048 , description = "The maximum number of tokens to generate."
398
+ default = 16 , ge = 1 , description = "The maximum number of tokens to generate."
190
399
)
191
400
192
401
temperature_field = Field (
@@ -374,35 +583,31 @@ async def create_completion(
374
583
make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
375
584
])
376
585
377
- if body .stream :
378
- send_chan , recv_chan = anyio .create_memory_object_stream (10 )
586
+ iterator_or_completion : Union [llama_cpp .Completion , Iterator [
587
+ llama_cpp .CompletionChunk
588
+ ]] = await run_in_threadpool (llama , ** kwargs )
379
589
380
- async def event_publisher (inner_send_chan : MemoryObjectSendStream ):
381
- async with inner_send_chan :
382
- try :
383
- iterator : Iterator [llama_cpp .CompletionChunk ] = await run_in_threadpool (llama , ** kwargs ) # type: ignore
384
- async for chunk in iterate_in_threadpool (iterator ):
385
- await inner_send_chan .send (dict (data = json .dumps (chunk )))
386
- if await request .is_disconnected ():
387
- raise anyio .get_cancelled_exc_class ()()
388
- if settings .interrupt_requests and llama_outer_lock .locked ():
389
- await inner_send_chan .send (dict (data = "[DONE]" ))
390
- raise anyio .get_cancelled_exc_class ()()
391
- await inner_send_chan .send (dict (data = "[DONE]" ))
392
- except anyio .get_cancelled_exc_class () as e :
393
- print ("disconnected" )
394
- with anyio .move_on_after (1 , shield = True ):
395
- print (
396
- f"Disconnected from client (via refresh/close) { request .client } "
397
- )
398
- raise e
590
+ if isinstance (iterator_or_completion , Iterator ):
591
+ # EAFP: It's easier to ask for forgiveness than permission
592
+ first_response = await run_in_threadpool (next , iterator_or_completion )
399
593
594
+ # If no exception was raised from first_response, we can assume that
595
+ # the iterator is valid and we can use it to stream the response.
596
+ def iterator () -> Iterator [llama_cpp .CompletionChunk ]:
597
+ yield first_response
598
+ yield from iterator_or_completion
599
+
600
+ send_chan , recv_chan = anyio .create_memory_object_stream (10 )
400
601
return EventSourceResponse (
401
- recv_chan , data_sender_callable = partial (event_publisher , send_chan )
402
- ) # type: ignore
602
+ recv_chan , data_sender_callable = partial ( # type: ignore
603
+ get_event_publisher ,
604
+ request = request ,
605
+ inner_send_chan = send_chan ,
606
+ iterator = iterator (),
607
+ )
608
+ )
403
609
else :
404
- completion : llama_cpp .Completion = await run_in_threadpool (llama , ** kwargs ) # type: ignore
405
- return completion
610
+ return iterator_or_completion
406
611
407
612
408
613
class CreateEmbeddingRequest (BaseModel ):
@@ -505,38 +710,31 @@ async def create_chat_completion(
505
710
make_logit_bias_processor (llama , body .logit_bias , body .logit_bias_type ),
506
711
])
507
712
508
- if body .stream :
509
- send_chan , recv_chan = anyio .create_memory_object_stream (10 )
713
+ iterator_or_completion : Union [llama_cpp .ChatCompletion , Iterator [
714
+ llama_cpp .ChatCompletionChunk
715
+ ]] = await run_in_threadpool (llama .create_chat_completion , ** kwargs )
510
716
511
- async def event_publisher (inner_send_chan : MemoryObjectSendStream ):
512
- async with inner_send_chan :
513
- try :
514
- iterator : Iterator [llama_cpp .ChatCompletionChunk ] = await run_in_threadpool (llama .create_chat_completion , ** kwargs ) # type: ignore
515
- async for chat_chunk in iterate_in_threadpool (iterator ):
516
- await inner_send_chan .send (dict (data = json .dumps (chat_chunk )))
517
- if await request .is_disconnected ():
518
- raise anyio .get_cancelled_exc_class ()()
519
- if settings .interrupt_requests and llama_outer_lock .locked ():
520
- await inner_send_chan .send (dict (data = "[DONE]" ))
521
- raise anyio .get_cancelled_exc_class ()()
522
- await inner_send_chan .send (dict (data = "[DONE]" ))
523
- except anyio .get_cancelled_exc_class () as e :
524
- print ("disconnected" )
525
- with anyio .move_on_after (1 , shield = True ):
526
- print (
527
- f"Disconnected from client (via refresh/close) { request .client } "
528
- )
529
- raise e
717
+ if isinstance (iterator_or_completion , Iterator ):
718
+ # EAFP: It's easier to ask for forgiveness than permission
719
+ first_response = await run_in_threadpool (next , iterator_or_completion )
720
+
721
+ # If no exception was raised from first_response, we can assume that
722
+ # the iterator is valid and we can use it to stream the response.
723
+ def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
724
+ yield first_response
725
+ yield from iterator_or_completion
530
726
727
+ send_chan , recv_chan = anyio .create_memory_object_stream (10 )
531
728
return EventSourceResponse (
532
- recv_chan ,
533
- data_sender_callable = partial ( event_publisher , send_chan ) ,
534
- ) # type: ignore
535
- else :
536
- completion : llama_cpp . ChatCompletion = await run_in_threadpool (
537
- llama . create_chat_completion , ** kwargs # type: ignore
729
+ recv_chan , data_sender_callable = partial ( # type: ignore
730
+ get_event_publisher ,
731
+ request = request ,
732
+ inner_send_chan = send_chan ,
733
+ iterator = iterator (),
734
+ )
538
735
)
539
- return completion
736
+ else :
737
+ return iterator_or_completion
540
738
541
739
542
740
class ModelData (TypedDict ):
0 commit comments