@@ -16,6 +16,12 @@ type StepResult struct {
1616 Messages []Message
1717}
1818
19+ // stepExecutionResult encapsulates the result of executing a step with stream processing.
20+ type stepExecutionResult struct {
21+ StepResult StepResult
22+ ShouldContinue bool
23+ }
24+
1925// StopCondition defines a function that determines when an agent should stop executing.
2026type StopCondition = func (steps []StepResult ) bool
2127
@@ -736,6 +742,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
736742 ActiveTools : opts .ActiveTools ,
737743 ProviderOptions : opts .ProviderOptions ,
738744 MaxRetries : opts .MaxRetries ,
745+ OnRetry : opts .OnRetry ,
739746 StopWhen : opts .StopWhen ,
740747 PrepareStep : opts .PrepareStep ,
741748 RepairToolCall : opts .RepairToolCall ,
@@ -829,48 +836,51 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
829836 ProviderOptions : call .ProviderOptions ,
830837 }
831838
832- // Get streaming response with retry logic
839+ // Execute step with retry logic wrapping both stream creation and processing
833840 retryOptions := DefaultRetryOptions ()
834841 if call .MaxRetries != nil {
835842 retryOptions .MaxRetries = * call .MaxRetries
836843 }
837844 retryOptions .OnRetry = call .OnRetry
838- retry := RetryWithExponentialBackoffRespectingRetryHeaders [StreamResponse ](retryOptions )
845+ retry := RetryWithExponentialBackoffRespectingRetryHeaders [stepExecutionResult ](retryOptions )
839846
840- stream , err := retry (ctx , func () (StreamResponse , error ) {
841- return stepModel .Stream (ctx , streamCall )
842- })
843- if err != nil {
844- if opts .OnError != nil {
845- opts .OnError (err )
847+ result , err := retry (ctx , func () (stepExecutionResult , error ) {
848+ // Create the stream
849+ stream , err := stepModel .Stream (ctx , streamCall )
850+ if err != nil {
851+ return stepExecutionResult {}, err
852+ }
853+
854+ // Process the stream
855+ result , err := a .processStepStream (ctx , stream , opts , steps )
856+ if err != nil {
857+ return stepExecutionResult {}, err
846858 }
847- return nil , err
848- }
849859
850- // Process stream with tool execution
851- stepResult , shouldContinue , err := a . processStepStream ( ctx , stream , opts , steps )
860+ return result , nil
861+ } )
852862 if err != nil {
853863 if opts .OnError != nil {
854864 opts .OnError (err )
855865 }
856866 return nil , err
857867 }
858868
859- steps = append (steps , stepResult )
860- totalUsage = addUsage (totalUsage , stepResult .Usage )
869+ steps = append (steps , result . StepResult )
870+ totalUsage = addUsage (totalUsage , result . StepResult .Usage )
861871
862872 // Call step finished callback
863873 if opts .OnStepFinish != nil {
864- _ = opts .OnStepFinish (stepResult )
874+ _ = opts .OnStepFinish (result . StepResult )
865875 }
866876
867877 // Add step messages to response messages
868- stepMessages := toResponseMessages (stepResult .Content )
878+ stepMessages := toResponseMessages (result . StepResult .Content )
869879 responseMessages = append (responseMessages , stepMessages ... )
870880
871881 // Check stop conditions
872882 shouldStop := isStopConditionMet (call .StopWhen , steps )
873- if shouldStop || ! shouldContinue {
883+ if shouldStop || ! result . ShouldContinue {
874884 break
875885 }
876886 }
@@ -1088,7 +1098,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption {
10881098}
10891099
10901100// processStepStream processes a single step's stream and returns the step result.
1091- func (a * agent ) processStepStream (ctx context.Context , stream StreamResponse , opts AgentStreamCall , _ []StepResult ) (StepResult , bool , error ) {
1101+ func (a * agent ) processStepStream (ctx context.Context , stream StreamResponse , opts AgentStreamCall , _ []StepResult ) (stepExecutionResult , error ) {
10921102 var stepContent []Content
10931103 var stepToolCalls []ToolCallContent
10941104 var stepUsage Usage
@@ -1110,7 +1120,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11101120 if opts .OnChunk != nil {
11111121 err := opts .OnChunk (part )
11121122 if err != nil {
1113- return StepResult {}, false , err
1123+ return stepExecutionResult {} , err
11141124 }
11151125 }
11161126
@@ -1120,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11201130 if opts .OnWarnings != nil {
11211131 err := opts .OnWarnings (part .Warnings )
11221132 if err != nil {
1123- return StepResult {}, false , err
1133+ return stepExecutionResult {} , err
11241134 }
11251135 }
11261136
@@ -1129,7 +1139,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11291139 if opts .OnTextStart != nil {
11301140 err := opts .OnTextStart (part .ID )
11311141 if err != nil {
1132- return StepResult {}, false , err
1142+ return stepExecutionResult {} , err
11331143 }
11341144 }
11351145
@@ -1140,7 +1150,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11401150 if opts .OnTextDelta != nil {
11411151 err := opts .OnTextDelta (part .ID , part .Delta )
11421152 if err != nil {
1143- return StepResult {}, false , err
1153+ return stepExecutionResult {} , err
11441154 }
11451155 }
11461156
@@ -1155,7 +1165,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11551165 if opts .OnTextEnd != nil {
11561166 err := opts .OnTextEnd (part .ID )
11571167 if err != nil {
1158- return StepResult {}, false , err
1168+ return stepExecutionResult {} , err
11591169 }
11601170 }
11611171
@@ -1168,7 +1178,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11681178 }
11691179 err := opts .OnReasoningStart (part .ID , content )
11701180 if err != nil {
1171- return StepResult {}, false , err
1181+ return stepExecutionResult {} , err
11721182 }
11731183 }
11741184
@@ -1181,7 +1191,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11811191 if opts .OnReasoningDelta != nil {
11821192 err := opts .OnReasoningDelta (part .ID , part .Delta )
11831193 if err != nil {
1184- return StepResult {}, false , err
1194+ return stepExecutionResult {} , err
11851195 }
11861196 }
11871197
@@ -1198,7 +1208,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
11981208 if opts .OnReasoningEnd != nil {
11991209 err := opts .OnReasoningEnd (part .ID , content )
12001210 if err != nil {
1201- return StepResult {}, false , err
1211+ return stepExecutionResult {} , err
12021212 }
12031213 }
12041214 delete (activeReasoningContent , part .ID )
@@ -1214,7 +1224,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12141224 if opts .OnToolInputStart != nil {
12151225 err := opts .OnToolInputStart (part .ID , part .ToolCallName )
12161226 if err != nil {
1217- return StepResult {}, false , err
1227+ return stepExecutionResult {} , err
12181228 }
12191229 }
12201230
@@ -1225,15 +1235,15 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12251235 if opts .OnToolInputDelta != nil {
12261236 err := opts .OnToolInputDelta (part .ID , part .Delta )
12271237 if err != nil {
1228- return StepResult {}, false , err
1238+ return stepExecutionResult {} , err
12291239 }
12301240 }
12311241
12321242 case StreamPartTypeToolInputEnd :
12331243 if opts .OnToolInputEnd != nil {
12341244 err := opts .OnToolInputEnd (part .ID )
12351245 if err != nil {
1236- return StepResult {}, false , err
1246+ return stepExecutionResult {} , err
12371247 }
12381248 }
12391249
@@ -1254,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12541264 if opts .OnToolCall != nil {
12551265 err := opts .OnToolCall (validatedToolCall )
12561266 if err != nil {
1257- return StepResult {}, false , err
1267+ return stepExecutionResult {} , err
12581268 }
12591269 }
12601270
@@ -1273,7 +1283,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12731283 if opts .OnSource != nil {
12741284 err := opts .OnSource (sourceContent )
12751285 if err != nil {
1276- return StepResult {}, false , err
1286+ return stepExecutionResult {} , err
12771287 }
12781288 }
12791289
@@ -1284,15 +1294,12 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
12841294 if opts .OnStreamFinish != nil {
12851295 err := opts .OnStreamFinish (part .Usage , part .FinishReason , part .ProviderMetadata )
12861296 if err != nil {
1287- return StepResult {}, false , err
1297+ return stepExecutionResult {} , err
12881298 }
12891299 }
12901300
12911301 case StreamPartTypeError :
1292- if opts .OnError != nil {
1293- opts .OnError (part .Error )
1294- }
1295- return StepResult {}, false , part .Error
1302+ return stepExecutionResult {}, part .Error
12961303 }
12971304 }
12981305
@@ -1302,7 +1309,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
13021309 var err error
13031310 toolResults , err = a .executeTools (ctx , a .settings .tools , stepToolCalls , opts .OnToolResult )
13041311 if err != nil {
1305- return StepResult {}, false , err
1312+ return stepExecutionResult {} , err
13061313 }
13071314 // Add tool results to content
13081315 for _ , result := range toolResults {
@@ -1324,7 +1331,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
13241331 // Determine if we should continue (has tool calls and not stopped)
13251332 shouldContinue := len (stepToolCalls ) > 0 && stepFinishReason == FinishReasonToolCalls
13261333
1327- return stepResult , shouldContinue , nil
1334+ return stepExecutionResult {
1335+ StepResult : stepResult ,
1336+ ShouldContinue : shouldContinue ,
1337+ }, nil
13281338}
13291339
13301340func addUsage (a , b Usage ) Usage {
0 commit comments