Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 86 additions & 6 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"maps"
"slices"
"time"
)

// StepResult represents the result of a single step in an agent execution.
Expand Down Expand Up @@ -228,9 +229,20 @@ type (
// OnToolCallFunc is called when tool call is complete.
OnToolCallFunc func(toolCall ToolCallContent) error

// PreToolExecuteFunc is called before tool execution.
// Can modify the tool call or return an error to skip execution.
// Returning a modified ToolCall allows changing input parameters.
// Returning an error creates an error result without executing the tool.
PreToolExecuteFunc func(ctx context.Context, toolCall ToolCall) (context.Context, *ToolCall, error)

// OnToolResultFunc is called when tool execution completes.
OnToolResultFunc func(result ToolResultContent) error

// PostToolExecuteFunc is called after tool execution, before sending result to LLM.
// Can modify the tool response or return an error to replace the response.
// Returning a modified ToolResponse allows filtering or redacting output.
PostToolExecuteFunc func(ctx context.Context, toolCall ToolCall, response ToolResponse, executionTimeMs int64) (*ToolResponse, error)

// OnSourceFunc is called for source references.
OnSourceFunc func(source SourceContent) error

Expand Down Expand Up @@ -280,7 +292,9 @@ type AgentStreamCall struct {
OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
OnToolCall OnToolCallFunc // Called when tool call is complete
PreToolExecute PreToolExecuteFunc // Called before tool execution (can modify input or block)
OnToolResult OnToolResultFunc // Called when tool execution completes
PostToolExecute PostToolExecuteFunc // Called after tool execution (can modify output)
OnSource OnSourceFunc // Called for source references
OnStreamFinish OnStreamFinishFunc // Called when stream finishes
}
Expand Down Expand Up @@ -462,7 +476,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}
}

toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil, nil)

// Build step content with validated tool calls and tool results
stepContent := []Content{}
Expand Down Expand Up @@ -616,7 +630,7 @@ func toResponseMessages(content []Content) []Message {
return messages
}

func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error, preToolExecute PreToolExecuteFunc, postToolExecute PostToolExecuteFunc) ([]ToolResultContent, error) {
if len(toolCalls) == 0 {
return nil, nil
}
Expand Down Expand Up @@ -669,12 +683,78 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
continue
}

// Execute the tool
toolResult, err := tool.Run(ctx, ToolCall{
// Prepare tool call for execution
executionToolCall := ToolCall{
ID: toolCall.ToolCallID,
Name: toolCall.ToolName,
Input: toolCall.Input,
})
}

// Call pre-tool execute hook
var preHookErr error
toolCtx := ctx
if preToolExecute != nil {
updatedCtx, modifiedCall, err := preToolExecute(ctx, executionToolCall)
if err != nil {
preHookErr = err
} else {
toolCtx = updatedCtx
if modifiedCall != nil {
executionToolCall = *modifiedCall
}
}
}

// If pre-hook returned error, create error result and skip execution
if preHookErr != nil {
result := ToolResultContent{
ToolCallID: toolCall.ToolCallID,
ToolName: toolCall.ToolName,
Result: ToolResultOutputContentError{
Error: preHookErr,
},
ProviderExecuted: false,
}
results = append(results, result)
if toolResultCallback != nil {
if err := toolResultCallback(result); err != nil {
return nil, err
}
}
// Continue to next tool call instead of returning error
continue
}

// Execute the tool with timing
startTime := time.Now()
toolResult, err := tool.Run(toolCtx, executionToolCall)
executionTimeMs := time.Since(startTime).Milliseconds()

// Call post-tool execute hook
if postToolExecute != nil && err == nil {
modifiedResponse, postErr := postToolExecute(ctx, executionToolCall, toolResult, executionTimeMs)
if postErr != nil {
// Post-hook error stops execution
result := ToolResultContent{
ToolCallID: toolCall.ToolCallID,
ToolName: toolCall.ToolName,
Result: ToolResultOutputContentError{
Error: postErr,
},
ClientMetadata: toolResult.Metadata,
ProviderExecuted: false,
}
if toolResultCallback != nil {
if cbErr := toolResultCallback(result); cbErr != nil {
return nil, cbErr
}
}
return nil, postErr
} else if modifiedResponse != nil {
toolResult = *modifiedResponse
}
}

if err != nil {
result := ToolResultContent{
ToolCallID: toolCall.ToolCallID,
Expand Down Expand Up @@ -1307,7 +1387,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
var toolResults []ToolResultContent
if len(stepToolCalls) > 0 {
var err error
toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult, opts.PreToolExecute, opts.PostToolExecute)
if err != nil {
return stepExecutionResult{}, err
}
Expand Down
Loading