Skip to content

Commit ddf3ee0

Browse files
authored
fix: retry logic for stream processing errors in agent Stream function (#75)
1 parent 0132f7c commit ddf3ee0

File tree

1 file changed

+48
-38
lines changed

1 file changed

+48
-38
lines changed

agent.go

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ type StepResult struct {
1616
Messages []Message
1717
}
1818

19+
// stepExecutionResult encapsulates the result of executing a step with stream processing.
20+
type stepExecutionResult struct {
21+
StepResult StepResult
22+
ShouldContinue bool
23+
}
24+
1925
// StopCondition defines a function that determines when an agent should stop executing.
2026
type StopCondition = func(steps []StepResult) bool
2127

@@ -736,6 +742,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
736742
ActiveTools: opts.ActiveTools,
737743
ProviderOptions: opts.ProviderOptions,
738744
MaxRetries: opts.MaxRetries,
745+
OnRetry: opts.OnRetry,
739746
StopWhen: opts.StopWhen,
740747
PrepareStep: opts.PrepareStep,
741748
RepairToolCall: opts.RepairToolCall,
@@ -829,48 +836,51 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
829836
ProviderOptions: call.ProviderOptions,
830837
}
831838

832-
// Get streaming response with retry logic
839+
// Execute step with retry logic wrapping both stream creation and processing
833840
retryOptions := DefaultRetryOptions()
834841
if call.MaxRetries != nil {
835842
retryOptions.MaxRetries = *call.MaxRetries
836843
}
837844
retryOptions.OnRetry = call.OnRetry
838-
retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions)
845+
retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
839846

840-
stream, err := retry(ctx, func() (StreamResponse, error) {
841-
return stepModel.Stream(ctx, streamCall)
842-
})
843-
if err != nil {
844-
if opts.OnError != nil {
845-
opts.OnError(err)
847+
result, err := retry(ctx, func() (stepExecutionResult, error) {
848+
// Create the stream
849+
stream, err := stepModel.Stream(ctx, streamCall)
850+
if err != nil {
851+
return stepExecutionResult{}, err
852+
}
853+
854+
// Process the stream
855+
result, err := a.processStepStream(ctx, stream, opts, steps)
856+
if err != nil {
857+
return stepExecutionResult{}, err
846858
}
847-
return nil, err
848-
}
849859

850-
// Process stream with tool execution
851-
stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
860+
return result, nil
861+
})
852862
if err != nil {
853863
if opts.OnError != nil {
854864
opts.OnError(err)
855865
}
856866
return nil, err
857867
}
858868

859-
steps = append(steps, stepResult)
860-
totalUsage = addUsage(totalUsage, stepResult.Usage)
869+
steps = append(steps, result.StepResult)
870+
totalUsage = addUsage(totalUsage, result.StepResult.Usage)
861871

862872
// Call step finished callback
863873
if opts.OnStepFinish != nil {
864-
_ = opts.OnStepFinish(stepResult)
874+
_ = opts.OnStepFinish(result.StepResult)
865875
}
866876

867877
// Add step messages to response messages
868-
stepMessages := toResponseMessages(stepResult.Content)
878+
stepMessages := toResponseMessages(result.StepResult.Content)
869879
responseMessages = append(responseMessages, stepMessages...)
870880

871881
// Check stop conditions
872882
shouldStop := isStopConditionMet(call.StopWhen, steps)
873-
if shouldStop || !shouldContinue {
883+
if shouldStop || !result.ShouldContinue {
874884
break
875885
}
876886
}
@@ -1088,7 +1098,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption {
10881098
}
10891099

10901100
// processStepStream processes a single step's stream and returns the step result.
1091-
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
1101+
func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) {
10921102
var stepContent []Content
10931103
var stepToolCalls []ToolCallContent
10941104
var stepUsage Usage
@@ -1110,7 +1120,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11101120
if opts.OnChunk != nil {
11111121
err := opts.OnChunk(part)
11121122
if err != nil {
1113-
return StepResult{}, false, err
1123+
return stepExecutionResult{}, err
11141124
}
11151125
}
11161126

@@ -1120,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11201130
if opts.OnWarnings != nil {
11211131
err := opts.OnWarnings(part.Warnings)
11221132
if err != nil {
1123-
return StepResult{}, false, err
1133+
return stepExecutionResult{}, err
11241134
}
11251135
}
11261136

@@ -1129,7 +1139,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11291139
if opts.OnTextStart != nil {
11301140
err := opts.OnTextStart(part.ID)
11311141
if err != nil {
1132-
return StepResult{}, false, err
1142+
return stepExecutionResult{}, err
11331143
}
11341144
}
11351145

@@ -1140,7 +1150,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11401150
if opts.OnTextDelta != nil {
11411151
err := opts.OnTextDelta(part.ID, part.Delta)
11421152
if err != nil {
1143-
return StepResult{}, false, err
1153+
return stepExecutionResult{}, err
11441154
}
11451155
}
11461156

@@ -1155,7 +1165,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11551165
if opts.OnTextEnd != nil {
11561166
err := opts.OnTextEnd(part.ID)
11571167
if err != nil {
1158-
return StepResult{}, false, err
1168+
return stepExecutionResult{}, err
11591169
}
11601170
}
11611171

@@ -1168,7 +1178,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11681178
}
11691179
err := opts.OnReasoningStart(part.ID, content)
11701180
if err != nil {
1171-
return StepResult{}, false, err
1181+
return stepExecutionResult{}, err
11721182
}
11731183
}
11741184

