Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ type PrepareStepResult struct {
ToolChoice *ToolChoice
ActiveTools []string
DisableAllTools bool
Tools []AgentTool
}

// ToolCallRepairOptions contains the options for repairing a tool call.
Expand Down Expand Up @@ -376,7 +377,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
stepActiveTools := opts.ActiveTools
stepToolChoice := ToolChoiceAuto
disableAllTools := false

stepTools := a.settings.tools
if opts.PrepareStep != nil {
updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
Model: stepModel,
Expand Down Expand Up @@ -407,6 +408,9 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
stepActiveTools = prepared.ActiveTools
}
disableAllTools = prepared.DisableAllTools
if prepared.Tools != nil {
stepTools = prepared.Tools
}
}

// Recreate prompt with potentially modified system prompt
Expand All @@ -421,7 +425,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}
}

preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)

retryOptions := DefaultRetryOptions()
if opts.MaxRetries != nil {
Expand Down Expand Up @@ -457,12 +461,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}

// Validate and potentially repair the tool call
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
stepToolCalls = append(stepToolCalls, validatedToolCall)
}
}

toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)

// Build step content with validated tool calls and tool results
stepContent := []Content{}
Expand Down Expand Up @@ -771,7 +775,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
stepActiveTools := call.ActiveTools
stepToolChoice := ToolChoiceAuto
disableAllTools := false

stepTools := a.settings.tools
// Apply step preparation if provided
if call.PrepareStep != nil {
updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
Expand Down Expand Up @@ -802,6 +806,9 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
stepActiveTools = prepared.ActiveTools
}
disableAllTools = prepared.DisableAllTools
if prepared.Tools != nil {
stepTools = prepared.Tools
}
}

// Recreate prompt with potentially modified system prompt
Expand All @@ -815,7 +822,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
}
}

preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)

// Start step stream
if opts.OnStepStart != nil {
Expand Down Expand Up @@ -852,7 +859,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
}

// Process the stream
result, err := a.processStepStream(ctx, stream, opts, steps)
result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
if err != nil {
return stepExecutionResult{}, err
}
Expand Down Expand Up @@ -1098,7 +1105,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption {
}

// processStepStream processes a single step's stream and returns the step result.
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) {
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
var stepContent []Content
var stepToolCalls []ToolCallContent
var stepUsage Usage
Expand Down Expand Up @@ -1257,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}

// Validate and potentially repair the tool call
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
stepToolCalls = append(stepToolCalls, validatedToolCall)
stepContent = append(stepContent, validatedToolCall)

Expand Down Expand Up @@ -1307,7 +1314,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
var toolResults []ToolResultContent
if len(stepToolCalls) > 0 {
var err error
toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
toolResults, err = a.executeTools(ctx, stepTools, stepToolCalls, opts.OnToolResult)
if err != nil {
return stepExecutionResult{}, err
}
Expand Down
Loading