Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1551ba1

Browse filesBrowse files
committed
Added RouteErrorHandler for server
1 parent 6d8892f commit 1551ba1
Copy full SHA for 1551ba1

File tree

Expand file treeCollapse file tree

2 files changed

+256
-58
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+256
-58
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def _create_completion(
845845

846846
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
847847
raise ValueError(
848-
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
848+
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
849849
)
850850

851851
if max_tokens <= 0:

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+255-57Lines changed: 255 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import json
22
import multiprocessing
3+
from re import compile, Match, Pattern
34
from threading import Lock
45
from functools import partial
5-
from typing import Iterator, List, Optional, Union, Dict
6+
from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict
67
from typing_extensions import TypedDict, Literal
78

89
import llama_cpp
910

1011
import anyio
1112
from anyio.streams.memory import MemoryObjectSendStream
1213
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
13-
from fastapi import Depends, FastAPI, APIRouter, Request
14+
from fastapi import Depends, FastAPI, APIRouter, Request, Response
1415
from fastapi.middleware.cors import CORSMiddleware
16+
from fastapi.responses import JSONResponse
17+
from fastapi.routing import APIRoute
1518
from pydantic import BaseModel, Field
1619
from pydantic_settings import BaseSettings
1720
from sse_starlette.sse import EventSourceResponse
@@ -92,7 +95,190 @@ class Settings(BaseSettings):
9295
)
9396

9497

