Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 365d9a4

Browse filesBrowse files
authored
Merge pull request abetlen#481 from c0sogi/main
Added `RouteErrorHandler` for server
2 parents a9cb645 + 1551ba1 commit 365d9a4
Copy full SHA for 365d9a4

File tree

Expand file treeCollapse file tree

2 files changed

+256
-58
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+256
-58
lines changed

‎llama_cpp/llama.py

Copy file name to clipboardExpand all lines: llama_cpp/llama.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def _create_completion(
845845

846846
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
847847
raise ValueError(
848-
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
848+
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
849849
)
850850

851851
if max_tokens <= 0:

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+255-57Lines changed: 255 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import json
22
import multiprocessing
3+
from re import compile, Match, Pattern
34
from threading import Lock
45
from functools import partial
5-
from typing import Iterator, List, Optional, Union, Dict
6+
from typing import Callable, Coroutine, Iterator, List, Optional, Union, Dict
67
from typing_extensions import TypedDict, Literal
78

89
import llama_cpp
910

1011
import anyio
1112
from anyio.streams.memory import MemoryObjectSendStream
1213
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
13-
from fastapi import Depends, FastAPI, APIRouter, Request
14+
from fastapi import Depends, FastAPI, APIRouter, Request, Response
1415
from fastapi.middleware.cors import CORSMiddleware
16+
from fastapi.responses import JSONResponse
17+
from fastapi.routing import APIRoute
1518
from pydantic import BaseModel, Field
1619
from pydantic_settings import BaseSettings
1720
from sse_starlette.sse import EventSourceResponse
@@ -94,7 +97,190 @@ class Settings(BaseSettings):
9497
)
9598

9699

