diff --git a/agent.go b/agent.go index 0a81daa3e..747ac86de 100644 --- a/agent.go +++ b/agent.go @@ -103,6 +103,7 @@ type PrepareStepResult struct { ToolChoice *ToolChoice ActiveTools []string DisableAllTools bool + Tools []AgentTool } // ToolCallRepairOptions contains the options for repairing a tool call. @@ -376,7 +377,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err stepActiveTools := opts.ActiveTools stepToolChoice := ToolChoiceAuto disableAllTools := false - + stepTools := a.settings.tools if opts.PrepareStep != nil { updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{ Model: stepModel, @@ -407,6 +408,9 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err stepActiveTools = prepared.ActiveTools } disableAllTools = prepared.DisableAllTools + if prepared.Tools != nil { + stepTools = prepared.Tools + } } // Recreate prompt with potentially modified system prompt @@ -421,7 +425,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err } } - preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools) + preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools) retryOptions := DefaultRetryOptions() if opts.MaxRetries != nil { @@ -457,12 +461,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err } // Validate and potentially repair the tool call - validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall) + validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall) stepToolCalls = append(stepToolCalls, validatedToolCall) } } - toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil) + toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil) // Build step content with validated tool calls and tool results stepContent := []Content{} @@ -771,7 +775,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, stepActiveTools := call.ActiveTools stepToolChoice := ToolChoiceAuto disableAllTools := false - + stepTools := a.settings.tools // Apply step preparation if provided if call.PrepareStep != nil { updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{ @@ -802,6 +806,9 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, stepActiveTools = prepared.ActiveTools } disableAllTools = prepared.DisableAllTools + if prepared.Tools != nil { + stepTools = prepared.Tools + } } // Recreate prompt with potentially modified system prompt @@ -815,7 +822,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, } } - preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools) + preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools) // Start step stream if opts.OnStepStart != nil { @@ -852,7 +859,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, } // Process the stream - result, err := a.processStepStream(ctx, stream, opts, steps) + result, err := a.processStepStream(ctx, stream, opts, steps, stepTools) if err != nil { return stepExecutionResult{}, err } @@ -1098,7 +1105,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption { } // processStepStream processes a single step's stream and returns the step result. -func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) { +func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) { var stepContent []Content var stepToolCalls []ToolCallContent var stepUsage Usage @@ -1257,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } // Validate and potentially repair the tool call - validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall) + validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall) stepToolCalls = append(stepToolCalls, validatedToolCall) stepContent = append(stepContent, validatedToolCall) @@ -1307,7 +1314,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op var toolResults []ToolResultContent if len(stepToolCalls) > 0 { var err error - toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult) + toolResults, err = a.executeTools(ctx, stepTools, stepToolCalls, opts.OnToolResult) if err != nil { return stepExecutionResult{}, err }