95-
router = APIRouter()
98+
class ErrorResponse(TypedDict):
99+
"""OpenAI style error response"""
100+
101+
message: str
102+
type: str
103+
param: Optional[str]
104+
code: Optional[str]
105+
106+
107+
class ErrorResponseFormatters:
108+
"""Collection of formatters for error responses.
109+
110+
Args:
111+
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
112+
Request body
113+
match (Match[str]): Match object from regex pattern
114+
115+
Returns:
116+
tuple[int, ErrorResponse]: Status code and error response
117+
"""
118+
119+
@staticmethod
120+
def context_length_exceeded(
121+
request: Union[
122+
"CreateCompletionRequest", "CreateChatCompletionRequest"
123+
],
124+
match: Match[str],
125+
) -> tuple[int, ErrorResponse]:
126+
"""Formatter for context length exceeded error"""
127+
128+
context_window = int(match.group(2))
129+
prompt_tokens = int(match.group(1))
130+
completion_tokens = request.max_tokens
131+
if hasattr(request, "messages"):
132+
# Chat completion
133+
message = (
134+
"This model's maximum context length is {} tokens. "
135+
"However, you requested {} tokens "
136+
"({} in the messages, {} in the completion). "
137+
"Please reduce the length of the messages or completion."
138+
)
139+
else:
140+
# Text completion
141+
message = (
142+
"This model's maximum context length is {} tokens, "
143+
"however you requested {} tokens "
144+
"({} in your prompt; {} for the completion). "
145+
"Please reduce your prompt; or completion length."
146+
)
147+
return 400, ErrorResponse(
148+
message=message.format(
149+
context_window,
150+
completion_tokens + prompt_tokens,
151+
prompt_tokens,
152+
completion_tokens,
153+
),
154+
type="invalid_request_error",
155+
param="messages",
156+
code="context_length_exceeded",
157+
)
158+
159+
@staticmethod
160+
def model_not_found(
161+
request: Union[
162+
"CreateCompletionRequest", "CreateChatCompletionRequest"
163+
],
164+
match: Match[str],
165+
) -> tuple[int, ErrorResponse]:
166+
"""Formatter for model_not_found error"""
167+
168+
model_path = str(match.group(1))
169+
message = f"The model `{model_path}` does not exist"
170+
return 400, ErrorResponse(
171+
message=message,
172+
type="invalid_request_error",
173+
param=None,
174+
code="model_not_found",
175+
)
176+
177+
178+
class RouteErrorHandler(APIRoute):
179+
"""Custom APIRoute that handles application errors and exceptions"""
180+
181+
# key: regex pattern for original error message from llama_cpp
182+
# value: formatter function
183+
pattern_and_formatters: dict[
184+
"Pattern",
185+
Callable[
186+
[
187+
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
188+
Match[str],
189+
],
190+
tuple[int, ErrorResponse],
191+
],
192+
] = {
193+
compile(
194+
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
195+
): ErrorResponseFormatters.context_length_exceeded,
196+
compile(
197+
r"Model path does not exist: (.+)"
198+
): ErrorResponseFormatters.model_not_found,
199+
}
200+
201+
def error_message_wrapper(
202+
self,
203+
error: Exception,
204+
body: Optional[
205+
Union[
206+
"CreateChatCompletionRequest",
207+
"CreateCompletionRequest",
208+
"CreateEmbeddingRequest",
209+
]
210+
] = None,
211+
) -> tuple[int, ErrorResponse]:
212+
"""Wraps error message in OpenAI style error response"""
213+
214+
if body is not None and isinstance(
215+
body,
216+
(
217+
CreateCompletionRequest,
218+
CreateChatCompletionRequest,
219+
),
220+
):
221+
# When text completion or chat completion
222+
for pattern, callback in self.pattern_and_formatters.items():
223+
match = pattern.search(str(error))
224+
if match is not None:
225+
return callback(body, match)
226+
227+
# Wrap other errors as internal server error
228+
return 500, ErrorResponse(
229+
message=str(error),
230+
type="internal_server_error",
231+
param=None,
232+
code=None,
233+
)
234+
235+
def get_route_handler(
236+
self,
237+
) -> Callable[[Request], Coroutine[None, None, Response]]:
238+
"""Defines custom route handler that catches exceptions and formats
239+
in OpenAI style error response"""
240+
241+
original_route_handler = super().get_route_handler()
242+
243+
async def custom_route_handler(request: Request) -> Response:
244+
try:
245+
return await original_route_handler(request)
246+
except Exception as exc:
247+
json_body = await request.json()
248+
try:
249+
if "messages" in json_body:
250+
# Chat completion
251+
body: Optional[
252+
Union[
253+
CreateChatCompletionRequest,
254+
CreateCompletionRequest,
255+
CreateEmbeddingRequest,
256+
]
257+
] = CreateChatCompletionRequest(**json_body)
258+
elif "prompt" in json_body:
259+
# Text completion
260+
body = CreateCompletionRequest(**json_body)
261+
else:
262+
# Embedding
263+
body = CreateEmbeddingRequest(**json_body)
264+
except Exception:
265+
# Invalid request body
266+
body = None
267+
268+
# Get proper error message from the exception
269+
(
270+
status_code,
271+
error_message,
272+
) = self.error_message_wrapper(error=exc, body=body)
273+
return JSONResponse(
274+
{"error": error_message},
275+
status_code=status_code,
276+
)
277+
278+
return custom_route_handler
279+
280+
281+
router = APIRouter(route_class=RouteErrorHandler)
96282

97283
settings: Optional[Settings] = None
98284
llama: Optional[llama_cpp.Llama] = None
@@ -179,10 +365,33 @@ def get_settings():
179365
yield settings
180366

181367

368+
async def get_event_publisher(
369+
request: Request,
370+
inner_send_chan: MemoryObjectSendStream,
371+
iterator: Iterator,
372+
):
373+
async with inner_send_chan:
374+
try:
375+
async for chunk in iterate_in_threadpool(iterator):
376+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
377+
if await request.is_disconnected():
378+
raise anyio.get_cancelled_exc_class()()
379+
if settings.interrupt_requests and llama_outer_lock.locked():
380+
await inner_send_chan.send(dict(data="[DONE]"))
381+
raise anyio.get_cancelled_exc_class()()
382+
await inner_send_chan.send(dict(data="[DONE]"))
383+
except anyio.get_cancelled_exc_class() as e:
384+
print("disconnected")
385+
with anyio.move_on_after(1, shield=True):
386+
print(
387+
f"Disconnected from client (via refresh/close) {request.client}"
388+
)
389+
raise e
390+
182391
model_field = Field(description="The model to use for generating completions.", default=None)
183392

184393
max_tokens_field = Field(
185-
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
394+
default=16, ge=1, description="The maximum number of tokens to generate."
186395
)
187396