97-
router = APIRouter()
100+
class ErrorResponse(TypedDict):
101+
"""OpenAI style error response"""
102+
103+
message: str
104+
type: str
105+
param: Optional[str]
106+
code: Optional[str]
107+
108+
109+
class ErrorResponseFormatters:
110+
"""Collection of formatters for error responses.
111+
112+
Args:
113+
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
114+
Request body
115+
match (Match[str]): Match object from regex pattern
116+
117+
Returns:
118+
tuple[int, ErrorResponse]: Status code and error response
119+
"""
120+
121+
@staticmethod
122+
def context_length_exceeded(
123+
request: Union[
124+
"CreateCompletionRequest", "CreateChatCompletionRequest"
125+
],
126+
match: Match[str],
127+
) -> tuple[int, ErrorResponse]:
128+
"""Formatter for context length exceeded error"""
129+
130+
context_window = int(match.group(2))
131+
prompt_tokens = int(match.group(1))
132+
completion_tokens = request.max_tokens
133+
if hasattr(request, "messages"):
134+
# Chat completion
135+
message = (
136+
"This model's maximum context length is {} tokens. "
137+
"However, you requested {} tokens "
138+
"({} in the messages, {} in the completion). "
139+
"Please reduce the length of the messages or completion."
140+
)
141+
else:
142+
# Text completion
143+
message = (
144+
"This model's maximum context length is {} tokens, "
145+
"however you requested {} tokens "
146+
"({} in your prompt; {} for the completion). "
147+
"Please reduce your prompt; or completion length."
148+
)
149+
return 400, ErrorResponse(
150+
message=message.format(
151+
context_window,
152+
completion_tokens + prompt_tokens,
153+
prompt_tokens,
154+
completion_tokens,
155+
),
156+
type="invalid_request_error",
157+
param="messages",
158+
code="context_length_exceeded",
159+
)
160+
161+
@staticmethod
162+
def model_not_found(
163+
request: Union[
164+
"CreateCompletionRequest", "CreateChatCompletionRequest"
165+
],
166+
match: Match[str],
167+
) -> tuple[int, ErrorResponse]:
168+
"""Formatter for model_not_found error"""
169+
170+
model_path = str(match.group(1))
171+
message = f"The model `{model_path}` does not exist"
172+
return 400, ErrorResponse(
173+
message=message,
174+
type="invalid_request_error",
175+
param=None,
176+
code="model_not_found",
177+
)
178+
179+
180+
class RouteErrorHandler(APIRoute):
181+
"""Custom APIRoute that handles application errors and exceptions"""
182+
183+
# key: regex pattern for original error message from llama_cpp
184+
# value: formatter function
185+
pattern_and_formatters: dict[
186+
"Pattern",
187+
Callable[
188+
[
189+
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
190+
Match[str],
191+
],
192+
tuple[int, ErrorResponse],
193+
],
194+
] = {
195+
compile(
196+
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
197+
): ErrorResponseFormatters.context_length_exceeded,
198+
compile(
199+
r"Model path does not exist: (.+)"
200+
): ErrorResponseFormatters.model_not_found,
201+
}
202+
203+
def error_message_wrapper(
204+
self,
205+
error: Exception,
206+
body: Optional[
207+
Union[
208+
"CreateChatCompletionRequest",
209+
"CreateCompletionRequest",
210+
"CreateEmbeddingRequest",
211+
]
212+
] = None,
213+
) -> tuple[int, ErrorResponse]:
214+
"""Wraps error message in OpenAI style error response"""
215+
216+
if body is not None and isinstance(
217+
body,
218+
(
219+
CreateCompletionRequest,
220+
CreateChatCompletionRequest,
221+
),
222+
):
223+
# When text completion or chat completion
224+
for pattern, callback in self.pattern_and_formatters.items():
225+
match = pattern.search(str(error))
226+
if match is not None:
227+
return callback(body, match)
228+
229+
# Wrap other errors as internal server error
230+
return 500, ErrorResponse(
231+
message=str(error),
232+
type="internal_server_error",
233+
param=None,
234+
code=None,
235+
)
236+
237+
def get_route_handler(
238+
self,
239+
) -> Callable[[Request], Coroutine[None, None, Response]]:
240+
"""Defines custom route handler that catches exceptions and formats
241+
in OpenAI style error response"""
242+
243+
original_route_handler = super().get_route_handler()
244+
245+
async def custom_route_handler(request: Request) -> Response:
246+
try:
247+
return await original_route_handler(request)
248+
except Exception as exc:
249+
json_body = await request.json()
250+
try:
251+
if "messages" in json_body:
252+
# Chat completion
253+
body: Optional[
254+
Union[
255+
CreateChatCompletionRequest,
256+
CreateCompletionRequest,
257+
CreateEmbeddingRequest,
258+
]
259+
] = CreateChatCompletionRequest(**json_body)
260+
elif "prompt" in json_body:
261+
# Text completion
262+
body = CreateCompletionRequest(**json_body)
263+
else:
264+
# Embedding
265+
body = CreateEmbeddingRequest(**json_body)
266+
except Exception:
267+
# Invalid request body
268+
body = None
269+
270+
# Get proper error message from the exception
271+
(
272+
status_code,
273+
error_message,
274+
) = self.error_message_wrapper(error=exc, body=body)
275+
return JSONResponse(
276+
{"error": error_message},
277+
status_code=status_code,
278+
)
279+
280+
return custom_route_handler
281+
282+
283+
router = APIRouter(route_class=RouteErrorHandler)
98284

99285
settings: Optional[Settings] = None
100286
llama: Optional[llama_cpp.Llama] = None
@@ -183,10 +369,33 @@ def get_settings():
183369
yield settings
184370

185371

372+
async def get_event_publisher(
373+
request: Request,
374+
inner_send_chan: MemoryObjectSendStream,
375+
iterator: Iterator,
376+
):
377+
async with inner_send_chan:
378+
try:
379+
async for chunk in iterate_in_threadpool(iterator):
380+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
381+
if await request.is_disconnected():
382+
raise anyio.get_cancelled_exc_class()()
383+
if settings.interrupt_requests and llama_outer_lock.locked():
384+
await inner_send_chan.send(dict(data="[DONE]"))
385+
raise anyio.get_cancelled_exc_class()()
386+
await inner_send_chan.send(dict(data="[DONE]"))
387+
except anyio.get_cancelled_exc_class() as e:
388+
print("disconnected")
389+
with anyio.move_on_after(1, shield=True):
390+
print(
391+
f"Disconnected from client (via refresh/close) {request.client}"
392+
)
393+
raise e
394+
186395
model_field = Field(description="The model to use for generating completions.", default=None)
187396

188397
max_tokens_field = Field(
189-
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
398+
default=16, ge=1, description="The maximum number of tokens to generate."
190399
)
191400

