|
8 | 8 | "fmt" |
9 | 9 | "maps" |
10 | 10 | "slices" |
| 11 | + "time" |
11 | 12 | ) |
12 | 13 |
|
13 | 14 | // StepResult represents the result of a single step in an agent execution. |
@@ -228,9 +229,20 @@ type ( |
228 | 229 | // OnToolCallFunc is called when tool call is complete. |
229 | 230 | OnToolCallFunc func(toolCall ToolCallContent) error |
230 | 231 |
|
| 232 | + // PreToolExecuteFunc is called before tool execution. |
| 233 | + // Can modify the tool call or return an error to skip execution. |
| 234 | + // Returning a modified ToolCall allows changing input parameters. |
| 235 | + // Returning an error creates an error result without executing the tool. |
| 236 | + PreToolExecuteFunc func(ctx context.Context, toolCall ToolCall) (context.Context, *ToolCall, error) |
| 237 | + |
231 | 238 | // OnToolResultFunc is called when tool execution completes. |
232 | 239 | OnToolResultFunc func(result ToolResultContent) error |
233 | 240 |
|
| 241 | + // PostToolExecuteFunc is called after tool execution, before sending result to LLM. |
| 242 | + // Can modify the tool response or return an error to replace the response. |
| 243 | + // Returning a modified ToolResponse allows filtering or redacting output. |
| 244 | + PostToolExecuteFunc func(ctx context.Context, toolCall ToolCall, response ToolResponse, executionTimeMs int64) (*ToolResponse, error) |
| 245 | + |
234 | 246 | // OnSourceFunc is called for source references. |
235 | 247 | OnSourceFunc func(source SourceContent) error |
236 | 248 |
|
@@ -280,7 +292,9 @@ type AgentStreamCall struct { |
280 | 292 | OnToolInputDelta OnToolInputDeltaFunc // Called for tool input deltas |
281 | 293 | OnToolInputEnd OnToolInputEndFunc // Called when tool input ends |
282 | 294 | OnToolCall OnToolCallFunc // Called when tool call is complete |
| 295 | + PreToolExecute PreToolExecuteFunc // Called before tool execution (can modify input or block) |
283 | 296 | OnToolResult OnToolResultFunc // Called when tool execution completes |
| 297 | + PostToolExecute PostToolExecuteFunc // Called after tool execution (can modify output) |
284 | 298 | OnSource OnSourceFunc // Called for source references |
285 | 299 | OnStreamFinish OnStreamFinishFunc // Called when stream finishes |
286 | 300 | } |
@@ -462,7 +476,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err |
462 | 476 | } |
463 | 477 | } |
464 | 478 |
|
465 | | - toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil) |
| 479 | + toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil, nil, nil) |
466 | 480 |
|
467 | 481 | // Build step content with validated tool calls and tool results |
468 | 482 | stepContent := []Content{} |
@@ -616,7 +630,7 @@ func toResponseMessages(content []Content) []Message { |
616 | 630 | return messages |
617 | 631 | } |
618 | 632 |
|
619 | | -func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) { |
| 633 | +func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error, preToolExecute PreToolExecuteFunc, postToolExecute PostToolExecuteFunc) ([]ToolResultContent, error) { |
620 | 634 | if len(toolCalls) == 0 { |
621 | 635 | return nil, nil |
622 | 636 | } |
@@ -669,12 +683,78 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall |
669 | 683 | continue |
670 | 684 | } |
671 | 685 |
|
672 | | - // Execute the tool |
673 | | - toolResult, err := tool.Run(ctx, ToolCall{ |
| 686 | + // Prepare tool call for execution |
| 687 | + executionToolCall := ToolCall{ |
674 | 688 | ID: toolCall.ToolCallID, |
675 | 689 | Name: toolCall.ToolName, |
676 | 690 | Input: toolCall.Input, |
677 | | - }) |
| 691 | + } |
| 692 | + |
| 693 | + // Call pre-tool execute hook |
| 694 | + var preHookErr error |
| 695 | + toolCtx := ctx |
| 696 | + if preToolExecute != nil { |
| 697 | + updatedCtx, modifiedCall, err := preToolExecute(ctx, executionToolCall) |
| 698 | + if err != nil { |
| 699 | + preHookErr = err |
| 700 | + } else { |
| 701 | + toolCtx = updatedCtx |
| 702 | + if modifiedCall != nil { |
| 703 | + executionToolCall = *modifiedCall |
| 704 | + } |
| 705 | + } |
| 706 | + } |
| 707 | + |
| 708 | + // If pre-hook returned error, create error result and skip execution |
| 709 | + if preHookErr != nil { |
| 710 | + result := ToolResultContent{ |
| 711 | + ToolCallID: toolCall.ToolCallID, |
| 712 | + ToolName: toolCall.ToolName, |
| 713 | + Result: ToolResultOutputContentError{ |
| 714 | + Error: preHookErr, |
| 715 | + }, |
| 716 | + ProviderExecuted: false, |
| 717 | + } |
| 718 | + results = append(results, result) |
| 719 | + if toolResultCallback != nil { |
| 720 | + if err := toolResultCallback(result); err != nil { |
| 721 | + return nil, err |
| 722 | + } |
| 723 | + } |
| 724 | + // Continue to next tool call instead of returning error |
| 725 | + continue |
| 726 | + } |
| 727 | + |
| 728 | + // Execute the tool with timing |
| 729 | + startTime := time.Now() |
| 730 | + toolResult, err := tool.Run(toolCtx, executionToolCall) |
| 731 | + executionTimeMs := time.Since(startTime).Milliseconds() |
| 732 | + |
| 733 | + // Call post-tool execute hook |
| 734 | + if postToolExecute != nil && err == nil { |
| 735 | + modifiedResponse, postErr := postToolExecute(ctx, executionToolCall, toolResult, executionTimeMs) |
| 736 | + if postErr != nil { |
| 737 | + // Post-hook error stops execution |
| 738 | + result := ToolResultContent{ |
| 739 | + ToolCallID: toolCall.ToolCallID, |
| 740 | + ToolName: toolCall.ToolName, |
| 741 | + Result: ToolResultOutputContentError{ |
| 742 | + Error: postErr, |
| 743 | + }, |
| 744 | + ClientMetadata: toolResult.Metadata, |
| 745 | + ProviderExecuted: false, |
| 746 | + } |
| 747 | + if toolResultCallback != nil { |
| 748 | + if cbErr := toolResultCallback(result); cbErr != nil { |
| 749 | + return nil, cbErr |
| 750 | + } |
| 751 | + } |
| 752 | + return nil, postErr |
| 753 | + } else if modifiedResponse != nil { |
| 754 | + toolResult = *modifiedResponse |
| 755 | + } |
| 756 | + } |
| 757 | + |
678 | 758 | if err != nil { |
679 | 759 | result := ToolResultContent{ |
680 | 760 | ToolCallID: toolCall.ToolCallID, |
@@ -1307,7 +1387,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op |
1307 | 1387 | var toolResults []ToolResultContent |
1308 | 1388 | if len(stepToolCalls) > 0 { |
1309 | 1389 | var err error |
1310 | | - toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult) |
| 1390 | + toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult, opts.PreToolExecute, opts.PostToolExecute) |
1311 | 1391 | if err != nil { |
1312 | 1392 | return stepExecutionResult{}, err |
1313 | 1393 | } |
|
0 commit comments