From 2560ad0664cbb805bc3c607d69cff327f0c07425 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 23 Apr 2025 10:28:40 -0700 Subject: [PATCH 1/5] allow cancel out of the streaming result --- src/agents/result.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/agents/result.py b/src/agents/result.py index 0d8372c8..59341f7c 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -152,6 +152,17 @@ def last_agent(self) -> Agent[Any]: """ return self.current_agent + def cancel(self) -> None: + """Cancels the streaming run, stopping all background tasks and marking the run as complete.""" + self._cleanup_tasks() # Cancel all running tasks + self.is_complete = True # Mark the run as complete to stop event streaming + + # Optionally, clear the event queue to prevent processing stale events + while not self._event_queue.empty(): + self._event_queue.get_nowait() + while not self._input_guardrail_queue.empty(): + self._input_guardrail_queue.get_nowait() + async def stream_events(self) -> AsyncIterator[StreamEvent]: """Stream deltas for new items as they are generated. We're using the types from the OpenAI Responses API, so these are semantic events: each event has a `type` field that @@ -174,6 +185,7 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: try: item = await self._event_queue.get() except asyncio.CancelledError: + self.cancel() # Ensure tasks are cleaned up if the coroutine is cancelled break if isinstance(item, QueueCompleteSentinel): From 58aaa6a7c38bde503a5bbc2f35459479e2b1951b Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 23 Apr 2025 10:44:04 -0700 Subject: [PATCH 2/5] add test --- tests/test_cancel_streaming.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 tests/test_cancel_streaming.py diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py new file mode 100644 index 00000000..a805b08b --- /dev/null +++ b/tests/test_cancel_streaming.py @@ -0,0 +1,22 @@ +import pytest +from openai.types.responses import ResponseTextDeltaEvent +from agents import Agent, Runner + + +@pytest.mark.asyncio +async def test_joker_streamed_jokes_with_cancel(): + agent = Agent( + name="Joker", + instructions="You are a helpful assistant.", + ) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + num_visible_event = 0 + + async for event in result.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + num_visible_event += 1 + if num_visible_event == 3: + result.cancel() + + assert num_visible_event == 3, f"Expected 3 visible events, but got {num_visible_event}" From 0cd5d7b5aefee1b72b28213c95065483a2a33e00 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 23 Apr 2025 16:35:29 -0700 Subject: [PATCH 3/5] linted & fixed test --- src/agents/result.py | 15 +++++++++++---- tests/test_cancel_streaming.py | 24 ++++++++++++------------ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/agents/result.py b/src/agents/result.py index 59341f7c..5c7f4d24 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -75,7 +75,9 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) - def to_input_list(self) -> list[TResponseInputItem]: """Creates a new input list, merging the original input with all the new items generated.""" - original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) + original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( + self.input + ) new_items = [item.to_input_item() for item in self.new_items] return original_items + new_items @@ -153,7 +155,8 @@ def last_agent(self) -> Agent[Any]: return self.current_agent def cancel(self) -> None: - """Cancels the streaming run, stopping all background tasks and marking the run as complete.""" + """Cancels the streaming run, stopping all background tasks and marking the run as + complete.""" self._cleanup_tasks() # Cancel all running tasks self.is_complete = True # Mark the run as complete to stop event streaming @@ -204,13 +207,17 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + self._stored_exception = MaxTurnsExceeded( + f"Max turns ({self.max_turns}) exceeded" + ) # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) + self._stored_exception = InputGuardrailTripwireTriggered( + guardrail_result + ) # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py index a805b08b..c52d145c 100644 --- a/tests/test_cancel_streaming.py +++ b/tests/test_cancel_streaming.py @@ -1,22 +1,22 @@ import pytest from openai.types.responses import ResponseTextDeltaEvent +from .fake_model import FakeModel + from agents import Agent, Runner @pytest.mark.asyncio async def test_joker_streamed_jokes_with_cancel(): - agent = Agent( - name="Joker", - instructions="You are a helpful assistant.", - ) + model = FakeModel() + agent = Agent(name="Joker", model=model) result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") - num_visible_event = 0 - + num_events = 0 + stop_after = 1 # There are two that the model gives back. + async for event in result.stream_events(): - if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): - num_visible_event += 1 - if num_visible_event == 3: - result.cancel() - - assert num_visible_event == 3, f"Expected 3 visible events, but got {num_visible_event}" + num_events += 1 + if num_events == 1: + result.cancel() + + assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}" From a7567205897b920e145e44508b2813d2cef1a740 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 23 Apr 2025 16:36:48 -0700 Subject: [PATCH 4/5] re-lint the test --- tests/test_cancel_streaming.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py index c52d145c..6d1807d7 100644 --- a/tests/test_cancel_streaming.py +++ b/tests/test_cancel_streaming.py @@ -1,9 +1,9 @@ import pytest -from openai.types.responses import ResponseTextDeltaEvent -from .fake_model import FakeModel from agents import Agent, Runner +from .fake_model import FakeModel + @pytest.mark.asyncio async def test_joker_streamed_jokes_with_cancel(): @@ -14,7 +14,7 @@ async def test_joker_streamed_jokes_with_cancel(): num_events = 0 stop_after = 1 # There are two that the model gives back. - async for event in result.stream_events(): + async for _event in result.stream_events(): num_events += 1 if num_events == 1: result.cancel() From 1c8cf972c6195d269fa44835610a36b5f32e14e7 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Wed, 23 Apr 2025 16:42:00 -0700 Subject: [PATCH 5/5] unnecessary cancel before the break, since it will be handled outside of the loop anyway --- src/agents/result.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agents/result.py b/src/agents/result.py index 5c7f4d24..1f1c7832 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -188,7 +188,6 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: try: item = await self._event_queue.get() except asyncio.CancelledError: - self.cancel() # Ensure tasks are cleaned up if the coroutine is cancelled break if isinstance(item, QueueCompleteSentinel):