@@ -1181,7 +1191,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11811191
if opts.OnReasoningDelta != nil {
11821192
err := opts.OnReasoningDelta(part.ID, part.Delta)
11831193
if err != nil {
1184-
return StepResult{}, false, err
1194+
return stepExecutionResult{}, err
11851195
}
11861196
}
11871197

@@ -1198,7 +1208,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11981208
if opts.OnReasoningEnd != nil {
11991209
err := opts.OnReasoningEnd(part.ID, content)
12001210
if err != nil {
1201-
return StepResult{}, false, err
1211+
return stepExecutionResult{}, err
12021212
}
12031213
}
12041214
delete(activeReasoningContent, part.ID)
@@ -1214,7 +1224,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12141224
if opts.OnToolInputStart != nil {
12151225
err := opts.OnToolInputStart(part.ID, part.ToolCallName)
12161226
if err != nil {
1217-
return StepResult{}, false, err
1227+
return stepExecutionResult{}, err
12181228
}
12191229
}
12201230

@@ -1225,15 +1235,15 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12251235
if opts.OnToolInputDelta != nil {
12261236
err := opts.OnToolInputDelta(part.ID, part.Delta)
12271237
if err != nil {
1228-
return StepResult{}, false, err
1238+
return stepExecutionResult{}, err
12291239
}
12301240
}
12311241

12321242
case StreamPartTypeToolInputEnd:
12331243
if opts.OnToolInputEnd != nil {
12341244
err := opts.OnToolInputEnd(part.ID)
12351245
if err != nil {
1236-
return StepResult{}, false, err
1246+
return stepExecutionResult{}, err
12371247
}
12381248
}
12391249

@@ -1254,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12541264
if opts.OnToolCall != nil {
12551265
err := opts.OnToolCall(validatedToolCall)
12561266
if err != nil {
1257-
return StepResult{}, false, err
1267+
return stepExecutionResult{}, err
12581268
}
12591269
}
12601270

@@ -1273,7 +1283,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12731283
if opts.OnSource != nil {
12741284
err := opts.OnSource(sourceContent)
12751285
if err != nil {
1276-
return StepResult{}, false, err
1286+
return stepExecutionResult{}, err
12771287
}
12781288
}
12791289

@@ -1284,15 +1294,12 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12841294
if opts.OnStreamFinish != nil {
12851295
err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
12861296
if err != nil {
1287-
return StepResult{}, false, err
1297+
return stepExecutionResult{}, err
12881298
}
12891299
}
12901300

12911301
case StreamPartTypeError:
1292-
if opts.OnError != nil {
1293-
opts.OnError(part.Error)
1294-
}
1295-
return StepResult{}, false, part.Error
1302+
return stepExecutionResult{}, part.Error
12961303
}
12971304
}
12981305

@@ -1302,7 +1309,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
13021309
var err error
13031310
toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
13041311
if err != nil {
1305-
return StepResult{}, false, err
1312+
return stepExecutionResult{}, err
13061313
}
13071314
// Add tool results to content
13081315
for _, result := range toolResults {
@@ -1324,7 +1331,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
13241331
// Determine if we should continue (has tool calls and not stopped)
13251332
shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
13261333

1327-
return stepResult, shouldContinue, nil
1334+
return stepExecutionResult{
1335+
StepResult: stepResult,
1336+
ShouldContinue: shouldContinue,
1337+
}, nil
13281338
}
13291339

13301340
func addUsage(a, b Usage) Usage {

0 commit comments

Comments
 (0)