Skip to content

Add tool call parameters for on_tool_start hook #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/basic/agent_lifecycle_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from agents import Agent, AgentHooks, RunContextWrapper, Runner, Tool, function_tool
from agents import Agent, AgentHooks, RunContextWrapper, Runner, Tool, function_tool, ToolCallItem


class CustomAgentHooks(AgentHooks):
Expand All @@ -28,7 +28,7 @@ async def on_handoff(self, context: RunContextWrapper, agent: Agent, source: Age
f"### ({self.display_name}) {self.event_counter}: Agent {source.name} handed off to {agent.name}"
)

async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool, tool_call: ToolCallItem) -> None:
self.event_counter += 1
print(
f"### ({self.display_name}) {self.event_counter}: Agent {agent.name} started tool {tool.name}"
Expand Down
24 changes: 20 additions & 4 deletions examples/basic/lifecycle_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@

from pydantic import BaseModel

from agents import Agent, RunContextWrapper, RunHooks, Runner, Tool, Usage, function_tool

from agents import (
Agent,
RunContextWrapper,
RunHooks,
Runner,
Tool,
Usage,
function_tool,
ToolCallItem
)

class ExampleHooks(RunHooks):
def __init__(self):
Expand All @@ -20,13 +28,21 @@ async def on_agent_start(self, context: RunContextWrapper, agent: Agent) -> None
f"### {self.event_counter}: Agent {agent.name} started. Usage: {self._usage_to_str(context.usage)}"
)

async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: Any) -> None:
async def on_agent_end(
self, context: RunContextWrapper, agent: Agent, output: Any
) -> None:
self.event_counter += 1
print(
f"### {self.event_counter}: Agent {agent.name} ended with output {output}. Usage: {self._usage_to_str(context.usage)}"
)