192401
temperature_field = Field(
@@ -374,35 +583,31 @@ async def create_completion(
374583
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
375584
])
376585

377-
if body.stream:
378-
send_chan, recv_chan = anyio.create_memory_object_stream(10)
586+
iterator_or_completion: Union[llama_cpp.Completion, Iterator[
587+
llama_cpp.CompletionChunk
588+
]] = await run_in_threadpool(llama, **kwargs)
379589

380-
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
381-
async with inner_send_chan:
382-
try:
383-
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
384-
async for chunk in iterate_in_threadpool(iterator):
385-
await inner_send_chan.send(dict(data=json.dumps(chunk)))
386-
if await request.is_disconnected():
387-
raise anyio.get_cancelled_exc_class()()
388-
if settings.interrupt_requests and llama_outer_lock.locked():
389-
await inner_send_chan.send(dict(data="[DONE]"))
390-
raise anyio.get_cancelled_exc_class()()
391-
await inner_send_chan.send(dict(data="[DONE]"))
392-
except anyio.get_cancelled_exc_class() as e:
393-
print("disconnected")
394-
with anyio.move_on_after(1, shield=True):
395-
print(
396-
f"Disconnected from client (via refresh/close) {request.client}"
397-
)
398-
raise e
590+
if isinstance(iterator_or_completion, Iterator):
591+
# EAFP: It's easier to ask for forgiveness than permission
592+
first_response = await run_in_threadpool(next, iterator_or_completion)
399593

594+
# If no exception was raised from first_response, we can assume that
595+
# the iterator is valid and we can use it to stream the response.
596+
def iterator() -> Iterator[llama_cpp.CompletionChunk]:
597+
yield first_response
598+
yield from iterator_or_completion
599+
600+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
400601
return EventSourceResponse(
401-
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
402-
) # type: ignore
602+
recv_chan, data_sender_callable=partial( # type: ignore
603+
get_event_publisher,
604+
request=request,
605+
inner_send_chan=send_chan,
606+
iterator=iterator(),
607+
)
608+
)
403609
else:
404-
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
405-
return completion
610+
return iterator_or_completion
406611

407612

408613
class CreateEmbeddingRequest(BaseModel):
@@ -505,38 +710,31 @@ async def create_chat_completion(
505710
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
506711
])
507712

508-
if body.stream:
509-
send_chan, recv_chan = anyio.create_memory_object_stream(10)
713+
iterator_or_completion: Union[llama_cpp.ChatCompletion, Iterator[
714+
llama_cpp.ChatCompletionChunk
715+
]] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
510716

511-
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
512-
async with inner_send_chan:
513-
try:
514-
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
515-
async for chat_chunk in iterate_in_threadpool(iterator):
516-
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
517-
if await request.is_disconnected():
518-
raise anyio.get_cancelled_exc_class()()
519-
if settings.interrupt_requests and llama_outer_lock.locked():
520-
await inner_send_chan.send(dict(data="[DONE]"))
521-
raise anyio.get_cancelled_exc_class()()
522-
await inner_send_chan.send(dict(data="[DONE]"))
523-
except anyio.get_cancelled_exc_class() as e:
524-
print("disconnected")
525-
with anyio.move_on_after(1, shield=True):
526-
print(
527-
f"Disconnected from client (via refresh/close) {request.client}"
528-
)
529-
raise e
717+
if isinstance(iterator_or_completion, Iterator):
718+
# EAFP: It's easier to ask for forgiveness than permission
719+
first_response = await run_in_threadpool(next, iterator_or_completion)
720+
721+
# If no exception was raised from first_response, we can assume that
722+
# the iterator is valid and we can use it to stream the response.
723+
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
724+
yield first_response
725+
yield from iterator_or_completion
530726

727+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
531728
return EventSourceResponse(
532-
recv_chan,
533-
data_sender_callable=partial(event_publisher, send_chan),
534-
) # type: ignore
535-
else:
536-
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
537-
llama.create_chat_completion, **kwargs # type: ignore
729+
recv_chan, data_sender_callable=partial( # type: ignore
730+
get_event_publisher,
731+
request=request,
732+
inner_send_chan=send_chan,
733+
iterator=iterator(),
734+
)
538735
)
539-
return completion
736+
else:
737+
return iterator_or_completion
540738

541739

542740
class ModelData(TypedDict):

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.