Skip to content

Commit 3843a7e

Browse files
committed
chore: allow updating tools in prepare step
1 parent 08fa737 commit 3843a7e

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

agent.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ type PrepareStepResult struct {
103103
ToolChoice *ToolChoice
104104
ActiveTools []string
105105
DisableAllTools bool
106+
Tools []AgentTool
106107
}
107108

108109
// ToolCallRepairOptions contains the options for repairing a tool call.
@@ -376,7 +377,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
376377
stepActiveTools := opts.ActiveTools
377378
stepToolChoice := ToolChoiceAuto
378379
disableAllTools := false
379-
380+
stepTools := a.settings.tools
380381
if opts.PrepareStep != nil {
381382
updatedCtx, prepared, err := opts.PrepareStep(ctx, PrepareStepFunctionOptions{
382383
Model: stepModel,
@@ -407,6 +408,9 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
407408
stepActiveTools = prepared.ActiveTools
408409
}
409410
disableAllTools = prepared.DisableAllTools
411+
if prepared.Tools != nil {
412+
stepTools = prepared.Tools
413+
}
410414
}
411415

412416
// Recreate prompt with potentially modified system prompt
@@ -421,7 +425,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
421425
}
422426
}
423427

424-
preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
428+
preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
425429

426430
retryOptions := DefaultRetryOptions()
427431
if opts.MaxRetries != nil {
@@ -457,12 +461,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
457461
}
458462

459463
// Validate and potentially repair the tool call
460-
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
464+
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
461465
stepToolCalls = append(stepToolCalls, validatedToolCall)
462466
}
463467
}
464468

465-
toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
469+
toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)
466470

467471
// Build step content with validated tool calls and tool results
468472
stepContent := []Content{}
@@ -771,7 +775,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
771775
stepActiveTools := call.ActiveTools
772776
stepToolChoice := ToolChoiceAuto
773777
disableAllTools := false
774-
778+
stepTools := a.settings.tools
775779
// Apply step preparation if provided
776780
if call.PrepareStep != nil {
777781
updatedCtx, prepared, err := call.PrepareStep(ctx, PrepareStepFunctionOptions{
@@ -802,6 +806,9 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
802806
stepActiveTools = prepared.ActiveTools
803807
}
804808
disableAllTools = prepared.DisableAllTools
809+
if prepared.Tools != nil {
810+
stepTools = prepared.Tools
811+
}
805812
}
806813

807814
// Recreate prompt with potentially modified system prompt
@@ -815,7 +822,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
815822
}
816823
}
817824

818-
preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
825+
preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
819826

820827
// Start step stream
821828
if opts.OnStepStart != nil {
@@ -852,7 +859,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
852859
}
853860

854861
// Process the stream
855-
result, err := a.processStepStream(ctx, stream, opts, steps)
862+
result, err := a.processStepStream(ctx, stream, opts, steps, stepTools)
856863
if err != nil {
857864
return stepExecutionResult{}, err
858865
}
@@ -1098,7 +1105,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption {
10981105
}
10991106

11001107
// processStepStream processes a single step's stream and returns the step result.
1101-
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) {
1108+
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult, stepTools []AgentTool) (stepExecutionResult, error) {
11021109
var stepContent []Content
11031110
var stepToolCalls []ToolCallContent
11041111
var stepUsage Usage
@@ -1257,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12571264
}
12581265

12591266
// Validate and potentially repair the tool call
1260-
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
1267+
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
12611268
stepToolCalls = append(stepToolCalls, validatedToolCall)
12621269
stepContent = append(stepContent, validatedToolCall)
12631270

@@ -1307,7 +1314,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
13071314
var toolResults []ToolResultContent
13081315
if len(stepToolCalls) > 0 {
13091316
var err error
1310-
toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
1317+
toolResults, err = a.executeTools(ctx, stepTools, stepToolCalls, opts.OnToolResult)
13111318
if err != nil {
13121319
return stepExecutionResult{}, err
13131320
}

0 commit comments

Comments
 (0)