Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e8f14ce

Browse filesBrowse files
gjpowerabetlen
andauthored
fix: streaming resource lock (abetlen#1879)
* fix: correct issue with handling lock during streaming move locking for streaming into get_event_publisher call so it is locked and unlocked in the correct task for the streaming reponse * fix: simplify exit stack management for create_chat_completion and create_completion * fix: correct missing `async with` and format code * fix: remove unnecessary explicit use of AsyncExitStack fix: correct type hints for body_model --------- Co-authored-by: Andrei <abetlen@gmail.com>
1 parent 1d5f534 commit e8f14ce
Copy full SHA for e8f14ce

File tree

Expand file treeCollapse file tree

1 file changed

+103
-121
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+103
-121
lines changed

‎llama_cpp/server/app.py

Copy file name to clipboardExpand all lines: llama_cpp/server/app.py
+103-121Lines changed: 103 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from anyio import Lock
99
from functools import partial
10-
from typing import Iterator, List, Optional, Union, Dict
10+
from typing import List, Optional, Union, Dict
1111

1212
import llama_cpp
1313

@@ -155,34 +155,71 @@ def create_app(
155155
return app
156156

157157

158+
def prepare_request_resources(
159+
body: CreateCompletionRequest | CreateChatCompletionRequest,
160+
llama_proxy: LlamaProxy,
161+
body_model: str | None,
162+
kwargs,
163+
) -> llama_cpp.Llama:
164+
if llama_proxy is None:
165+
raise HTTPException(
166+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
167+
detail="Service is not available",
168+
)
169+
llama = llama_proxy(body_model)
170+
if body.logit_bias is not None:
171+
kwargs["logit_bias"] = (
172+
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
173+
if body.logit_bias_type == "tokens"
174+
else body.logit_bias
175+
)
176+
177+
if body.grammar is not None:
178+
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
179+
180+
if body.min_tokens > 0:
181+
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
182+
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
183+
)
184+
if "logits_processor" not in kwargs:
185+
kwargs["logits_processor"] = _min_tokens_logits_processor
186+
else:
187+
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
188+
return llama
189+
190+
158191
async def get_event_publisher(
159192
request: Request,
160193
inner_send_chan: MemoryObjectSendStream[typing.Any],
161-
iterator: Iterator[typing.Any],
162-
on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None,
194+
body: CreateCompletionRequest | CreateChatCompletionRequest,
195+
body_model: str | None,
196+
llama_call,
197+
kwargs,
163198
):
164199
server_settings = next(get_server_settings())
165200
interrupt_requests = (
166201
server_settings.interrupt_requests if server_settings else False
167202
)
168-
async with inner_send_chan:
169-
try:
170-
async for chunk in iterate_in_threadpool(iterator):
171-
await inner_send_chan.send(dict(data=json.dumps(chunk)))
172-
if await request.is_disconnected():
173-
raise anyio.get_cancelled_exc_class()()
174-
if interrupt_requests and llama_outer_lock.locked():
175-
await inner_send_chan.send(dict(data="[DONE]"))
176-
raise anyio.get_cancelled_exc_class()()
177-
await inner_send_chan.send(dict(data="[DONE]"))
178-
except anyio.get_cancelled_exc_class() as e:
179-
print("disconnected")
180-
with anyio.move_on_after(1, shield=True):
181-
print(f"Disconnected from client (via refresh/close) {request.client}")
182-
raise e
183-
finally:
184-
if on_complete:
185-
await on_complete()
203+
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
204+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
205+
async with inner_send_chan:
206+
try:
207+
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
208+
async for chunk in iterate_in_threadpool(iterator):
209+
await inner_send_chan.send(dict(data=json.dumps(chunk)))
210+
if await request.is_disconnected():
211+
raise anyio.get_cancelled_exc_class()()
212+
if interrupt_requests and llama_outer_lock.locked():
213+
await inner_send_chan.send(dict(data="[DONE]"))
214+
raise anyio.get_cancelled_exc_class()()
215+
await inner_send_chan.send(dict(data="[DONE]"))
216+
except anyio.get_cancelled_exc_class() as e:
217+
print("disconnected")
218+
with anyio.move_on_after(1, shield=True):
219+
print(
220+
f"Disconnected from client (via refresh/close) {request.client}"
221+
)
222+
raise e
186223

187224

188225
def _logit_bias_tokens_to_input_ids(
@@ -267,18 +304,11 @@ async def create_completion(
267304
request: Request,
268305
body: CreateCompletionRequest,
269306
) -> llama_cpp.Completion:
270-
exit_stack = contextlib.AsyncExitStack()
271-
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
272-
if llama_proxy is None:
273-
raise HTTPException(
274-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
275-
detail="Service is not available",
276-
)
277307
if isinstance(body.prompt, list):
278308
assert len(body.prompt) <= 1
279309
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
280310

281-
llama = llama_proxy(
311+
body_model = (
282312
body.model
283313
if request.url.path != "/v1/engines/copilot-codex/completions"
284314
else "copilot-codex"
@@ -293,60 +323,38 @@ async def create_completion(
293323
}
294324
kwargs = body.model_dump(exclude=exclude)
295325

296-
if body.logit_bias is not None:
297-
kwargs["logit_bias"] = (
298-
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
299-
if body.logit_bias_type == "tokens"
300-
else body.logit_bias
301-
)
302-
303-
if body.grammar is not None:
304-
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
305-
306-
if body.min_tokens > 0:
307-
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
308-
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
309-
)
310-
if "logits_processor" not in kwargs:
311-
kwargs["logits_processor"] = _min_tokens_logits_processor
312-
else:
313-
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
314-
315-
try:
316-
iterator_or_completion: Union[
317-
llama_cpp.CreateCompletionResponse,
318-
Iterator[llama_cpp.CreateCompletionStreamResponse],
319-
] = await run_in_threadpool(llama, **kwargs)
320-
except Exception as err:
321-
await exit_stack.aclose()
322-
raise err
323-
324-
if isinstance(iterator_or_completion, Iterator):
325-
# EAFP: It's easier to ask for forgiveness than permission
326-
first_response = await run_in_threadpool(next, iterator_or_completion)
327-
328-
# If no exception was raised from first_response, we can assume that
329-
# the iterator is valid and we can use it to stream the response.
330-
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
331-
yield first_response
332-
yield from iterator_or_completion
333-
326+
# handle streaming request
327+
if kwargs.get("stream", False):
334328
send_chan, recv_chan = anyio.create_memory_object_stream(10)
335329
return EventSourceResponse(
336330
recv_chan,
337331
data_sender_callable=partial( # type: ignore
338332
get_event_publisher,
339333
request=request,
340334
inner_send_chan=send_chan,
341-
iterator=iterator(),
342-
on_complete=exit_stack.aclose,
335+
body=body,
336+
body_model=body_model,
337+
llama_call=llama_cpp.Llama.__call__,
338+
kwargs=kwargs,
343339
),
344340
sep="\n",
345341
ping_message_factory=_ping_message_factory,
346342
)
347-
else:
348-
await exit_stack.aclose()
349-
return iterator_or_completion
343+
344+
# handle regular request
345+
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
346+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
347+
348+
if await request.is_disconnected():
349+
print(
350+
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
351+
)
352+
raise HTTPException(
353+
status_code=status.HTTP_400_BAD_REQUEST,
354+
detail="Client closed request",
355+
)
356+
357+
return await run_in_threadpool(llama, **kwargs)
350358

351359

352360
@router.post(
@@ -474,74 +482,48 @@ async def create_chat_completion(
474482
# where the dependency is cleaned up before a StreamingResponse
475483
# is complete.
476484
# https://github.com/tiangolo/fastapi/issues/11143
477-
exit_stack = contextlib.AsyncExitStack()
478-
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
479-
if llama_proxy is None:
480-
raise HTTPException(
481-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
482-
detail="Service is not available",
483-
)
485+
486+
body_model = body.model
484487
exclude = {
485488
"n",
486489
"logit_bias_type",
487490
"user",
488491
"min_tokens",
489492
}
490493
kwargs = body.model_dump(exclude=exclude)
491-
llama = llama_proxy(body.model)
492-
if body.logit_bias is not None:
493-
kwargs["logit_bias"] = (
494-
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
495-
if body.logit_bias_type == "tokens"
496-
else body.logit_bias
497-
)
498-
499-
if body.grammar is not None:
500-
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
501-
502-
if body.min_tokens > 0:
503-
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
504-
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
505-
)
506-
if "logits_processor" not in kwargs:
507-
kwargs["logits_processor"] = _min_tokens_logits_processor
508-
else:
509-
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
510-
511-
try:
512-
iterator_or_completion: Union[
513-
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
514-
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
515-
except Exception as err:
516-
await exit_stack.aclose()
517-
raise err
518-
519-
if isinstance(iterator_or_completion, Iterator):
520-
# EAFP: It's easier to ask for forgiveness than permission
521-
first_response = await run_in_threadpool(next, iterator_or_completion)
522-
523-
# If no exception was raised from first_response, we can assume that
524-
# the iterator is valid and we can use it to stream the response.
525-
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
526-
yield first_response
527-
yield from iterator_or_completion
528494

495+
# handle streaming request
496+
if kwargs.get("stream", False):
529497
send_chan, recv_chan = anyio.create_memory_object_stream(10)
530498
return EventSourceResponse(
531499
recv_chan,
532500
data_sender_callable=partial( # type: ignore
533501
get_event_publisher,
534502
request=request,
535503
inner_send_chan=send_chan,
536-
iterator=iterator(),
537-
on_complete=exit_stack.aclose,
504+
body=body,
505+
body_model=body_model,
506+
llama_call=llama_cpp.Llama.create_chat_completion,
507+
kwargs=kwargs,
538508
),
539509
sep="\n",
540510
ping_message_factory=_ping_message_factory,
541511
)
542-
else:
543-
await exit_stack.aclose()
544-
return iterator_or_completion
512+
513+
# handle regular request
514+
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
515+
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
516+
517+
if await request.is_disconnected():
518+
print(
519+
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
520+
)
521+
raise HTTPException(
522+
status_code=status.HTTP_400_BAD_REQUEST,
523+
detail="Client closed request",
524+
)
525+
526+
return await run_in_threadpool(llama.create_chat_completion, **kwargs)
545527

546528

547529
@router.get(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.