Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 65cae71

Browse filesBrowse files
authored
Extract chat completions streaming helpers (#523)
Small refactor. --- [//]: # (BEGIN SAPLING FOOTER) * #524 * __->__ #523
1 parent 80de53e commit 65cae71
Copy full SHA for 65cae71

File tree

2 files changed

+301
-275
lines changed
Filter options

2 files changed

+301
-275
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import AsyncIterator
4+
from dataclasses import dataclass, field
5+
6+
from openai import AsyncStream
7+
from openai.types.chat import ChatCompletionChunk
8+
from openai.types.completion_usage import CompletionUsage
9+
from openai.types.responses import (
10+
Response,
11+
ResponseCompletedEvent,
12+
ResponseContentPartAddedEvent,
13+
ResponseContentPartDoneEvent,
14+
ResponseCreatedEvent,
15+
ResponseFunctionCallArgumentsDeltaEvent,
16+
ResponseFunctionToolCall,
17+
ResponseOutputItem,
18+
ResponseOutputItemAddedEvent,
19+
ResponseOutputItemDoneEvent,
20+
ResponseOutputMessage,
21+
ResponseOutputRefusal,
22+
ResponseOutputText,
23+
ResponseRefusalDeltaEvent,
24+
ResponseTextDeltaEvent,
25+
ResponseUsage,
26+
)
27+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
28+
29+
from ..items import TResponseStreamEvent
30+
from .fake_id import FAKE_RESPONSES_ID
31+
32+
33+
@dataclass
34+
class StreamingState:
35+
started: bool = False
36+
text_content_index_and_output: tuple[int, ResponseOutputText] | None = None
37+
refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
38+
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
39+
40+
41+
class ChatCmplStreamHandler:
42+
@classmethod
43+
async def handle_stream(
44+
cls,
45+
response: Response,
46+
stream: AsyncStream[ChatCompletionChunk],
47+
) -> AsyncIterator[TResponseStreamEvent]:
48+
usage: CompletionUsage | None = None
49+
state = StreamingState()
50+
51+
async for chunk in stream:
52+
if not state.started:
53+
state.started = True
54+
yield ResponseCreatedEvent(
55+
response=response,
56+
type="response.created",
57+
)
58+
59+
usage = chunk.usage
60+
61+
if not chunk.choices or not chunk.choices[0].delta:
62+
continue
63+
64+
delta = chunk.choices[0].delta
65+
66+
# Handle text
67+
if delta.content:
68+
if not state.text_content_index_and_output:
69+
# Initialize a content tracker for streaming text
70+
state.text_content_index_and_output = (
71+
0 if not state.refusal_content_index_and_output else 1,
72+
ResponseOutputText(
73+
text="",
74+
type="output_text",
75+
annotations=[],
76+
),
77+
)
78+
# Start a new assistant message stream
79+
assistant_item = ResponseOutputMessage(
80+
id=FAKE_RESPONSES_ID,
81+
content=[],
82+
role="assistant",
83+
type="message",
84+
status="in_progress",
85+
)
86+
# Notify consumers of the start of a new output message + first content part
87+
yield ResponseOutputItemAddedEvent(
88+
item=assistant_item,
89+
output_index=0,
90+
type="response.output_item.added",
91+
)
92+
yield ResponseContentPartAddedEvent(
93+
content_index=state.text_content_index_and_output[0],
94+
item_id=FAKE_RESPONSES_ID,
95+
output_index=0,
96+
part=ResponseOutputText(
97+
text="",
98+
type="output_text",
99+
annotations=[],
100+
),
101+
type="response.content_part.added",
102+
)
103+
# Emit the delta for this segment of content
104+
yield ResponseTextDeltaEvent(
105+
content_index=state.text_content_index_and_output[0],
106+
delta=delta.content,
107+
item_id=FAKE_RESPONSES_ID,
108+
output_index=0,
109+
type="response.output_text.delta",
110+
)
111+
# Accumulate the text into the response part
112+
state.text_content_index_and_output[1].text += delta.content
113+
114+
# Handle refusals (model declines to answer)
115+
if delta.refusal:
116+
if not state.refusal_content_index_and_output:
117+
# Initialize a content tracker for streaming refusal text
118+
state.refusal_content_index_and_output = (
119+
0 if not state.text_content_index_and_output else 1,
120+
ResponseOutputRefusal(refusal="", type="refusal"),
121+
)
122+
# Start a new assistant message if one doesn't exist yet (in-progress)
123+
assistant_item = ResponseOutputMessage(
124+
id=FAKE_RESPONSES_ID,
125+
content=[],
126+
role="assistant",
127+
type="message",
128+
status="in_progress",
129+
)
130+
# Notify downstream that assistant message + first content part are starting
131+
yield ResponseOutputItemAddedEvent(
132+
item=assistant_item,
133+
output_index=0,
134+
type="response.output_item.added",
135+
)
136+
yield ResponseContentPartAddedEvent(
137+
content_index=state.refusal_content_index_and_output[0],
138+
item_id=FAKE_RESPONSES_ID,
139+
output_index=0,
140+
part=ResponseOutputText(
141+
text="",
142+
type="output_text",
143+
annotations=[],
144+
),
145+
type="response.content_part.added",
146+
)
147+
# Emit the delta for this segment of refusal
148+
yield ResponseRefusalDeltaEvent(
149+
content_index=state.refusal_content_index_and_output[0],
150+
delta=delta.refusal,
151+
item_id=FAKE_RESPONSES_ID,
152+
output_index=0,
153+
type="response.refusal.delta",
154+
)
155+
# Accumulate the refusal string in the output part
156+
state.refusal_content_index_and_output[1].refusal += delta.refusal
157+
158+
# Handle tool calls
159+
# Because we don't know the name of the function until the end of the stream, we'll
160+
# save everything and yield events at the end
161+
if delta.tool_calls:
162+
for tc_delta in delta.tool_calls:
163+
if tc_delta.index not in state.function_calls:
164+
state.function_calls[tc_delta.index] = ResponseFunctionToolCall(
165+
id=FAKE_RESPONSES_ID,
166+
arguments="",
167+
name="",
168+
type="function_call",
169+
call_id="",
170+
)
171+
tc_function = tc_delta.function
172+
173+
state.function_calls[tc_delta.index].arguments += (
174+
tc_function.arguments if tc_function else ""
175+
) or ""
176+
state.function_calls[tc_delta.index].name += (
177+
tc_function.name if tc_function else ""
178+
) or ""
179+
state.function_calls[tc_delta.index].call_id += tc_delta.id or ""
180+
181+
function_call_starting_index = 0
182+
if state.text_content_index_and_output:
183+
function_call_starting_index += 1
184+
# Send end event for this content part
185+
yield ResponseContentPartDoneEvent(
186+
content_index=state.text_content_index_and_output[0],
187+
item_id=FAKE_RESPONSES_ID,
188+
output_index=0,
189+
part=state.text_content_index_and_output[1],
190+
type="response.content_part.done",
191+
)
192+
193+
if state.refusal_content_index_and_output:
194+
function_call_starting_index += 1
195+
# Send end event for this content part
196+
yield ResponseContentPartDoneEvent(
197+
content_index=state.refusal_content_index_and_output[0],
198+
item_id=FAKE_RESPONSES_ID,
199+
output_index=0,
200+
part=state.refusal_content_index_and_output[1],
201+
type="response.content_part.done",
202+
)
203+
204+
# Actually send events for the function calls
205+
for function_call in state.function_calls.values():
206+
# First, a ResponseOutputItemAdded for the function call
207+
yield ResponseOutputItemAddedEvent(
208+
item=ResponseFunctionToolCall(
209+
id=FAKE_RESPONSES_ID,
210+
call_id=function_call.call_id,
211+
arguments=function_call.arguments,
212+
name=function_call.name,
213+
type="function_call",
214+
),
215+
output_index=function_call_starting_index,
216+
type="response.output_item.added",
217+
)
218+
# Then, yield the args
219+
yield ResponseFunctionCallArgumentsDeltaEvent(
220+
delta=function_call.arguments,
221+
item_id=FAKE_RESPONSES_ID,
222+
output_index=function_call_starting_index,
223+
type="response.function_call_arguments.delta",
224+
)
225+
# Finally, the ResponseOutputItemDone
226+
yield ResponseOutputItemDoneEvent(
227+
item=ResponseFunctionToolCall(
228+
id=FAKE_RESPONSES_ID,
229+
call_id=function_call.call_id,
230+
arguments=function_call.arguments,
231+
name=function_call.name,
232+
type="function_call",
233+
),
234+
output_index=function_call_starting_index,
235+
type="response.output_item.done",
236+
)
237+
238+
# Finally, send the Response completed event
239+
outputs: list[ResponseOutputItem] = []
240+
if state.text_content_index_and_output or state.refusal_content_index_and_output:
241+
assistant_msg = ResponseOutputMessage(
242+
id=FAKE_RESPONSES_ID,
243+
content=[],
244+
role="assistant",
245+
type="message",
246+
status="completed",
247+
)
248+
if state.text_content_index_and_output:
249+
assistant_msg.content.append(state.text_content_index_and_output[1])
250+
if state.refusal_content_index_and_output:
251+
assistant_msg.content.append(state.refusal_content_index_and_output[1])
252+
outputs.append(assistant_msg)
253+
254+
# send a ResponseOutputItemDone for the assistant message
255+
yield ResponseOutputItemDoneEvent(
256+
item=assistant_msg,
257+
output_index=0,
258+
type="response.output_item.done",
259+
)
260+
261+
for function_call in state.function_calls.values():
262+
outputs.append(function_call)
263+
264+
final_response = response.model_copy()
265+
final_response.output = outputs
266+
final_response.usage = (
267+
ResponseUsage(
268+
input_tokens=usage.prompt_tokens,
269+
output_tokens=usage.completion_tokens,
270+
total_tokens=usage.total_tokens,
271+
output_tokens_details=OutputTokensDetails(
272+
reasoning_tokens=usage.completion_tokens_details.reasoning_tokens
273+
if usage.completion_tokens_details
274+
and usage.completion_tokens_details.reasoning_tokens
275+
else 0
276+
),
277+
input_tokens_details=InputTokensDetails(
278+
cached_tokens=usage.prompt_tokens_details.cached_tokens
279+
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens
280+
else 0
281+
),
282+
)
283+
if usage
284+
else None
285+
)
286+
287+
yield ResponseCompletedEvent(
288+
response=final_response,
289+
type="response.completed",
290+
)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.