Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a113fea

Browse filesBrowse files
authored
Allow cancel out of the streaming result (openai#579)
Fix for openai#574 @rm-openai I'm not sure how to add a test within the repo but I have pasted a test script below that seems to work ```python import asyncio from openai.types.responses import ResponseTextDeltaEvent from agents import Agent, Runner async def main(): agent = Agent( name="Joker", instructions="You are a helpful assistant.", ) result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") num_visible_event = 0 async for event in result.stream_events(): if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): print(event.data.delta, end="", flush=True) num_visible_event += 1 print(num_visible_event) if num_visible_event == 3: result.cancel() if __name__ == "__main__": asyncio.run(main()) ````
1 parent 178020e commit a113fea
Copy full SHA for a113fea

File tree

2 files changed

+43
-3
lines changed
Filter options

2 files changed

+43
-3
lines changed

‎src/agents/result.py

Copy file name to clipboardExpand all lines: src/agents/result.py
+21-3Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) -
7575

7676
def to_input_list(self) -> list[TResponseInputItem]:
7777
"""Creates a new input list, merging the original input with all the new items generated."""
78-
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
78+
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
79+
self.input
80+
)
7981
new_items = [item.to_input_item() for item in self.new_items]
8082

8183
return original_items + new_items
@@ -152,6 +154,18 @@ def last_agent(self) -> Agent[Any]:
152154
"""
153155
return self.current_agent
154156

157+
def cancel(self) -> None:
158+
"""Cancels the streaming run, stopping all background tasks and marking the run as
159+
complete."""
160+
self._cleanup_tasks() # Cancel all running tasks
161+
self.is_complete = True # Mark the run as complete to stop event streaming
162+
163+
# Optionally, clear the event queue to prevent processing stale events
164+
while not self._event_queue.empty():
165+
self._event_queue.get_nowait()
166+
while not self._input_guardrail_queue.empty():
167+
self._input_guardrail_queue.get_nowait()
168+
155169
async def stream_events(self) -> AsyncIterator[StreamEvent]:
156170
"""Stream deltas for new items as they are generated. We're using the types from the
157171
OpenAI Responses API, so these are semantic events: each event has a `type` field that
@@ -192,13 +206,17 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
192206

193207
def _check_errors(self):
194208
if self.current_turn > self.max_turns:
195-
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
209+
self._stored_exception = MaxTurnsExceeded(
210+
f"Max turns ({self.max_turns}) exceeded"
211+
)
196212

197213
# Fetch all the completed guardrail results from the queue and raise if needed
198214
while not self._input_guardrail_queue.empty():
199215
guardrail_result = self._input_guardrail_queue.get_nowait()
200216
if guardrail_result.output.tripwire_triggered:
201-
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
217+
self._stored_exception = InputGuardrailTripwireTriggered(
218+
guardrail_result
219+
)
202220

203221
# Check the tasks for any exceptions
204222
if self._run_impl_task and self._run_impl_task.done():

‎tests/test_cancel_streaming.py

Copy file name to clipboard
+22Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
3+
from agents import Agent, Runner
4+
5+
from .fake_model import FakeModel
6+
7+
8+
@pytest.mark.asyncio
9+
async def test_joker_streamed_jokes_with_cancel():
10+
model = FakeModel()
11+
agent = Agent(name="Joker", model=model)
12+
13+
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
14+
num_events = 0
15+
stop_after = 1 # There are two that the model gives back.
16+
17+
async for _event in result.stream_events():
18+
num_events += 1
19+
if num_events == 1:
20+
result.cancel()
21+
22+
assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}"

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.