diff --git a/agent.go b/agent.go index 0a81daa3e..fe70ba067 100644 --- a/agent.go +++ b/agent.go @@ -8,6 +8,7 @@ import ( "fmt" "maps" "slices" + "time" ) // StepResult represents the result of a single step in an agent execution. @@ -228,9 +229,20 @@ type ( // OnToolCallFunc is called when tool call is complete. OnToolCallFunc func(toolCall ToolCallContent) error + // PreToolExecuteFunc is called before tool execution. + // Can modify the tool call or return an error to skip execution. + // Returning a modified ToolCall allows changing input parameters. + // Returning an error creates an error result without executing the tool. + PreToolExecuteFunc func(ctx context.Context, toolCall ToolCall) (context.Context, *ToolCall, error) + // OnToolResultFunc is called when tool execution completes. OnToolResultFunc func(result ToolResultContent) error + // PostToolExecuteFunc is called after tool execution, before sending result to LLM. + // Can modify the tool response or return an error to replace the response. + // Returning a modified ToolResponse allows filtering or redacting output. + PostToolExecuteFunc func(ctx context.Context, toolCall ToolCall, response ToolResponse, executionTimeMs int64) (*ToolResponse, error) + // OnSourceFunc is called for source references. OnSourceFunc func(source SourceContent) error @@ -280,7 +292,9 @@ type AgentStreamCall struct { OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas OnToolInputEnd OnToolInputEndFunc // Called when tool input ends OnToolCall OnToolCallFunc // Called when tool call is complete + PreToolExecute PreToolExecuteFunc // Called before tool execution (can modify input or block) OnToolResult OnToolResultFunc // Called when tool execution completes + PostToolExecute PostToolExecuteFunc // Called after tool execution (can modify output) OnSource OnSourceFunc // Called for source references OnStreamFinish OnStreamFinishFunc // Called when stream finishes } @@ -462,7 +476,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err } } - toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil) + toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil, nil) // Build step content with validated tool calls and tool results stepContent := []Content{} @@ -616,7 +630,7 @@ func toResponseMessages(content []Content) []Message { return messages } -func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) { +func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error, preToolExecute PreToolExecuteFunc, postToolExecute PostToolExecuteFunc) ([]ToolResultContent, error) { if len(toolCalls) == 0 { return nil, nil } @@ -669,12 +683,78 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall continue } - // Execute the tool - toolResult, err := tool.Run(ctx, ToolCall{ + // Prepare tool call for execution + executionToolCall := ToolCall{ ID: toolCall.ToolCallID, Name: toolCall.ToolName, Input: toolCall.Input, - }) + } + + // Call pre-tool execute hook + var preHookErr error + toolCtx := ctx + if preToolExecute != nil { + updatedCtx, modifiedCall, err := preToolExecute(ctx, executionToolCall) + if err != nil { + preHookErr = err + } else { + toolCtx = updatedCtx + if modifiedCall != nil { + executionToolCall = *modifiedCall + } + } + } + + // If pre-hook returned error, create error result and skip execution + if preHookErr != nil { + result := ToolResultContent{ + ToolCallID: toolCall.ToolCallID, + ToolName: toolCall.ToolName, + Result: ToolResultOutputContentError{ + Error: preHookErr, + }, + ProviderExecuted: false, + } + results = append(results, result) + if toolResultCallback != nil { + if err := toolResultCallback(result); err != nil { + return nil, err + } + } + // Continue to next tool call instead of returning error + continue + } + + // Execute the tool with timing + startTime := time.Now() + toolResult, err := tool.Run(toolCtx, executionToolCall) + executionTimeMs := time.Since(startTime).Milliseconds() + + // Call post-tool execute hook + if postToolExecute != nil && err == nil { + modifiedResponse, postErr := postToolExecute(ctx, executionToolCall, toolResult, executionTimeMs) + if postErr != nil { + // Post-hook error stops execution + result := ToolResultContent{ + ToolCallID: toolCall.ToolCallID, + ToolName: toolCall.ToolName, + Result: ToolResultOutputContentError{ + Error: postErr, + }, + ClientMetadata: toolResult.Metadata, + ProviderExecuted: false, + } + if toolResultCallback != nil { + if cbErr := toolResultCallback(result); cbErr != nil { + return nil, cbErr + } + } + return nil, postErr + } else if modifiedResponse != nil { + toolResult = *modifiedResponse + } + } + if err != nil { result := ToolResultContent{ ToolCallID: toolCall.ToolCallID, @@ -1307,7 +1387,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op var toolResults []ToolResultContent if len(stepToolCalls) > 0 { var err error - toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult) + toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult, opts.PreToolExecute, opts.PostToolExecute) if err != nil { return stepExecutionResult{}, err }