188397
temperature_field = Field(
@@ -370,35 +579,31 @@ async def create_completion(
370579
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
371580
])
372581

373-
if body.stream:
374-
send_chan, recv_chan = anyio.create_memory_object_stream(10)
582+
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
583+
llama_cpp.CompletionChunk
584+
]] = await run_in_threadpool(llama, **kwargs)
375585

376-
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
377-
async with inner_send_chan:
378-
try:
379-
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
380-
async for chunk in iterate_in_threadpool(iterator):
381-
await inner_send_chan.send(dict(data=json.dumps(chunk)))
382-
if await request.is_disconnected():
383-
raise anyio.get_cancelled_exc_class()()
384-
if settings.interrupt_requests and llama_outer_lock.locked():
385-
await inner_send_chan.send(dict(data="[DONE]"))
386-
raise anyio.get_cancelled_exc_class()()
387-
await inner_send_chan.send(dict(data="[DONE]"))
388-
except anyio.get_cancelled_exc_class() as e:
389-
print("disconnected")
390-
with anyio.move_on_after(1, shield=True):
391-
print(
392-
f"Disconnected from client (via refresh/close) {request.client}"
393-
)
394-
raise e
586+
if isinstance(iterator_or_completion, Iterator):
587+
# EAFP: It's easier to ask for forgiveness than permission
588+
first_response = await run_in_threadpool(next, iterator_or_completion)
395589

590+
# If no exception was raised from first_response, we can assume that
591+
# the iterator is valid and we can use it to stream the response.
592+
def iterator() -> Iterator[llama_cpp.CompletionChunk]:
593+
yield first_response
594+
yield from iterator_or_completion
595+
596+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
396597
return EventSourceResponse(
397-
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
398-
) # type: ignore
598+
recv_chan, data_sender_callable=partial( # type: ignore
599+
get_event_publisher,
600+
request=request,
601+
inner_send_chan=send_chan,
602+
iterator=iterator(),
603+
)
604+
)
399605
else:
400-
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
401-
return completion
606+
return iterator_or_completion
402607

403608

404609
class CreateEmbeddingRequest(BaseModel):
@@ -501,38 +706,31 @@ async def create_chat_completion(
501706
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
502707
])
503708

504-
if body.stream:
505-
send_chan, recv_chan = anyio.create_memory_object_stream(10)
709+
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
710+
llama_cpp.ChatCompletionChunk
711+
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
506712

507-
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
508-
async with inner_send_chan:
509-
try:
510-
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
511-
async for chat_chunk in iterate_in_threadpool(iterator):
512-
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
513-
if await request.is_disconnected():
514-
raise anyio.get_cancelled_exc_class()()
515-
if settings.interrupt_requests and llama_outer_lock.locked():
516-
await inner_send_chan.send(dict(data="[DONE]"))
517-
raise anyio.get_cancelled_exc_class()()
518-
await inner_send_chan.send(dict(data="[DONE]"))
519-
except anyio.get_cancelled_exc_class() as e:
520-
print("disconnected")
521-
with anyio.move_on_after(1, shield=True):
522-
print(
523-
f"Disconnected from client (via refresh/close) {request.client}"
524-
)
525-
raise e
713+
if isinstance(iterator_or_completion, Iterator):
714+
# EAFP: It's easier to ask for forgiveness than permission
715+
first_response = await run_in_threadpool(next, iterator_or_completion)
716+
717+
# If no exception was raised from first_response, we can assume that
718+
# the iterator is valid and we can use it to stream the response.
719+
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
720+
yield first_response
721+
yield from iterator_or_completion
526722

723+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
527724
return EventSourceResponse(
528-
recv_chan,
529-
data_sender_callable=partial(event_publisher, send_chan),
530-
) # type: ignore
531-
else:
532-
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
533-
llama.create_chat_completion, **kwargs # type: ignore
725+
recv_chan, data_sender_callable=partial( # type: ignore
726+
get_event_publisher,
727+
request=request,
728+
inner_send_chan=send_chan,
729+
iterator=iterator(),
730+
)
534731
)
535-
return completion
732+
else:
733+
return iterator_or_completion
536734

537735

538736
class ModelData(TypedDict):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.