Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 4046fcb

Browse filesBrowse files
authored
Add is_enabled to FunctionTool (#808)
### Summary: Allows a user to do `function_tool(is_enabled=<some_callable>)`; the callable is called when the agent runs. This allows you to dynamically enable/disable a tool based on the context/env. The meta-goal is to allow `Agent` to be effectively immutable. That enables some nice things down the line, and this allows you to dynamically modify the tools list without mutating the agent. ### Test Plan: Unit tests
1 parent 995af4d commit 4046fcb
Copy full SHA for 4046fcb

File tree

Expand file treeCollapse file tree

6 files changed

+102
-24
lines changed
Filter options
Expand file treeCollapse file tree

6 files changed

+102
-24
lines changed

‎src/agents/agent.py

Copy file name to clipboardExpand all lines: src/agents/agent.py
+19-3Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import dataclasses
45
import inspect
56
from collections.abc import Awaitable
@@ -17,7 +18,7 @@
1718
from .model_settings import ModelSettings
1819
from .models.interface import Model
1920
from .run_context import RunContextWrapper, TContext
20-
from .tool import FunctionToolResult, Tool, function_tool
21+
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
2122
from .util import _transforms
2223
from .util._types import MaybeAwaitable
2324

@@ -246,7 +247,22 @@ async def get_mcp_tools(self) -> list[Tool]:
246247
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
247248
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
248249

249-
async def get_all_tools(self) -> list[Tool]:
250+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
250251
"""All agent tools, including MCP tools and function tools."""
251252
mcp_tools = await self.get_mcp_tools()
252-
return mcp_tools + self.tools
253+
254+
async def _check_tool_enabled(tool: Tool) -> bool:
255+
if not isinstance(tool, FunctionTool):
256+
return True
257+
258+
attr = tool.is_enabled
259+
if isinstance(attr, bool):
260+
return attr
261+
res = attr(run_context, self)
262+
if inspect.isawaitable(res):
263+
return bool(await res)
264+
return bool(res)
265+
266+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
267+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
268+
return [*mcp_tools, *enabled]

‎src/agents/run.py

Copy file name to clipboardExpand all lines: src/agents/run.py
+6-4Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async def run(
181181

182182
try:
183183
while True:
184-
all_tools = await cls._get_all_tools(current_agent)
184+
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
185185

186186
# Start an agent span if we don't have one. This span is ended if the current
187187
# agent changes, or if the agent loop ends.
@@ -525,7 +525,7 @@ async def _run_streamed_impl(
525525
if streamed_result.is_complete:
526526
break
527527

528-
all_tools = await cls._get_all_tools(current_agent)
528+
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
529529

530530
# Start an agent span if we don't have one. This span is ended if the current
531531
# agent changes, or if the agent loop ends.
@@ -980,8 +980,10 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
980980
return handoffs
981981

982982
@classmethod
983-
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
984-
return await agent.get_all_tools()
983+
async def _get_all_tools(
984+
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
985+
) -> list[Tool]:
986+
return await agent.get_all_tools(context_wrapper)
985987

986988
@classmethod
987989
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:

‎src/agents/tool.py

Copy file name to clipboardExpand all lines: src/agents/tool.py
+16-1Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
from collections.abc import Awaitable
66
from dataclasses import dataclass
7-
from typing import Any, Callable, Literal, Union, overload
7+
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload
88

99
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
1010
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
@@ -24,6 +24,9 @@
2424
from .util import _error_tracing
2525
from .util._types import MaybeAwaitable
2626

27+
if TYPE_CHECKING:
28+
from .agent import Agent
29+
2730
ToolParams = ParamSpec("ToolParams")
2831

2932
ToolFunctionWithoutContext = Callable[ToolParams, Any]
@@ -74,6 +77,11 @@ class FunctionTool:
7477
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
7578
as it increases the likelihood of correct JSON input."""
7679

80+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
81+
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
82+
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
83+
based on your context/state."""
84+
7785

7886
@dataclass
7987
class FileSearchTool:
@@ -262,6 +270,7 @@ def function_tool(
262270
use_docstring_info: bool = True,
263271
failure_error_function: ToolErrorFunction | None = None,
264272
strict_mode: bool = True,
273+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
265274
) -> FunctionTool:
266275
"""Overload for usage as @function_tool (no parentheses)."""
267276
...
@@ -276,6 +285,7 @@ def function_tool(
276285
use_docstring_info: bool = True,
277286
failure_error_function: ToolErrorFunction | None = None,
278287
strict_mode: bool = True,
288+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
279289
) -> Callable[[ToolFunction[...]], FunctionTool]:
280290
"""Overload for usage as @function_tool(...)."""
281291
...
@@ -290,6 +300,7 @@ def function_tool(
290300
use_docstring_info: bool = True,
291301
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
292302
strict_mode: bool = True,
303+
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
293304
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
294305
"""
295306
Decorator to create a FunctionTool from a function. By default, we will:
@@ -318,6 +329,9 @@ def function_tool(
318329
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
319330
value, it will be optional, additional properties are allowed, etc. See here for more:
320331
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
332+
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
333+
context and agent and returns whether the tool is enabled. Disabled tools are hidden
334+
from the LLM at runtime.
321335
"""
322336

323337
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@@ -407,6 +421,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
407421
params_json_schema=schema.params_json_schema,
408422
on_invoke_tool=_on_invoke_tool,
409423
strict_json_schema=strict_mode,
424+
is_enabled=is_enabled,
410425
)
411426

412427
# If func is actually a callable, we were used as @function_tool with no parentheses

‎tests/test_function_tool.py

Copy file name to clipboardExpand all lines: tests/test_function_tool.py
+42-1Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import BaseModel
66
from typing_extensions import TypedDict
77

8-
from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
8+
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
99
from agents.tool import default_tool_error_function
1010

1111

@@ -255,3 +255,44 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
255255

256256
result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
257257
assert result == "error_ValueError"
258+
259+
260+
class BoolCtx(BaseModel):
261+
enable_tools: bool
262+
263+
264+
@pytest.mark.asyncio
265+
async def test_is_enabled_bool_and_callable():
266+
@function_tool(is_enabled=False)
267+
def disabled_tool():
268+
return "nope"
269+
270+
async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool:
271+
return ctx.context.enable_tools
272+
273+
@function_tool(is_enabled=cond_enabled)
274+
def another_tool():
275+
return "hi"
276+
277+
async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str:
278+
return "third"
279+
280+
third_tool = FunctionTool(
281+
name="third_tool",
282+
description="third tool",
283+
on_invoke_tool=third_tool_on_invoke_tool,
284+
is_enabled=lambda ctx, agent: ctx.context.enable_tools,
285+
params_json_schema={},
286+
)
287+
288+
agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool])
289+
context_1 = RunContextWrapper(BoolCtx(enable_tools=False))
290+
context_2 = RunContextWrapper(BoolCtx(enable_tools=True))
291+
292+
tools_with_ctx = await agent.get_all_tools(context_1)
293+
assert tools_with_ctx == []
294+
295+
tools_with_ctx = await agent.get_all_tools(context_2)
296+
assert len(tools_with_ctx) == 2
297+
assert tools_with_ctx[0].name == "another_tool"
298+
assert tools_with_ctx[1].name == "third_tool"

‎tests/test_run_step_execution.py

Copy file name to clipboardExpand all lines: tests/test_run_step_execution.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ async def get_execute_result(
290290

291291
processed_response = RunImpl.process_model_response(
292292
agent=agent,
293-
all_tools=await agent.get_all_tools(),
293+
all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)),
294294
response=response,
295295
output_schema=output_schema,
296296
handoffs=handoffs,

‎tests/test_run_step_processing.py

Copy file name to clipboardExpand all lines: tests/test_run_step_processing.py
+18-14Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
)
3535

