Skip to content

Commit fe54e84

Browse files
fix: supports direct streaming output from host while using host multi-agent callback (#236)
1 parent a19ae2a commit fe54e84

File tree

3 files changed

+114
-83
lines changed

3 files changed

+114
-83
lines changed

flow/agent/multiagent/host/callback.go

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ package host
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"io"
23+
"runtime/debug"
2224

2325
"github.com/cloudwego/eino/callbacks"
2426
"github.com/cloudwego/eino/components/model"
2527
"github.com/cloudwego/eino/flow/agent"
28+
"github.com/cloudwego/eino/internal/safe"
2629
"github.com/cloudwego/eino/schema"
2730
template "github.com/cloudwego/eino/utils/callbacks"
2831
)
@@ -41,10 +44,6 @@ type HandOffInfo struct {
4144
// ConvertCallbackHandlers converts []host.MultiAgentCallback to callbacks.Handler.
4245
func ConvertCallbackHandlers(handlers ...MultiAgentCallback) callbacks.Handler {
4346
onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
44-
if output == nil || info == nil {
45-
return ctx
46-
}
47-
4847
msg := output.Message
4948
if msg == nil || msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
5049
return ctx
@@ -63,47 +62,48 @@ func ConvertCallbackHandlers(handlers ...MultiAgentCallback) callbacks.Handler {
6362
}
6463

6564
onChatModelEndWithStreamOutput := func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context {
66-
if output == nil || info == nil {
67-
return ctx
68-
}
69-
70-
defer output.Close()
71-
72-
var msgs []*schema.Message
73-
for {
74-
oneOutput, err := output.Recv()
75-
if err == io.EOF {
76-
break
77-
}
78-
if err != nil {
79-
return ctx
65+
go func() {
66+
defer func() {
67+
panicInfo := recover()
68+
if panicInfo != nil {
69+
fmt.Println(safe.NewPanicErr(panicInfo, debug.Stack()))
70+
}
71+
output.Close()
72+
}()
73+
74+
handOffs := make(map[string]string)
75+
var handOffOrder []string
76+
for {
77+
oneOutput, err := output.Recv()
78+
if err != nil {
79+
if err == io.EOF {
80+
break
81+
}
82+
return
83+
}
84+
85+
for _, toolCall := range oneOutput.Message.ToolCalls {
86+
if len(toolCall.Function.Name) > 0 {
87+
if existing, ok := handOffs[toolCall.Function.Name]; !ok {
88+
handOffOrder = append(handOffOrder, toolCall.Function.Name)
89+
handOffs[toolCall.Function.Name] = toolCall.Function.Arguments
90+
} else {
91+
handOffs[toolCall.Function.Name] = existing + toolCall.Function.Arguments
92+
}
93+
}
94+
}
8095
}
8196

82-
msg := oneOutput.Message
83-
if msg == nil {
84-
continue
97+
for _, cb := range handlers {
98+
for _, name := range handOffOrder {
99+
args := handOffs[name]
100+
_ = cb.OnHandOff(ctx, &HandOffInfo{
101+
ToAgentName: name,
102+
Argument: args,
103+
})
104+
}
85105
}
86-
87-
msgs = append(msgs, msg)
88-
}
89-
90-
msg, err := schema.ConcatMessages(msgs)
91-
if err != nil {
92-
return ctx
93-
}
94-
95-
if msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
96-
return ctx
97-
}
98-
99-
for _, cb := range handlers {
100-
for _, toolCall := range msg.ToolCalls {
101-
ctx = cb.OnHandOff(ctx, &HandOffInfo{
102-
ToAgentName: toolCall.Function.Name,
103-
Argument: toolCall.Function.Arguments,
104-
})
105-
}
106-
}
106+
}()
107107

108108
return ctx
109109
}

flow/agent/multiagent/host/compose_test.go

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package host
1919
import (
2020
"context"
2121
"io"
22+
"sync"
2223
"testing"
2324

2425
"github.com/stretchr/testify/assert"
@@ -46,6 +47,15 @@ func TestHostMultiAgent(t *testing.T) {
4647
},
4748
}
4849

50+
specialist2Msg1 := &schema.Message{
51+
Role: schema.Assistant,
52+
Content: "specialist2",
53+
}
54+
specialist2Msg2 := &schema.Message{
55+
Role: schema.Assistant,
56+
Content: " stream answer",
57+
}
58+
4959
specialist2 := &Specialist{
5060
Invokable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) {
5161
return &schema.Message{
@@ -54,15 +64,7 @@ func TestHostMultiAgent(t *testing.T) {
5464
}, nil
5565
},
5666
Streamable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.StreamReader[*schema.Message], error) {
57-
sr, sw := schema.Pipe[*schema.Message](0)
58-
go func() {
59-
sw.Send(&schema.Message{
60-
Role: schema.Assistant,
61-
Content: "specialist2 stream answer",
62-
}, nil)
63-
sw.Close()
64-
}()
65-
return sr, nil
67+
return schema.StreamReaderFromArray([]*schema.Message{specialist2Msg1, specialist2Msg2}), nil
6668
},
6769
AgentMeta: AgentMeta{
6870
Name: "specialist 2",
@@ -94,7 +96,7 @@ func TestHostMultiAgent(t *testing.T) {
9496

9597
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(directAnswerMsg, nil).Times(1)
9698

97-
mockCallback := &mockAgentCallback{}
99+
mockCallback := newMockAgentCallback(0)
98100

99101
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
100102
assert.NoError(t, err)
@@ -122,7 +124,7 @@ func TestHostMultiAgent(t *testing.T) {
122124

123125
mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1)
124126

125-
mockCallback := &mockAgentCallback{}
127+
mockCallback := newMockAgentCallback(0)
126128
outStream, err := hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback))
127129
assert.NoError(t, err)
128130
assert.Empty(t, mockCallback.infos)
@@ -139,9 +141,8 @@ func TestHostMultiAgent(t *testing.T) {
139141

140142
outStream.Close()
141143

142-
msg, err := schema.ConcatMessages(msgs)
143-
assert.NoError(t, err)
144-
assert.Equal(t, "direct answer", msg.Content)
144+
assert.Equal(t, directAnswerMsg1, msgs[0])
145+
assert.Equal(t, directAnswerMsg2, msgs[1])
145146
})
146147

147148
t.Run("generate hand off", func(t *testing.T) {
@@ -166,11 +167,12 @@ func TestHostMultiAgent(t *testing.T) {
166167
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
167168
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)
168169

169-
mockCallback := &mockAgentCallback{}
170+
mockCallback := newMockAgentCallback(1)
170171

171172
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
172173
assert.NoError(t, err)
173174
assert.Equal(t, "specialist 1 answer", out.Content)
175+
mockCallback.wg.Wait()
174176
assert.Equal(t, []*HandOffInfo{
175177
{
176178
ToAgentName: specialist1.Name,
@@ -182,11 +184,12 @@ func TestHostMultiAgent(t *testing.T) {
182184
handOffMsg.ToolCalls[0].Function.Arguments = `{"reason": "specialist 2 is even better"}`
183185
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
184186

185-
mockCallback = &mockAgentCallback{}
187+
mockCallback = newMockAgentCallback(1)
186188

187189
out, err = hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
188190
assert.NoError(t, err)
189191
assert.Equal(t, "specialist2 invoke answer", out.Content)
192+
mockCallback.wg.Wait()
190193
assert.Equal(t, []*HandOffInfo{
191194
{
192195
ToAgentName: specialist2.Name,
@@ -297,7 +300,7 @@ func TestHostMultiAgent(t *testing.T) {
297300
mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1)
298301
mockSpecialistLLM1.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr1, nil).Times(1)
299302

300-
mockCallback := &mockAgentCallback{}
303+
mockCallback := newMockAgentCallback(1)
301304
outStream, err := hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback))
302305
assert.NoError(t, err)
303306

@@ -313,9 +316,10 @@ func TestHostMultiAgent(t *testing.T) {
313316

314317
outStream.Close()
315318

316-
msg, err := schema.ConcatMessages(msgs)
317-
assert.NoError(t, err)
318-
assert.Equal(t, "specialist 1 answer", msg.Content)
319+
assert.Equal(t, specialistMsg1, msgs[0])
320+
assert.Equal(t, specialistMsg2, msgs[1])
321+
322+
mockCallback.wg.Wait()
319323

320324
assert.Equal(t, []*HandOffInfo{
321325
{
@@ -337,7 +341,7 @@ func TestHostMultiAgent(t *testing.T) {
337341

338342
mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1)
339343

340-
mockCallback = &mockAgentCallback{}
344+
mockCallback = newMockAgentCallback(1)
341345
outStream, err = hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback))
342346
assert.NoError(t, err)
343347

@@ -353,9 +357,10 @@ func TestHostMultiAgent(t *testing.T) {
353357

354358
outStream.Close()
355359

356-
msg, err = schema.ConcatMessages(msgs)
357-
assert.NoError(t, err)
358-
assert.Equal(t, "specialist2 stream answer", msg.Content)
360+
assert.Equal(t, specialist2Msg1, msgs[0])
361+
assert.Equal(t, specialist2Msg2, msgs[1])
362+
363+
mockCallback.wg.Wait()
359364

360365
assert.Equal(t, []*HandOffInfo{
361366
{
@@ -387,7 +392,7 @@ func TestHostMultiAgent(t *testing.T) {
387392
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
388393
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)
389394

390-
mockCallback := &mockAgentCallback{}
395+
mockCallback := newMockAgentCallback(1)
391396

392397
hostMA, err := NewMultiAgent(ctx, &MultiAgentConfig{
393398
Host: Host{
@@ -412,6 +417,8 @@ func TestHostMultiAgent(t *testing.T) {
412417
out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, compose.WithCallbacks(ConvertCallbackHandlers(mockCallback)).DesignateNodeWithPath(compose.NewNodePath("host_ma_node", hostMA.HostNodeKey())))
413418
assert.NoError(t, err)
414419
assert.Equal(t, "Beijing", out.Content)
420+
421+
mockCallback.wg.Wait()
415422
assert.Equal(t, []*HandOffInfo{
416423
{
417424
ToAgentName: specialist1.Name,
@@ -458,20 +465,32 @@ func TestHostMultiAgent(t *testing.T) {
458465
Index: generic.PtrOf(1),
459466
Function: schema.FunctionCall{
460467
Name: specialist2.Name,
461-
Arguments: `{"reason": "specialist 2 is also good"}`,
468+
Arguments: `{"reason": "specialist 2`,
462469
},
463470
},
464471
},
465472
}
466473

467-
sr, sw := schema.Pipe[*schema.Message](0)
468-
go func() {
469-
sw.Send(handOffMsg1, nil)
470-
sw.Send(handOffMsg2, nil)
471-
sw.Send(handOffMsg3, nil)
472-
sw.Send(handOffMsg4, nil)
473-
sw.Close()
474-
}()
474+
handOffMsg5 := &schema.Message{
475+
Role: schema.Assistant,
476+
ToolCalls: []schema.ToolCall{
477+
{
478+
Index: generic.PtrOf(1),
479+
Function: schema.FunctionCall{
480+
Name: specialist2.Name,
481+
Arguments: ` is also good"}`,
482+
},
483+
},
484+
},
485+
}
486+
487+
sr := schema.StreamReaderFromArray([]*schema.Message{
488+
handOffMsg1,
489+
handOffMsg2,
490+
handOffMsg3,
491+
handOffMsg4,
492+
handOffMsg5,
493+
})
475494

476495
specialist1Msg1 := &schema.Message{
477496
Role: schema.Assistant,
@@ -483,12 +502,10 @@ func TestHostMultiAgent(t *testing.T) {
483502
Content: "1 answer",
484503
}
485504

486-
sr1, sw1 := schema.Pipe[*schema.Message](0)
487-
go func() {
488-
sw1.Send(specialist1Msg1, nil)
489-
sw1.Send(specialist1Msg2, nil)
490-
sw1.Close()
491-
}()
505+
sr1 := schema.StreamReaderFromArray([]*schema.Message{
506+
specialist1Msg1,
507+
specialist1Msg2,
508+
})
492509

493510
streamToolCallChecker := func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error) {
494511
defer modelOutput.Close()
@@ -528,7 +545,7 @@ func TestHostMultiAgent(t *testing.T) {
528545
mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1)
529546
mockSpecialistLLM1.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr1, nil).Times(1)
530547

531-
mockCallback := &mockAgentCallback{}
548+
mockCallback := newMockAgentCallback(2)
532549
outStream, err := hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback))
533550
assert.NoError(t, err)
534551

@@ -551,6 +568,7 @@ func TestHostMultiAgent(t *testing.T) {
551568
t.Errorf("Unexpected message content: %s", msg.Content)
552569
}
553570

571+
mockCallback.wg.Wait()
554572
assert.Equal(t, []*HandOffInfo{
555573
{
556574
ToAgentName: specialist1.Name,
@@ -566,9 +584,21 @@ func TestHostMultiAgent(t *testing.T) {
566584

567585
type mockAgentCallback struct {
568586
infos []*HandOffInfo
587+
wg sync.WaitGroup
569588
}
570589

571590
func (m *mockAgentCallback) OnHandOff(ctx context.Context, info *HandOffInfo) context.Context {
572591
m.infos = append(m.infos, info)
592+
m.wg.Done()
573593
return ctx
574594
}
595+
596+
func newMockAgentCallback(expects int) *mockAgentCallback {
597+
m := &mockAgentCallback{
598+
infos: make([]*HandOffInfo, 0),
599+
wg: sync.WaitGroup{},
600+
}
601+
602+
m.wg.Add(expects)
603+
return m
604+
}

flow/agent/multiagent/host/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ type MultiAgentConfig struct {
9494
// Summarizer is the summarizer agent that will summarize the outputs of all the chosen specialist agents.
9595
// Only when the Host agent picks multiple Specialist will this be called.
9696
// If you do not provide a summarizer, a default summarizer that simply concatenates all the output messages into one message will be used.
97+
// Note: the default summarizer do not support streaming.
9798
Summarizer *Summarizer
9899
}
99100

0 commit comments

Comments
 (0)