Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e3698f3

Browse filesBrowse files
authored
Enable non-strict output types (#539)
See #528, some folks are having issues because their output types are not strict-compatible. My approach was: 1. Create `AgentOutputSchemaBase`, which represents the base methods for an output type - the json schema + validation 2. Make the existing `AgentOutputSchema` subclass `AgentOutputSchemaBase` 3. Allow users to pass a `AgentOutputSchemaBase` to `Agent(output_type=...)`
1 parent 4b8472d commit e3698f3
Copy full SHA for e3698f3

18 files changed

+256
-61
lines changed
+81Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import asyncio
2+
import json
3+
from dataclasses import dataclass
4+
from typing import Any
5+
6+
from agents import Agent, AgentOutputSchema, AgentOutputSchemaBase, Runner
7+
8+
"""This example demonstrates how to use an output type that is not in strict mode. Strict mode
9+
allows us to guarantee valid JSON output, but some schemas are not strict-compatible.
10+
11+
In this example, we define an output type that is not strict-compatible, and then we run the
12+
agent with strict_json_schema=False.
13+
14+
We also demonstrate a custom output type.
15+
16+
To understand which schemas are strict-compatible, see:
17+
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
18+
"""
19+
20+
21+
@dataclass
22+
class OutputType:
23+
jokes: dict[int, str]
24+
"""A list of jokes, indexed by joke number."""
25+
26+
27+
class CustomOutputSchema(AgentOutputSchemaBase):
28+
"""A demonstration of a custom output schema."""
29+
30+
def is_plain_text(self) -> bool:
31+
return False
32+
33+
def name(self) -> str:
34+
return "CustomOutputSchema"
35+
36+
def json_schema(self) -> dict[str, Any]:
37+
return {
38+
"type": "object",
39+
"properties": {"jokes": {"type": "object", "properties": {"joke": {"type": "string"}}}},
40+
}
41+
42+
def is_strict_json_schema(self) -> bool:
43+
return False
44+
45+
def validate_json(self, json_str: str) -> Any:
46+
json_obj = json.loads(json_str)
47+
# Just for demonstration, we'll return a list.
48+
return list(json_obj["jokes"].values())
49+
50+
51+
async def main():
52+
agent = Agent(
53+
name="Assistant",
54+
instructions="You are a helpful assistant.",
55+
output_type=OutputType,
56+
)
57+
58+
input = "Tell me 3 short jokes."
59+
60+
# First, let's try with a strict output type. This should raise an exception.
61+
try:
62+
result = await Runner.run(agent, input)
63+
raise AssertionError("Should have raised an exception")
64+
except Exception as e:
65+
print(f"Error (expected): {e}")
66+
67+
# Now let's try again with a non-strict output type. This should work.
68+
# In some cases, it will raise an error - the schema isn't strict, so the model may
69+
# produce an invalid JSON object.
70+
agent.output_type = AgentOutputSchema(OutputType, strict_json_schema=False)
71+
result = await Runner.run(agent, input)
72+
print(result.final_output)
73+
74+
# Finally, let's try a custom output type.
75+
agent.output_type = CustomOutputSchema()
76+
result = await Runner.run(agent, input)
77+
print(result.final_output)
78+
79+
80+
if __name__ == "__main__":
81+
asyncio.run(main())

‎src/agents/__init__.py

Copy file name to clipboardExpand all lines: src/agents/__init__.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from . import _config
88
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
9-
from .agent_output import AgentOutputSchema
9+
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
1010
from .computer import AsyncComputer, Button, Computer, Environment
1111
from .exceptions import (
1212
AgentsException,
@@ -158,6 +158,7 @@ def enable_verbose_stdout_logging():
158158
"OpenAIProvider",
159159
"OpenAIResponsesModel",
160160
"AgentOutputSchema",
161+
"AgentOutputSchemaBase",
161162
"Computer",
162163
"AsyncComputer",
163164
"Environment",

‎src/agents/_run_impl.py

Copy file name to clipboardExpand all lines: src/agents/_run_impl.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
3030

3131
from .agent import Agent, ToolsToFinalOutputResult
32-
from .agent_output import AgentOutputSchema
32+
from .agent_output import AgentOutputSchemaBase
3333
from .computer import AsyncComputer, Computer
3434
from .exceptions import AgentsException, ModelBehaviorError, UserError
3535
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
@@ -195,7 +195,7 @@ async def execute_tools_and_side_effects(
195195
pre_step_items: list[RunItem],
196196
new_response: ModelResponse,
197197
processed_response: ProcessedResponse,
198-
output_schema: AgentOutputSchema | None,
198+
output_schema: AgentOutputSchemaBase | None,
199199
hooks: RunHooks[TContext],
200200
context_wrapper: RunContextWrapper[TContext],
201201
run_config: RunConfig,
@@ -335,7 +335,7 @@ def process_model_response(
335335
agent: Agent[Any],
336336
all_tools: list[Tool],
337337
response: ModelResponse,
338-
output_schema: AgentOutputSchema | None,
338+
output_schema: AgentOutputSchemaBase | None,
339339
handoffs: list[Handoff],
340340
) -> ProcessedResponse:
341341
items: list[RunItem] = []

‎src/agents/agent.py

Copy file name to clipboardExpand all lines: src/agents/agent.py
+9-2Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from typing_extensions import NotRequired, TypeAlias, TypedDict
1010

11+
from .agent_output import AgentOutputSchemaBase
1112
from .guardrail import InputGuardrail, OutputGuardrail
1213
from .handoffs import Handoff
1314
from .items import ItemHelpers
@@ -141,8 +142,14 @@ class Agent(Generic[TContext]):
141142
Runs only if the agent produces a final output.
142143
"""
143144

144-
output_type: type[Any] | None = None
145-
"""The type of the output object. If not provided, the output will be `str`."""
145+
output_type: type[Any] | AgentOutputSchemaBase | None = None
146+
"""The type of the output object. If not provided, the output will be `str`. In most cases,
147+
you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
148+
You can customize this in two ways:
149+
1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
150+
2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
151+
creation, subclass and pass an `AgentOutputSchemaBase` subclass.
152+
"""
146153

147154
hooks: AgentHooks[TContext] | None = None
148155
"""A class that receives callbacks on various lifecycle events for this agent.

‎src/agents/agent_output.py

Copy file name to clipboardExpand all lines: src/agents/agent_output.py
+58-8Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import abc
12
from dataclasses import dataclass
23
from typing import Any
34

@@ -12,8 +13,46 @@
1213
_WRAPPER_DICT_KEY = "response"
1314

1415

16+
class AgentOutputSchemaBase(abc.ABC):
17+
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
18+
produced by the LLM into the output type.
19+
"""
20+
21+
@abc.abstractmethod
22+
def is_plain_text(self) -> bool:
23+
"""Whether the output type is plain text (versus a JSON object)."""
24+
pass
25+
26+
@abc.abstractmethod
27+
def name(self) -> str:
28+
"""The name of the output type."""
29+
pass
30+
31+
@abc.abstractmethod
32+
def json_schema(self) -> dict[str, Any]:
33+
"""Returns the JSON schema of the output. Will only be called if the output type is not
34+
plain text.
35+
"""
36+
pass
37+
38+
@abc.abstractmethod
39+
def is_strict_json_schema(self) -> bool:
40+
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
41+
features, but guarantees valis JSON. See here for details:
42+
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
43+
"""
44+
pass
45+
46+
@abc.abstractmethod
47+
def validate_json(self, json_str: str) -> Any:
48+
"""Validate a JSON string against the output type. You must return the validated object,
49+
or raise a `ModelBehaviorError` if the JSON is invalid.
50+
"""
51+
pass
52+
53+
1554
@dataclass(init=False)
16-
class AgentOutputSchema:
55+
class AgentOutputSchema(AgentOutputSchemaBase):
1756
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
1857
produced by the LLM into the output type.
1958
"""
@@ -32,7 +71,7 @@ class AgentOutputSchema:
3271
_output_schema: dict[str, Any]
3372
"""The JSON schema of the output."""
3473

35-
strict_json_schema: bool
74+
_strict_json_schema: bool
3675
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
3776
as it increases the likelihood of correct JSON input.
3877
"""
@@ -45,7 +84,7 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
4584
setting this to True, as it increases the likelihood of correct JSON input.
4685
"""
4786
self.output_type = output_type
48-
self.strict_json_schema = strict_json_schema
87+
self._strict_json_schema = strict_json_schema
4988

5089
if output_type is None or output_type is str:
5190
self._is_wrapped = False
@@ -70,24 +109,35 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
70109
self._type_adapter = TypeAdapter(output_type)
71110
self._output_schema = self._type_adapter.json_schema()
72111

73-
if self.strict_json_schema:
74-
self._output_schema = ensure_strict_json_schema(self._output_schema)
112+
if self._strict_json_schema:
113+
try:
114+
self._output_schema = ensure_strict_json_schema(self._output_schema)
115+
except UserError as e:
116+
raise UserError(
117+
"Strict JSON schema is enabled, but the output type is not valid. "
118+
"Either make the output type strict, or pass output_schema_strict=False to "
119+
"your Agent()"
120+
) from e
75121

76122
def is_plain_text(self) -> bool:
77123
"""Whether the output type is plain text (versus a JSON object)."""
78124
return self.output_type is None or self.output_type is str
79125

126+
def is_strict_json_schema(self) -> bool:
127+
"""Whether the JSON schema is in strict mode."""
128+
return self._strict_json_schema
129+
80130
def json_schema(self) -> dict[str, Any]:
81131
"""The JSON schema of the output type."""
82132
if self.is_plain_text():
83133
raise UserError("Output type is plain text, so no JSON schema is available")
84134
return self._output_schema
85135

86-
def validate_json(self, json_str: str, partial: bool = False) -> Any:
136+
def validate_json(self, json_str: str) -> Any:
87137
"""Validate a JSON string against the output type. Returns the validated object, or raises
88138
a `ModelBehaviorError` if the JSON is invalid.
89139
"""
90-
validated = _json.validate_json(json_str, self._type_adapter, partial)
140+
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
91141
if self._is_wrapped:
92142
if not isinstance(validated, dict):
93143
_error_tracing.attach_error_to_current_span(
@@ -113,7 +163,7 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any:
113163
return validated[_WRAPPER_DICT_KEY]
114164
return validated
115165

116-
def output_type_name(self) -> str:
166+
def name(self) -> str:
117167
"""The name of the output type."""
118168
return _type_to_str(self.output_type)
119169

‎src/agents/extensions/models/litellm_model.py

Copy file name to clipboardExpand all lines: src/agents/extensions/models/litellm_model.py
+6-6Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from openai.types.responses import Response
3030

3131
from ... import _debug
32-
from ...agent_output import AgentOutputSchema
32+
from ...agent_output import AgentOutputSchemaBase
3333
from ...handoffs import Handoff
3434
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
3535
from ...logger import logger
@@ -68,7 +68,7 @@ async def get_response(
6868
input: str | list[TResponseInputItem],
6969
model_settings: ModelSettings,
7070
tools: list[Tool],
71-
output_schema: AgentOutputSchema | None,
71+
output_schema: AgentOutputSchemaBase | None,
7272
handoffs: list[Handoff],
7373
tracing: ModelTracing,
7474
previous_response_id: str | None,
@@ -139,7 +139,7 @@ async def stream_response(
139139
input: str | list[TResponseInputItem],
140140
model_settings: ModelSettings,
141141
tools: list[Tool],
142-
output_schema: AgentOutputSchema | None,
142+
output_schema: AgentOutputSchemaBase | None,
143143
handoffs: list[Handoff],
144144
tracing: ModelTracing,
145145
*,
@@ -186,7 +186,7 @@ async def _fetch_response(
186186
input: str | list[TResponseInputItem],
187187
model_settings: ModelSettings,
188188
tools: list[Tool],
189-
output_schema: AgentOutputSchema | None,
189+
output_schema: AgentOutputSchemaBase | None,
190190
handoffs: list[Handoff],
191191
span: Span[GenerationSpanData],
192192
tracing: ModelTracing,
@@ -200,7 +200,7 @@ async def _fetch_response(
200200
input: str | list[TResponseInputItem],
201201
model_settings: ModelSettings,
202202
tools: list[Tool],
203-
output_schema: AgentOutputSchema | None,
203+
output_schema: AgentOutputSchemaBase | None,
204204
handoffs: list[Handoff],
205205
span: Span[GenerationSpanData],
206206
tracing: ModelTracing,
@@ -213,7 +213,7 @@ async def _fetch_response(
213213
input: str | list[TResponseInputItem],
214214
model_settings: ModelSettings,
215215
tools: list[Tool],
216-
output_schema: AgentOutputSchema | None,
216+
output_schema: AgentOutputSchemaBase | None,
217217
handoffs: list[Handoff],
218218
span: Span[GenerationSpanData],
219219
tracing: ModelTracing,

‎src/agents/models/chatcmpl_converter.py

Copy file name to clipboardExpand all lines: src/agents/models/chatcmpl_converter.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
)
3737
from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message
3838

39-
from ..agent_output import AgentOutputSchema
39+
from ..agent_output import AgentOutputSchemaBase
4040
from ..exceptions import AgentsException, UserError
4141
from ..handoffs import Handoff
4242
from ..items import TResponseInputItem, TResponseOutputItem
@@ -67,7 +67,7 @@ def convert_tool_choice(
6767

6868
@classmethod
6969
def convert_response_format(
70-
cls, final_output_schema: AgentOutputSchema | None
70+
cls, final_output_schema: AgentOutputSchemaBase | None
7171
) -> ResponseFormat | NotGiven:
7272
if not final_output_schema or final_output_schema.is_plain_text():
7373
return NOT_GIVEN
@@ -76,7 +76,7 @@ def convert_response_format(
7676
"type": "json_schema",
7777
"json_schema": {
7878
"name": "final_output",
79-
"strict": final_output_schema.strict_json_schema,
79+
"strict": final_output_schema.is_strict_json_schema(),
8080
"schema": final_output_schema.json_schema(),
8181
},
8282
}

‎src/agents/models/interface.py

Copy file name to clipboardExpand all lines: src/agents/models/interface.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncIterator
66
from typing import TYPE_CHECKING
77

8-
from ..agent_output import AgentOutputSchema
8+
from ..agent_output import AgentOutputSchemaBase
99
from ..handoffs import Handoff
1010
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
1111
from ..tool import Tool
@@ -41,7 +41,7 @@ async def get_response(
4141
input: str | list[TResponseInputItem],
4242
model_settings: ModelSettings,
4343
tools: list[Tool],
44-
output_schema: AgentOutputSchema | None,
44+
output_schema: AgentOutputSchemaBase | None,
4545
handoffs: list[Handoff],
4646
tracing: ModelTracing,
4747
*,
@@ -72,7 +72,7 @@ def stream_response(
7272
input: str | list[TResponseInputItem],
7373
model_settings: ModelSettings,
7474
tools: list[Tool],
75-
output_schema: AgentOutputSchema | None,
75+
output_schema: AgentOutputSchemaBase | None,
7676
handoffs: list[Handoff],
7777
tracing: ModelTracing,
7878
*,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.