Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3755ea8

Browse filesBrowse files
authored
Create to_json_dict for ModelSettings (openai#582)
Now that `ModelSettings` has `Reasoning`, a non-primitive object, `dataclasses.as_dict()` wont work. It will raise an error when you try to serialize (e.g. for tracing). This ensures the object is actually serializable.
1 parent a113fea commit 3755ea8
Copy full SHA for 3755ea8

File tree

7 files changed

+84
-15
lines changed
Filter options

7 files changed

+84
-15
lines changed

‎pyproject.toml

Copy file name to clipboardExpand all lines: pyproject.toml
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires-python = ">=3.9"
77
license = "MIT"
88
authors = [{ name = "OpenAI", email = "support@openai.com" }]
99
dependencies = [
10-
"openai>=1.66.5",
10+
"openai>=1.76.0",
1111
"pydantic>=2.10, <3",
1212
"griffe>=1.5.6, <2",
1313
"typing-extensions>=4.12.2, <5",

‎src/agents/extensions/models/litellm_model.py

Copy file name to clipboardExpand all lines: src/agents/extensions/models/litellm_model.py
+2-3Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import dataclasses
43
import json
54
import time
65
from collections.abc import AsyncIterator
@@ -75,7 +74,7 @@ async def get_response(
7574
) -> ModelResponse:
7675
with generation_span(
7776
model=str(self.model),
78-
model_config=dataclasses.asdict(model_settings)
77+
model_config=model_settings.to_json_dict()
7978
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
8079
disabled=tracing.is_disabled(),
8180
) as span_generation:
@@ -147,7 +146,7 @@ async def stream_response(
147146
) -> AsyncIterator[TResponseStreamEvent]:
148147
with generation_span(
149148
model=str(self.model),
150-
model_config=dataclasses.asdict(model_settings)
149+
model_config=model_settings.to_json_dict()
151150
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
152151
disabled=tracing.is_disabled(),
153152
) as span_generation:

‎src/agents/model_settings.py

Copy file name to clipboardExpand all lines: src/agents/model_settings.py
+16-1Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
from dataclasses import dataclass, fields, replace
4-
from typing import Literal
5+
from typing import Any, Literal
56

67
from openai._types import Body, Headers, Query
78
from openai.types.shared import Reasoning
9+
from pydantic import BaseModel
810

911

1012
@dataclass
@@ -83,3 +85,16 @@ def resolve(self, override: ModelSettings | None) -> ModelSettings:
8385
if getattr(override, field.name) is not None
8486
}
8587
return replace(self, **changes)
88+
89+
def to_json_dict(self) -> dict[str, Any]:
90+
dataclass_dict = dataclasses.asdict(self)
91+
92+
json_dict: dict[str, Any] = {}
93+
94+
for field_name, value in dataclass_dict.items():
95+
if isinstance(value, BaseModel):
96+
json_dict[field_name] = value.model_dump(mode="json")
97+
else:
98+
json_dict[field_name] = value
99+
100+
return json_dict

‎src/agents/models/openai_chatcompletions.py

Copy file name to clipboardExpand all lines: src/agents/models/openai_chatcompletions.py
+2-5Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import dataclasses
43
import json
54
import time
65
from collections.abc import AsyncIterator
@@ -56,8 +55,7 @@ async def get_response(
5655
) -> ModelResponse:
5756
with generation_span(
5857
model=str(self.model),
59-
model_config=dataclasses.asdict(model_settings)
60-
| {"base_url": str(self._client.base_url)},
58+
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
6159
disabled=tracing.is_disabled(),
6260
) as span_generation:
6361
response = await self._fetch_response(
@@ -121,8 +119,7 @@ async def stream_response(
121119
"""
122120
with generation_span(
123121
model=str(self.model),
124-
model_config=dataclasses.asdict(model_settings)
125-
| {"base_url": str(self._client.base_url)},
122+
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
126123
disabled=tracing.is_disabled(),
127124
) as span_generation:
128125
response, stream = await self._fetch_response(
+59Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import json
2+
from dataclasses import fields
3+
4+
from openai.types.shared import Reasoning
5+
6+
from agents.model_settings import ModelSettings
7+
8+
9+
def verify_serialization(model_settings: ModelSettings) -> None:
10+
"""Verify that ModelSettings can be serialized to a JSON string."""
11+
json_dict = model_settings.to_json_dict()
12+
json_string = json.dumps(json_dict)
13+
assert json_string is not None
14+
15+
16+
def test_basic_serialization() -> None:
17+
"""Tests whether ModelSettings can be serialized to a JSON string."""
18+
19+
# First, lets create a ModelSettings instance
20+
model_settings = ModelSettings(
21+
temperature=0.5,
22+
top_p=0.9,
23+
max_tokens=100,
24+
)
25+
26+
# Now, lets serialize the ModelSettings instance to a JSON string
27+
verify_serialization(model_settings)
28+
29+
30+
def test_all_fields_serialization() -> None:
31+
"""Tests whether ModelSettings can be serialized to a JSON string."""
32+
33+
# First, lets create a ModelSettings instance
34+
model_settings = ModelSettings(
35+
temperature=0.5,
36+
top_p=0.9,
37+
frequency_penalty=0.0,
38+
presence_penalty=0.0,
39+
tool_choice="auto",
40+
parallel_tool_calls=True,
41+
truncation="auto",
42+
max_tokens=100,
43+
reasoning=Reasoning(),
44+
metadata={"foo": "bar"},
45+
store=False,
46+
include_usage=False,
47+
extra_query={"foo": "bar"},
48+
extra_body={"foo": "bar"},
49+
extra_headers={"foo": "bar"},
50+
)
51+
52+
# Verify that every single field is set to a non-None value
53+
for field in fields(model_settings):
54+
assert getattr(model_settings, field.name) is not None, (
55+
f"You must set the {field.name} field"
56+
)
57+
58+
# Now, lets serialize the ModelSettings instance to a JSON string
59+
verify_serialization(model_settings)

‎tests/voice/conftest.py

Copy file name to clipboardExpand all lines: tests/voice/conftest.py
-1Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ def pytest_ignore_collect(collection_path, config):
99

1010
if str(collection_path).startswith(this_dir):
1111
return True
12-

‎uv.lock

Copy file name to clipboardExpand all lines: uv.lock
+4-4Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.