diff --git a/src/agents/result.py b/src/agents/result.py index 0d8372c8..1f1c7832 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -75,7 +75,9 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) - def to_input_list(self) -> list[TResponseInputItem]: """Creates a new input list, merging the original input with all the new items generated.""" - original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) + original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( + self.input + ) new_items = [item.to_input_item() for item in self.new_items] return original_items + new_items @@ -152,6 +154,18 @@ def last_agent(self) -> Agent[Any]: """ return self.current_agent + def cancel(self) -> None: + """Cancels the streaming run, stopping all background tasks and marking the run as + complete.""" + self._cleanup_tasks() # Cancel all running tasks + self.is_complete = True # Mark the run as complete to stop event streaming + + # Optionally, clear the event queue to prevent processing stale events + while not self._event_queue.empty(): + self._event_queue.get_nowait() + while not self._input_guardrail_queue.empty(): + self._input_guardrail_queue.get_nowait() + async def stream_events(self) -> AsyncIterator[StreamEvent]: """Stream deltas for new items as they are generated. We're using the types from the OpenAI Responses API, so these are semantic events: each event has a `type` field that @@ -192,13 +206,17 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") + self._stored_exception = MaxTurnsExceeded( + f"Max turns ({self.max_turns}) exceeded" + ) # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) + self._stored_exception = InputGuardrailTripwireTriggered( + guardrail_result + ) # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py new file mode 100644 index 00000000..6d1807d7 --- /dev/null +++ b/tests/test_cancel_streaming.py @@ -0,0 +1,22 @@ +import pytest + +from agents import Agent, Runner + +from .fake_model import FakeModel + + +@pytest.mark.asyncio +async def test_joker_streamed_jokes_with_cancel(): + model = FakeModel() + agent = Agent(name="Joker", model=model) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + num_events = 0 + stop_after = 1 # There are two that the model gives back. + + async for _event in result.stream_events(): + num_events += 1 + if num_events == 1: + result.cancel() + + assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}"