Skip to content

Commit 9205bcb

Browse files
committed
chore: add hooks for pre/post tool call manipulation
1 parent 08fa737 commit 9205bcb

File tree

1 file changed

+86
-6
lines changed

1 file changed

+86
-6
lines changed

agent.go

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"maps"
1010
"slices"
11+
"time"
1112
)
1213

1314
// StepResult represents the result of a single step in an agent execution.
@@ -228,9 +229,20 @@ type (
228229
// OnToolCallFunc is called when tool call is complete.
229230
OnToolCallFunc func(toolCall ToolCallContent) error
230231

232+
// PreToolExecuteFunc is called before tool execution.
233+
// Can modify the tool call or return an error to skip execution.
234+
// Returning a modified ToolCall allows changing input parameters.
235+
// Returning an error creates an error result without executing the tool.
236+
PreToolExecuteFunc func(ctx context.Context, toolCall ToolCall) (context.Context, *ToolCall, error)
237+
231238
// OnToolResultFunc is called when tool execution completes.
232239
OnToolResultFunc func(result ToolResultContent) error
233240

241+
// PostToolExecuteFunc is called after tool execution, before sending result to LLM.
242+
// Can modify the tool response or return an error to replace the response.
243+
// Returning a modified ToolResponse allows filtering or redacting output.
244+
PostToolExecuteFunc func(ctx context.Context, toolCall ToolCall, response ToolResponse, executionTimeMs int64) (*ToolResponse, error)
245+
234246
// OnSourceFunc is called for source references.
235247
OnSourceFunc func(source SourceContent) error
236248

@@ -280,7 +292,9 @@ type AgentStreamCall struct {
280292
OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas
281293
OnToolInputEnd OnToolInputEndFunc // Called when tool input ends
282294
OnToolCall OnToolCallFunc // Called when tool call is complete
295+
PreToolExecute PreToolExecuteFunc // Called before tool execution (can modify input or block)
283296
OnToolResult OnToolResultFunc // Called when tool execution completes
297+
PostToolExecute PostToolExecuteFunc // Called after tool execution (can modify output)
284298
OnSource OnSourceFunc // Called for source references
285299
OnStreamFinish OnStreamFinishFunc // Called when stream finishes
286300
}
@@ -462,7 +476,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
462476
}
463477
}
464478

465-
toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
479+
toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil, nil)
466480

467481
// Build step content with validated tool calls and tool results
468482
stepContent := []Content{}
@@ -616,7 +630,7 @@ func toResponseMessages(content []Content) []Message {
616630
return messages
617631
}
618632

619-
func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
633+
func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error, preToolExecute PreToolExecuteFunc, postToolExecute PostToolExecuteFunc) ([]ToolResultContent, error) {
620634
if len(toolCalls) == 0 {
621635
return nil, nil
622636
}
@@ -669,12 +683,78 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
669683
continue
670684
}
671685

672-
// Execute the tool
673-
toolResult, err := tool.Run(ctx, ToolCall{
686+
// Prepare tool call for execution
687+
executionToolCall := ToolCall{
674688
ID: toolCall.ToolCallID,
675689
Name: toolCall.ToolName,
676690
Input: toolCall.Input,
677-
})
691+
}
692+
693+
// Call pre-tool execute hook
694+
var preHookErr error
695+
toolCtx := ctx
696+
if preToolExecute != nil {
697+
updatedCtx, modifiedCall, err := preToolExecute(ctx, executionToolCall)
698+
if err != nil {
699+
preHookErr = err
700+
} else {
701+
toolCtx = updatedCtx
702+
if modifiedCall != nil {
703+
executionToolCall = *modifiedCall
704+
}
705+
}
706+
}
707+
708+
// If pre-hook returned error, create error result and skip execution
709+
if preHookErr != nil {
710+
result := ToolResultContent{
711+
ToolCallID: toolCall.ToolCallID,
712+
ToolName: toolCall.ToolName,
713+
Result: ToolResultOutputContentError{
714+
Error: preHookErr,
715+
},
716+
ProviderExecuted: false,
717+
}
718+
results = append(results, result)
719+
if toolResultCallback != nil {
720+
if err := toolResultCallback(result); err != nil {
721+
return nil, err
722+
}
723+
}
724+
// Continue to next tool call instead of returning error
725+
continue
726+
}
727+
728+
// Execute the tool with timing
729+
startTime := time.Now()
730+
toolResult, err := tool.Run(toolCtx, executionToolCall)
731+
executionTimeMs := time.Since(startTime).Milliseconds()
732+
733+
// Call post-tool execute hook
734+
if postToolExecute != nil && err == nil {
735+
modifiedResponse, postErr := postToolExecute(ctx, executionToolCall, toolResult, executionTimeMs)
736+
if postErr != nil {
737+
// Post-hook error stops execution
738+
result := ToolResultContent{
739+
ToolCallID: toolCall.ToolCallID,
740+
ToolName: toolCall.ToolName,
741+
Result: ToolResultOutputContentError{
742+
Error: postErr,
743+
},
744+
ClientMetadata: toolResult.Metadata,
745+
ProviderExecuted: false,
746+
}
747+
if toolResultCallback != nil {
748+
if cbErr := toolResultCallback(result); cbErr != nil {
749+
return nil, cbErr
750+
}
751+
}
752+
return nil, postErr
753+
} else if modifiedResponse != nil {
754+
toolResult = *modifiedResponse
755+
}
756+
}
757+
678758
if err != nil {
679759
result := ToolResultContent{
680760
ToolCallID: toolCall.ToolCallID,
@@ -1307,7 +1387,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
13071387
var toolResults []ToolResultContent
13081388
if len(stepToolCalls) > 0 {
13091389
var err error
1310-
toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1390+
toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult, opts.PreToolExecute, opts.PostToolExecute)
13111391
if err != nil {
13121392
return stepExecutionResult{}, err
13131393
}

0 commit comments

Comments
 (0)