async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
async def on_tool_start(
self,
context: RunContextWrapper,
agent: Agent,
tool: Tool,
tool_call: ToolCallItem,
) -> None:
self.event_counter += 1
print(
f"### {self.event_counter}: Tool {tool.name} started. Usage: {self._usage_to_str(context.usage)}"
Expand Down
76 changes: 55 additions & 21 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Computer
from .exceptions import AgentsException, ModelBehaviorError, UserError
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
from .guardrail import (
InputGuardrail,
InputGuardrailResult,
OutputGuardrail,
OutputGuardrailResult,
)
from .handoffs import Handoff, HandoffInputData
from .items import (
HandoffCallItem,
Expand Down Expand Up @@ -271,17 +276,25 @@ async def execute_tools_and_side_effects(
)

# Now we can check if the model also produced a final output
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
message_items = [
item for item in new_step_items if isinstance(item, MessageOutputItem)
]

# We'll use the last content output as the final output
potential_final_output_text = (
ItemHelpers.extract_last_text(message_items[-1].raw_item) if message_items else None
ItemHelpers.extract_last_text(message_items[-1].raw_item)
if message_items
else None
)

# There are two possibilities that lead to a final output:
# 1. Structured output schema => always leads to a final output
# 2. Plain text output schema => only leads to a final output if there are no tool calls
if output_schema and not output_schema.is_plain_text() and potential_final_output_text:
if (
output_schema
and not output_schema.is_plain_text()
and potential_final_output_text
):
final_output = output_schema.validate_json(potential_final_output_text)
return await cls.execute_final_output(
agent=agent,
Expand Down Expand Up @@ -402,7 +415,9 @@ def process_model_response(
data={"tool_name": output.name},
)
)
raise ModelBehaviorError(f"Tool {output.name} not found in agent {agent.name}")
raise ModelBehaviorError(
f"Tool {output.name} not found in agent {agent.name}"
)
items.append(ToolCallItem(raw_item=output, agent=agent))
functions.append(
ToolRunFunction(
Expand Down Expand Up @@ -437,9 +452,13 @@ async def run_single_tool(
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, func_tool),
hooks.on_tool_start(
context_wrapper, agent, func_tool, ToolCallItem(raw_item=tool_call)
),
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
agent.hooks.on_tool_start(
context_wrapper, agent, func_tool, ToolCallItem(raw_item=tool_call)
)
if agent.hooks

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since using only arguments as a parameter is too granular, and adding tool_call.id in the future would lead to parameter inflation at the root level, plus tool_call.id is actually needed, would it be better to pass the entire tool_call object as a parameter instead?"

else _coro.noop_coroutine()
),
Expand All @@ -449,7 +468,9 @@ async def run_single_tool(
await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
(
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
agent.hooks.on_tool_end(
context_wrapper, agent, func_tool, result
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -482,7 +503,9 @@ async def run_single_tool(
output=result,
run_item=ToolCallOutputItem(
output=result,
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
raw_item=ItemHelpers.tool_call_output_item(
tool_run.tool_call, str(result)
),
agent=agent,
),
)
Expand Down Expand Up @@ -590,9 +613,11 @@ async def execute_handoffs(
if input_filter:
logger.debug("Filtering inputs for handoff")
handoff_input_data = HandoffInputData(
input_history=tuple(original_input)
if isinstance(original_input, list)
else original_input,
input_history=(
tuple(original_input)
if isinstance(original_input, list)
else original_input
),
pre_handoff_items=tuple(pre_step_items),
new_items=tuple(new_step_items),
)
Expand Down Expand Up @@ -666,9 +691,11 @@ async def run_final_output_hooks(
):
await asyncio.gather(
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _coro.noop_coroutine(),
(
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _coro.noop_coroutine()
),
)

@classmethod
Expand All @@ -693,7 +720,9 @@ async def run_single_output_guardrail(
context: RunContextWrapper[TContext],
) -> OutputGuardrailResult:
with guardrail_span(guardrail.get_name()) as span_guardrail:
result = await guardrail.run(agent=agent, agent_output=agent_output, context=context)
result = await guardrail.run(
agent=agent, agent_output=agent_output, context=context
)
span_guardrail.span_data.triggered = result.output.tripwire_triggered
return result

Expand Down Expand Up @@ -758,7 +787,8 @@ async def _check_for_final_output_from_tools(
)
else:
return cast(
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
ToolsToFinalOutputResult,
agent.tool_use_behavior(context_wrapper, tool_results),
)

logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
Expand Down Expand Up @@ -816,13 +846,15 @@ async def execute(
output_func = (
cls._get_screenshot_async(action.computer_tool.computer, action.tool_call)
if isinstance(action.computer_tool.computer, AsyncComputer)
else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call)
else cls._get_screenshot_sync(
action.computer_tool.computer, action.tool_call
)
)

_, _, output = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
hooks.on_tool_start(context_wrapper, agent, action.computer_tool, ToolCallItem(raw_item=action.tool_call)),
(
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool, ToolCallItem(raw_item=action.tool_call))
if agent.hooks
else _coro.noop_coroutine()
),
Expand All @@ -832,7 +864,9 @@ async def execute(
await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
(
agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
agent.hooks.on_tool_end(
context_wrapper, agent, action.computer_tool, output
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down
7 changes: 6 additions & 1 deletion src/agents/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .agent import Agent
from .run_context import RunContextWrapper, TContext
from .tool import Tool
from .items import ToolCallItem


class RunHooks(Generic[TContext]):
Expand Down Expand Up @@ -39,6 +40,7 @@ async def on_tool_start(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
tool_call: ToolCallItem,
) -> None:
"""Called before a tool is invoked."""
pass
Expand All @@ -61,7 +63,9 @@ class AgentHooks(Generic[TContext]):
Subclass and override the methods you need.
"""

async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
async def on_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
) -> None:
"""Called before the agent is invoked. Called each time the running agent is changed to this
agent."""
pass
Expand Down Expand Up @@ -90,6 +94,7 @@ async def on_tool_start(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
tool_call: ToolCallItem,
) -> None:
"""Called before a tool is invoked."""
pass
Expand Down
8 changes: 6 additions & 2 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
ToolFunctionWithoutContext = Callable[ToolParams, Any]
ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any]

ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
ToolFunction = Union[
ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]
]


@dataclass
Expand Down Expand Up @@ -244,7 +246,9 @@ async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
else schema.params_pydantic_model()
)
except ValidationError as e:
raise ModelBehaviorError(f"Invalid JSON input for tool {schema.name}: {e}") from e
raise ModelBehaviorError(
f"Invalid JSON input for tool {schema.name}: {e}"
) from e

args, kwargs_dict = schema.to_call_args(parsed)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from agents.run_context import RunContextWrapper, TContext
from agents.tool import Tool

from src.agents.items import ToolCallItem
from .fake_model import FakeModel
from .test_responses import (
get_final_output_message,
Expand Down Expand Up @@ -54,6 +55,7 @@ async def on_tool_start(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
tool_call: ToolCallItem,
) -> None:
self.events["on_tool_start"] += 1

Expand Down
37 changes: 28 additions & 9 deletions tests/test_computer_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,18 @@ async def drag(self, path: list[tuple[int, int]]) -> None:
@pytest.mark.parametrize(
"action,expected_call",
[
(ActionClick(type="click", x=10, y=21, button="left"), ("click", (10, 21, "left"))),
(ActionDoubleClick(type="double_click", x=42, y=47), ("double_click", (42, 47))),
(
ActionDrag(type="drag", path=[ActionDragPath(x=1, y=2), ActionDragPath(x=3, y=4)]),
ActionClick(type="click", x=10, y=21, button="left"),
("click", (10, 21, "left")),
),
(
ActionDoubleClick(type="double_click", x=42, y=47),
("double_click", (42, 47)),
),
(
ActionDrag(
type="drag", path=[ActionDragPath(x=1, y=2), ActionDragPath(x=3, y=4)]
),
("drag", (((1, 2), (3, 4)),)),
),
(ActionKeypress(type="keypress", keys=["a", "b"]), ("keypress", (["a", "b"],))),
Expand Down Expand Up @@ -172,13 +180,24 @@ async def test_get_screenshot_sync_executes_action_and_takes_screenshot(
@pytest.mark.parametrize(
"action,expected_call",
[
(ActionClick(type="click", x=2, y=3, button="right"), ("click", (2, 3, "right"))),
(ActionDoubleClick(type="double_click", x=12, y=13), ("double_click", (12, 13))),
(
ActionDrag(type="drag", path=[ActionDragPath(x=5, y=6), ActionDragPath(x=6, y=7)]),
ActionClick(type="click", x=2, y=3, button="right"),
("click", (2, 3, "right")),
),
(
ActionDoubleClick(type="double_click", x=12, y=13),
("double_click", (12, 13)),
),
(
ActionDrag(
type="drag", path=[ActionDragPath(x=5, y=6), ActionDragPath(x=6, y=7)]
),
("drag", (((5, 6), (6, 7)),)),
),
(ActionKeypress(type="keypress", keys=["ctrl", "c"]), ("keypress", (["ctrl", "c"],))),
(
ActionKeypress(type="keypress", keys=["ctrl", "c"]),
("keypress", (["ctrl", "c"],)),
),
(ActionMove(type="move", x=8, y=9), ("move", (8, 9))),
(ActionScreenshot(type="screenshot"), ("screenshot", ())),
(
Expand Down Expand Up @@ -222,7 +241,7 @@ def __init__(self) -> None:
self.ended: list[tuple[Agent[Any], Any, str]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, tool_call: Any
) -> None:
self.started.append((agent, tool))

Expand All @@ -241,7 +260,7 @@ def __init__(self) -> None:
self.ended: list[tuple[Agent[Any], Any, str]] = []

async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, tool_call: Any,
) -> None:
self.started.append((agent, tool))

Expand Down
2 changes: 2 additions & 0 deletions tests/test_global_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from agents import Agent, RunContextWrapper, RunHooks, Runner, TContext, Tool

from src.agents.items import ToolCallItem
from .fake_model import FakeModel
from .test_responses import (
get_final_output_message,
Expand Down Expand Up @@ -52,6 +53,7 @@ async def on_tool_start(
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
tool_call: ToolCallItem,
) -> None:
self.events["on_tool_start"] += 1

Expand Down