3636

37+
def _dummy_ctx() -> RunContextWrapper[None]:
38+
return RunContextWrapper(context=None)
39+
40+
3741
def test_empty_response():
3842
agent = Agent(name="test")
3943
response = ModelResponse(
@@ -83,7 +87,7 @@ async def test_single_tool_call():
8387
response=response,
8488
output_schema=None,
8589
handoffs=[],
86-
all_tools=await agent.get_all_tools(),
90+
all_tools=await agent.get_all_tools(_dummy_ctx()),
8791
)
8892
assert not result.handoffs
8993
assert result.functions and len(result.functions) == 1
@@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error():
111115
response=response,
112116
output_schema=None,
113117
handoffs=[],
114-
all_tools=await agent.get_all_tools(),
118+
all_tools=await agent.get_all_tools(_dummy_ctx()),
115119
)
116120

117121

@@ -140,7 +144,7 @@ async def test_multiple_tool_calls():
140144
response=response,
141145
output_schema=None,
142146
handoffs=[],
143-
all_tools=await agent.get_all_tools(),
147+
all_tools=await agent.get_all_tools(_dummy_ctx()),
144148
)
145149
assert not result.handoffs
146150
assert result.functions and len(result.functions) == 2
@@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly():
169173
response=response,
170174
output_schema=None,
171175
handoffs=[],
172-
all_tools=await agent_3.get_all_tools(),
176+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
173177
)
174178
assert not result.handoffs, "Shouldn't have a handoff here"
175179

