Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 80066f0

Browse filesBrowse files
committed
Use async routes
1 parent c2b59a5 commit 80066f0
Copy full SHA for 80066f0

File tree

Expand file treeCollapse file tree

1 file changed

+88
-57
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+88
-57
lines changed

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+88-57Lines changed: 88 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import json
22
import multiprocessing
33
from threading import Lock
4-
from typing import List, Optional, Union, Iterator, Dict
4+
from functools import partial
5+
from typing import Iterator, List, Optional, Union, Dict
56
from typing_extensions import TypedDict, Literal
67

78
import llama_cpp
89

9-
from fastapi import Depends, FastAPI, APIRouter
10+
import anyio
11+
from anyio.streams.memory import MemoryObjectSendStream
12+
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
13+
from fastapi import Depends, FastAPI, APIRouter, Request
1014
from fastapi.middleware.cors import CORSMiddleware
1115
from pydantic import BaseModel, BaseSettings, Field, create_model_from_typeddict
1216
from sse_starlette.sse import EventSourceResponse
@@ -241,35 +245,49 @@ class Config:
241245
"/v1/completions",
242246
response_model=CreateCompletionResponse,
243247
)
244-
def create_completion(
245-
request: CreateCompletionRequest, llama: llama_cpp.Llama = Depends(get_llama)
248+
async def create_completion(
249+
request: Request,
250+
body: CreateCompletionRequest,
251+
llama: llama_cpp.Llama = Depends(get_llama),
246252
):
247-
if isinstance(request.prompt, list):
248-
assert len(request.prompt) <= 1
249-
request.prompt = request.prompt[0] if len(request.prompt) > 0 else ""
250-
251-
completion_or_chunks = llama(
252-
**request.dict(
253-
exclude={
254-
"n",
255-
"best_of",
256-
"logit_bias",
257-
"user",
258-
}
259-
)
260-
)
261-
if request.stream:
262-
263-
async def server_sent_events(
264-
chunks: Iterator[llama_cpp.CompletionChunk],
265-
):
266-
for chunk in chunks:
267-
yield dict(data=json.dumps(chunk))
253+
if isinstance(body.prompt, list):
254+
assert len(body.prompt) <= 1
255+
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
256+
257+
exclude = {
258+
"n",
259+
"best_of",
260+
"logit_bias",
261+
"user",
262+
}
263+
kwargs = body.dict(exclude=exclude)
264+
if body.stream:
265+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
266+
267+
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
268+
async with inner_send_chan:
269+
try:
270+
iterator: Iterator[llama_cpp.CompletionChunk] = await run_in_threadpool(llama, **kwargs) # type: ignore
271+
async for chunk in iterate_in_threadpool(iterator):
272+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
273+
if await request.is_disconnected():
274+
raise anyio.get_cancelled_exc_class()()
275+
await inner_send_chan.send(dict(data="[DONE]"))
276+
except anyio.get_cancelled_exc_class() as e:
277+
print("disconnected")
278+
with anyio.move_on_after(1, shield=True):
279+
print(
280+
f"Disconnected from client (via refresh/close) {request.client}"
281+
)
282+
await inner_send_chan.send(dict(closing=True))
283+
raise e
268284

269-
chunks: Iterator[llama_cpp.CompletionChunk] = completion_or_chunks # type: ignore
270-
return EventSourceResponse(server_sent_events(chunks))
271-
completion: llama_cpp.Completion = completion_or_chunks # type: ignore
272-
return completion
285+
return EventSourceResponse(
286+
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
287+
)
288+
else:
289+
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
290+
return completion
273291

274292

275293
class CreateEmbeddingRequest(BaseModel):
@@ -292,10 +310,12 @@ class Config:
292310
"/v1/embeddings",
293311
response_model=CreateEmbeddingResponse,
294312
)
295-
def create_embedding(
313+
async def create_embedding(
296314
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
297315
):
298-
return llama.create_embedding(**request.dict(exclude={"user"}))
316+
return await run_in_threadpool(
317+
llama.create_embedding, **request.dict(exclude={"user"})
318+
)
299319

300320

301321
class ChatCompletionRequestMessage(BaseModel):
@@ -349,36 +369,47 @@ class Config:
349369
"/v1/chat/completions",
350370
response_model=CreateChatCompletionResponse,
351371
)
352-
def create_chat_completion(
353-
request: CreateChatCompletionRequest,
372+
async def create_chat_completion(
373+
request: Request,
374+
body: CreateChatCompletionRequest,
354375
llama: llama_cpp.Llama = Depends(get_llama),
355376
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
356-
completion_or_chunks = llama.create_chat_completion(
357-
**request.dict(
358-
exclude={
359-
"n",
360-
"logit_bias",
361-
"user",
362-
}
363-
),
364-
)
365-
366-
if request.stream:
367-
368-
async def server_sent_events(
369-
chat_chunks: Iterator[llama_cpp.ChatCompletionChunk],
370-
):
371-
for chat_chunk in chat_chunks:
372-
yield dict(data=json.dumps(chat_chunk))
373-
yield dict(data="[DONE]")
374-
375-
chunks: Iterator[llama_cpp.ChatCompletionChunk] = completion_or_chunks # type: ignore
377+
exclude = {
378+
"n",
379+
"logit_bias",
380+
"user",
381+
}
382+
kwargs = body.dict(exclude=exclude)
383+
if body.stream:
384+
send_chan, recv_chan = anyio.create_memory_object_stream(10)
385+
386+
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
387+
async with inner_send_chan:
388+
try:
389+
iterator: Iterator[llama_cpp.ChatCompletionChunk] = await run_in_threadpool(llama.create_chat_completion, **kwargs) # type: ignore
390+
async for chat_chunk in iterate_in_threadpool(iterator):
391+
await inner_send_chan.send(dict(data=json.dumps(chat_chunk)))
392+
if await request.is_disconnected():
393+
raise anyio.get_cancelled_exc_class()()
394+
await inner_send_chan.send(dict(data="[DONE]"))
395+
except anyio.get_cancelled_exc_class() as e:
396+
print("disconnected")
397+
with anyio.move_on_after(1, shield=True):
398+
print(
399+
f"Disconnected from client (via refresh/close) {request.client}"
400+
)
401+
await inner_send_chan.send(dict(closing=True))
402+
raise e
376403

377404
return EventSourceResponse(
378-
server_sent_events(chunks),
405+
recv_chan,
406+
data_sender_callable=partial(event_publisher, send_chan),
407+
)
408+
else:
409+
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
410+
llama.create_chat_completion, **kwargs # type: ignore
379411
)
380-
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
381-
return completion
412+
return completion
382413

383414

384415
class ModelData(TypedDict):
@@ -397,7 +428,7 @@ class ModelList(TypedDict):
397428

398429

399430
@router.get("/v1/models", response_model=GetModelResponse)
400-
def get_models(
431+
async def get_models(
401432
settings: Settings = Depends(get_settings),
402433
llama: llama_cpp.Llama = Depends(get_llama),
403434
) -> ModelList:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.