Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a0254b0

Browse filesBrowse files
RFC: automatically use litellm if possible (#534)
## Summary This replaces the default model provider with a `MultiProvider`, which has the logic: - if the model name starts with `openai/` or doesn't contain "/", use OpenAI - if the model name starts with `litellm/`, use LiteLLM to use the appropriate model provider. It's also extensible, so users can create their own mappings. I also imagine that if we natively supported Anthropic/Gemini etc, we can add it to MultiProvider to make it work. The goal is that it should be really easy to use any model provider. Today if you pass `model="gpt-4.1"`, it works great. But `model="claude-sonnet-3.7"` doesn't. If we can make it that easy, it's a win for devx. I'm not entirely sure if this is a good idea - is it too magical? Is the API too reliant on litellm? Comments welcome. ## Test plan For now, the example. Will add unit tests if we agree its worth mergin. --------- Co-authored-by: Steven Heidel <steven@heidel.ca>
1 parent 0a3dfa0 commit a0254b0
Copy full SHA for a0254b0

File tree

4 files changed

+208
-2
lines changed
Filter options

4 files changed

+208
-2
lines changed
+41Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
5+
from agents import Agent, Runner, function_tool, set_tracing_disabled
6+
7+
"""This example uses the built-in support for LiteLLM. To use this, ensure you have the
8+
ANTHROPIC_API_KEY environment variable set.
9+
"""
10+
11+
set_tracing_disabled(disabled=True)
12+
13+
14+
@function_tool
15+
def get_weather(city: str):
16+
print(f"[debug] getting weather for {city}")
17+
return f"The weather in {city} is sunny."
18+
19+
20+
async def main():
21+
agent = Agent(
22+
name="Assistant",
23+
instructions="You only respond in haikus.",
24+
# We prefix with litellm/ to tell the Runner to use the LitellmModel
25+
model="litellm/anthropic/claude-3-5-sonnet-20240620",
26+
tools=[get_weather],
27+
)
28+
29+
result = await Runner.run(agent, "What's the weather in Tokyo?")
30+
print(result.final_output)
31+
32+
33+
if __name__ == "__main__":
34+
import os
35+
36+
if os.getenv("ANTHROPIC_API_KEY") is None:
37+
raise ValueError(
38+
"ANTHROPIC_API_KEY is not set. Please set it the environment variable and try again."
39+
)
40+
41+
asyncio.run(main())
+21Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from ...models.interface import Model, ModelProvider
2+
from .litellm_model import LitellmModel
3+
4+
DEFAULT_MODEL: str = "gpt-4.1"
5+
6+
7+
class LitellmProvider(ModelProvider):
8+
"""A ModelProvider that uses LiteLLM to route to any model provider. You can use it via:
9+
```python
10+
Runner.run(agent, input, run_config=RunConfig(model_provider=LitellmProvider()))
11+
```
12+
See supported models here: [litellm models](https://docs.litellm.ai/docs/providers).
13+
14+
NOTE: API keys must be set via environment variables. If you're using models that require
15+
additional configuration (e.g. Azure API base or version), those must also be set via the
16+
environment variables that LiteLLM expects. If you have more advanced needs, we recommend
17+
copy-pasting this class and making any modifications you need.
18+
"""
19+
20+
def get_model(self, model_name: str | None) -> Model:
21+
return LitellmModel(model_name or DEFAULT_MODEL)

‎src/agents/models/multi_provider.py

Copy file name to clipboard
+144Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from __future__ import annotations
2+
3+
from openai import AsyncOpenAI
4+
5+
from ..exceptions import UserError
6+
from .interface import Model, ModelProvider
7+
from .openai_provider import OpenAIProvider
8+
9+
10+
class MultiProviderMap:
11+
"""A map of model name prefixes to ModelProviders."""
12+
13+
def __init__(self):
14+
self._mapping: dict[str, ModelProvider] = {}
15+
16+
def has_prefix(self, prefix: str) -> bool:
17+
"""Returns True if the given prefix is in the mapping."""
18+
return prefix in self._mapping
19+
20+
def get_mapping(self) -> dict[str, ModelProvider]:
21+
"""Returns a copy of the current prefix -> ModelProvider mapping."""
22+
return self._mapping.copy()
23+
24+
def set_mapping(self, mapping: dict[str, ModelProvider]):
25+
"""Overwrites the current mapping with a new one."""
26+
self._mapping = mapping
27+
28+
def get_provider(self, prefix: str) -> ModelProvider | None:
29+
"""Returns the ModelProvider for the given prefix.
30+
31+
Args:
32+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
33+
"""
34+
return self._mapping.get(prefix)
35+
36+
def add_provider(self, prefix: str, provider: ModelProvider):
37+
"""Adds a new prefix -> ModelProvider mapping.
38+
39+
Args:
40+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
41+
provider: The ModelProvider to use for the given prefix.
42+
"""
43+
self._mapping[prefix] = provider
44+
45+
def remove_provider(self, prefix: str):
46+
"""Removes the mapping for the given prefix.
47+
48+
Args:
49+
prefix: The prefix of the model name e.g. "openai" or "my_prefix".
50+
"""
51+
del self._mapping[prefix]
52+
53+
54+
class MultiProvider(ModelProvider):
55+
"""This ModelProvider maps to a Model based on the prefix of the model name. By default, the
56+
mapping is:
57+
- "openai/" prefix or no prefix -> OpenAIProvider. e.g. "openai/gpt-4.1", "gpt-4.1"
58+
- "litellm/" prefix -> LitellmProvider. e.g. "litellm/openai/gpt-4.1"
59+
60+
You can override or customize this mapping.
61+
"""
62+
63+
def __init__(
64+
self,
65+
*,
66+
provider_map: MultiProviderMap | None = None,
67+
openai_api_key: str | None = None,
68+
openai_base_url: str | None = None,
69+
openai_client: AsyncOpenAI | None = None,
70+
openai_organization: str | None = None,
71+
openai_project: str | None = None,
72+
openai_use_responses: bool | None = None,
73+
) -> None:
74+
"""Create a new OpenAI provider.
75+
76+
Args:
77+
provider_map: A MultiProviderMap that maps prefixes to ModelProviders. If not provided,
78+
we will use a default mapping. See the documentation for this class to see the
79+
default mapping.
80+
openai_api_key: The API key to use for the OpenAI provider. If not provided, we will use
81+
the default API key.
82+
openai_base_url: The base URL to use for the OpenAI provider. If not provided, we will
83+
use the default base URL.
84+
openai_client: An optional OpenAI client to use. If not provided, we will create a new
85+
OpenAI client using the api_key and base_url.
86+
openai_organization: The organization to use for the OpenAI provider.
87+
openai_project: The project to use for the OpenAI provider.
88+
openai_use_responses: Whether to use the OpenAI responses API.
89+
"""
90+
self.provider_map = provider_map
91+
self.openai_provider = OpenAIProvider(
92+
api_key=openai_api_key,
93+
base_url=openai_base_url,
94+
openai_client=openai_client,
95+
organization=openai_organization,
96+
project=openai_project,
97+
use_responses=openai_use_responses,
98+
)
99+
100+
self._fallback_providers: dict[str, ModelProvider] = {}
101+
102+
def _get_prefix_and_model_name(self, model_name: str | None) -> tuple[str | None, str | None]:
103+
if model_name is None:
104+
return None, None
105+
elif "/" in model_name:
106+
prefix, model_name = model_name.split("/", 1)
107+
return prefix, model_name
108+
else:
109+
return None, model_name
110+
111+
def _create_fallback_provider(self, prefix: str) -> ModelProvider:
112+
if prefix == "litellm":
113+
from ..extensions.models.litellm_provider import LitellmProvider
114+
115+
return LitellmProvider()
116+
else:
117+
raise UserError(f"Unknown prefix: {prefix}")
118+
119+
def _get_fallback_provider(self, prefix: str | None) -> ModelProvider:
120+
if prefix is None or prefix == "openai":
121+
return self.openai_provider
122+
elif prefix in self._fallback_providers:
123+
return self._fallback_providers[prefix]
124+
else:
125+
self._fallback_providers[prefix] = self._create_fallback_provider(prefix)
126+
return self._fallback_providers[prefix]
127+
128+
def get_model(self, model_name: str | None) -> Model:
129+
"""Returns a Model based on the model name. The model name can have a prefix, ending with
130+
a "/", which will be used to look up the ModelProvider. If there is no prefix, we will use
131+
the OpenAI provider.
132+
133+
Args:
134+
model_name: The name of the model to get.
135+
136+
Returns:
137+
A Model.
138+
"""
139+
prefix, model_name = self._get_prefix_and_model_name(model_name)
140+
141+
if prefix and self.provider_map and (provider := self.provider_map.get_provider(prefix)):
142+
return provider.get_model(model_name)
143+
else:
144+
return self._get_fallback_provider(prefix).get_model(model_name)

‎src/agents/run.py

Copy file name to clipboardExpand all lines: src/agents/run.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .logger import logger
3535
from .model_settings import ModelSettings
3636
from .models.interface import Model, ModelProvider
37-
from .models.openai_provider import OpenAIProvider
37+
from .models.multi_provider import MultiProvider
3838
from .result import RunResult, RunResultStreaming
3939
from .run_context import RunContextWrapper, TContext
4040
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
@@ -56,7 +56,7 @@ class RunConfig:
5656
agent. The model_provider passed in below must be able to resolve this model name.
5757
"""
5858

59-
model_provider: ModelProvider = field(default_factory=OpenAIProvider)
59+
model_provider: ModelProvider = field(default_factory=MultiProvider)
6060
"""The model provider to use when looking up string model names. Defaults to OpenAI."""
6161

6262
model_settings: ModelSettings | None = None

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.