@@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly():
183187
response=response,
184188
output_schema=None,
185189
handoffs=Runner._get_handoffs(agent_3),
186-
all_tools=await agent_3.get_all_tools(),
190+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
187191
)
188192
assert len(result.handoffs) == 1, "Should have a handoff here"
189193
handoff = result.handoffs[0]
@@ -213,7 +217,7 @@ async def test_missing_handoff_fails():
213217
response=response,
214218
output_schema=None,
215219
handoffs=Runner._get_handoffs(agent_3),
216-
all_tools=await agent_3.get_all_tools(),
220+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
217221
)
218222

219223

@@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error():
236240
response=response,
237241
output_schema=None,
238242
handoffs=Runner._get_handoffs(agent_3),
239-
all_tools=await agent_3.get_all_tools(),
243+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
240244
)
241245
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
242246

@@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly():
262266
response=response,
263267
output_schema=Runner._get_output_schema(agent),
264268
handoffs=[],
265-
all_tools=await agent.get_all_tools(),
269+
all_tools=await agent.get_all_tools(_dummy_ctx()),
266270
)
267271

268272

@@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly():
288292
response=response,
289293
output_schema=None,
290294
handoffs=[],
291-
all_tools=await agent.get_all_tools(),
295+
all_tools=await agent.get_all_tools(_dummy_ctx()),
292296
)
293297
# The final item should be a ToolCallItem for the file search call
294298
assert any(
@@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly():
313317
response=response,
314318
output_schema=None,
315319
handoffs=[],
316-
all_tools=await agent.get_all_tools(),
320+
all_tools=await agent.get_all_tools(_dummy_ctx()),
317321
)
318322
assert any(
319323
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
@@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly():
340344
response=response,
341345
output_schema=None,
342346
handoffs=[],
343-
all_tools=await Agent(name="test").get_all_tools(),
347+
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
344348
)
345349
assert any(
346350
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
@@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error():
409413
response=response,
410414
output_schema=None,
411415
handoffs=[],
412-
all_tools=await Agent(name="test").get_all_tools(),
416+
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
413417
)
414418

415419

@@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly():
437441
response=response,
438442
output_schema=None,
439443
handoffs=[],
440-
all_tools=await agent.get_all_tools(),
444+
all_tools=await agent.get_all_tools(_dummy_ctx()),
441445
)
442446
assert any(
443447
isinstance(item, ToolCallItem) and item.raw_item is computer_call
@@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly():
468472
response=response,
469473
output_schema=None,
470474
handoffs=Runner._get_handoffs(agent_3),
471-
all_tools=await agent_3.get_all_tools(),
475+
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
472476
)
473477
assert result.functions and len(result.functions) == 1
474478
assert len(result.handoffs) == 1, "Should have a handoff here"

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.