@@ -19,6 +19,7 @@ package host
19
19
import (
20
20
"context"
21
21
"io"
22
+ "sync"
22
23
"testing"
23
24
24
25
"github.com/stretchr/testify/assert"
@@ -46,6 +47,15 @@ func TestHostMultiAgent(t *testing.T) {
46
47
},
47
48
}
48
49
50
+ specialist2Msg1 := & schema.Message {
51
+ Role : schema .Assistant ,
52
+ Content : "specialist2" ,
53
+ }
54
+ specialist2Msg2 := & schema.Message {
55
+ Role : schema .Assistant ,
56
+ Content : " stream answer" ,
57
+ }
58
+
49
59
specialist2 := & Specialist {
50
60
Invokable : func (ctx context.Context , input []* schema.Message , opts ... agent.AgentOption ) (* schema.Message , error ) {
51
61
return & schema.Message {
@@ -54,15 +64,7 @@ func TestHostMultiAgent(t *testing.T) {
54
64
}, nil
55
65
},
56
66
Streamable : func (ctx context.Context , input []* schema.Message , opts ... agent.AgentOption ) (* schema.StreamReader [* schema.Message ], error ) {
57
- sr , sw := schema.Pipe [* schema.Message ](0 )
58
- go func () {
59
- sw .Send (& schema.Message {
60
- Role : schema .Assistant ,
61
- Content : "specialist2 stream answer" ,
62
- }, nil )
63
- sw .Close ()
64
- }()
65
- return sr , nil
67
+ return schema .StreamReaderFromArray ([]* schema.Message {specialist2Msg1 , specialist2Msg2 }), nil
66
68
},
67
69
AgentMeta : AgentMeta {
68
70
Name : "specialist 2" ,
@@ -94,7 +96,7 @@ func TestHostMultiAgent(t *testing.T) {
94
96
95
97
mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (directAnswerMsg , nil ).Times (1 )
96
98
97
- mockCallback := & mockAgentCallback {}
99
+ mockCallback := newMockAgentCallback ( 0 )
98
100
99
101
out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ))
100
102
assert .NoError (t , err )
@@ -122,7 +124,7 @@ func TestHostMultiAgent(t *testing.T) {
122
124
123
125
mockHostLLM .EXPECT ().Stream (gomock .Any (), gomock .Any ()).Return (sr , nil ).Times (1 )
124
126
125
- mockCallback := & mockAgentCallback {}
127
+ mockCallback := newMockAgentCallback ( 0 )
126
128
outStream , err := hostMA .Stream (ctx , nil , WithAgentCallbacks (mockCallback ))
127
129
assert .NoError (t , err )
128
130
assert .Empty (t , mockCallback .infos )
@@ -139,9 +141,8 @@ func TestHostMultiAgent(t *testing.T) {
139
141
140
142
outStream .Close ()
141
143
142
- msg , err := schema .ConcatMessages (msgs )
143
- assert .NoError (t , err )
144
- assert .Equal (t , "direct answer" , msg .Content )
144
+ assert .Equal (t , directAnswerMsg1 , msgs [0 ])
145
+ assert .Equal (t , directAnswerMsg2 , msgs [1 ])
145
146
})
146
147
147
148
t .Run ("generate hand off" , func (t * testing.T ) {
@@ -166,11 +167,12 @@ func TestHostMultiAgent(t *testing.T) {
166
167
mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (1 )
167
168
mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (specialistMsg , nil ).Times (1 )
168
169
169
- mockCallback := & mockAgentCallback {}
170
+ mockCallback := newMockAgentCallback ( 1 )
170
171
171
172
out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ))
172
173
assert .NoError (t , err )
173
174
assert .Equal (t , "specialist 1 answer" , out .Content )
175
+ mockCallback .wg .Wait ()
174
176
assert .Equal (t , []* HandOffInfo {
175
177
{
176
178
ToAgentName : specialist1 .Name ,
@@ -182,11 +184,12 @@ func TestHostMultiAgent(t *testing.T) {
182
184
handOffMsg .ToolCalls [0 ].Function .Arguments = `{"reason": "specialist 2 is even better"}`
183
185
mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (1 )
184
186
185
- mockCallback = & mockAgentCallback {}
187
+ mockCallback = newMockAgentCallback ( 1 )
186
188
187
189
out , err = hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ))
188
190
assert .NoError (t , err )
189
191
assert .Equal (t , "specialist2 invoke answer" , out .Content )
192
+ mockCallback .wg .Wait ()
190
193
assert .Equal (t , []* HandOffInfo {
191
194
{
192
195
ToAgentName : specialist2 .Name ,
@@ -297,7 +300,7 @@ func TestHostMultiAgent(t *testing.T) {
297
300
mockHostLLM .EXPECT ().Stream (gomock .Any (), gomock .Any ()).Return (sr , nil ).Times (1 )
298
301
mockSpecialistLLM1 .EXPECT ().Stream (gomock .Any (), gomock .Any ()).Return (sr1 , nil ).Times (1 )
299
302
300
- mockCallback := & mockAgentCallback {}
303
+ mockCallback := newMockAgentCallback ( 1 )
301
304
outStream , err := hostMA .Stream (ctx , nil , WithAgentCallbacks (mockCallback ))
302
305
assert .NoError (t , err )
303
306
@@ -313,9 +316,10 @@ func TestHostMultiAgent(t *testing.T) {
313
316
314
317
outStream .Close ()
315
318
316
- msg , err := schema .ConcatMessages (msgs )
317
- assert .NoError (t , err )
318
- assert .Equal (t , "specialist 1 answer" , msg .Content )
319
+ assert .Equal (t , specialistMsg1 , msgs [0 ])
320
+ assert .Equal (t , specialistMsg2 , msgs [1 ])
321
+
322
+ mockCallback .wg .Wait ()
319
323
320
324
assert .Equal (t , []* HandOffInfo {
321
325
{
@@ -337,7 +341,7 @@ func TestHostMultiAgent(t *testing.T) {
337
341
338
342
mockHostLLM .EXPECT ().Stream (gomock .Any (), gomock .Any ()).Return (sr , nil ).Times (1 )
339
343
340
- mockCallback = & mockAgentCallback {}
344
+ mockCallback = newMockAgentCallback ( 1 )
341
345
outStream , err = hostMA .Stream (ctx , nil , WithAgentCallbacks (mockCallback ))
342
346
assert .NoError (t , err )
343
347
@@ -353,9 +357,10 @@ func TestHostMultiAgent(t *testing.T) {
353
357
354
358
outStream .Close ()
355
359
356
- msg , err = schema .ConcatMessages (msgs )
357
- assert .NoError (t , err )
358
- assert .Equal (t , "specialist2 stream answer" , msg .Content )
360
+ assert .Equal (t , specialist2Msg1 , msgs [0 ])
361
+ assert .Equal (t , specialist2Msg2 , msgs [1 ])
362
+
363
+ mockCallback .wg .Wait ()
359
364
360
365
assert .Equal (t , []* HandOffInfo {
361
366
{
@@ -387,7 +392,7 @@ func TestHostMultiAgent(t *testing.T) {
387
392
mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (1 )
388
393
mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (specialistMsg , nil ).Times (1 )
389
394
390
- mockCallback := & mockAgentCallback {}
395
+ mockCallback := newMockAgentCallback ( 1 )
391
396
392
397
hostMA , err := NewMultiAgent (ctx , & MultiAgentConfig {
393
398
Host : Host {
@@ -412,6 +417,8 @@ func TestHostMultiAgent(t *testing.T) {
412
417
out , err := fullGraph .Invoke (ctx , map [string ]any {"country_name" : "China" }, compose .WithCallbacks (ConvertCallbackHandlers (mockCallback )).DesignateNodeWithPath (compose .NewNodePath ("host_ma_node" , hostMA .HostNodeKey ())))
413
418
assert .NoError (t , err )
414
419
assert .Equal (t , "Beijing" , out .Content )
420
+
421
+ mockCallback .wg .Wait ()
415
422
assert .Equal (t , []* HandOffInfo {
416
423
{
417
424
ToAgentName : specialist1 .Name ,
@@ -458,20 +465,32 @@ func TestHostMultiAgent(t *testing.T) {
458
465
Index : generic .PtrOf (1 ),
459
466
Function : schema.FunctionCall {
460
467
Name : specialist2 .Name ,
461
- Arguments : `{"reason": "specialist 2 is also good"} ` ,
468
+ Arguments : `{"reason": "specialist 2` ,
462
469
},
463
470
},
464
471
},
465
472
}
466
473
467
- sr , sw := schema.Pipe [* schema.Message ](0 )
468
- go func () {
469
- sw .Send (handOffMsg1 , nil )
470
- sw .Send (handOffMsg2 , nil )
471
- sw .Send (handOffMsg3 , nil )
472
- sw .Send (handOffMsg4 , nil )
473
- sw .Close ()
474
- }()
474
+ handOffMsg5 := & schema.Message {
475
+ Role : schema .Assistant ,
476
+ ToolCalls : []schema.ToolCall {
477
+ {
478
+ Index : generic .PtrOf (1 ),
479
+ Function : schema.FunctionCall {
480
+ Name : specialist2 .Name ,
481
+ Arguments : ` is also good"}` ,
482
+ },
483
+ },
484
+ },
485
+ }
486
+
487
+ sr := schema .StreamReaderFromArray ([]* schema.Message {
488
+ handOffMsg1 ,
489
+ handOffMsg2 ,
490
+ handOffMsg3 ,
491
+ handOffMsg4 ,
492
+ handOffMsg5 ,
493
+ })
475
494
476
495
specialist1Msg1 := & schema.Message {
477
496
Role : schema .Assistant ,
@@ -483,12 +502,10 @@ func TestHostMultiAgent(t *testing.T) {
483
502
Content : "1 answer" ,
484
503
}
485
504
486
- sr1 , sw1 := schema.Pipe [* schema.Message ](0 )
487
- go func () {
488
- sw1 .Send (specialist1Msg1 , nil )
489
- sw1 .Send (specialist1Msg2 , nil )
490
- sw1 .Close ()
491
- }()
505
+ sr1 := schema .StreamReaderFromArray ([]* schema.Message {
506
+ specialist1Msg1 ,
507
+ specialist1Msg2 ,
508
+ })
492
509
493
510
streamToolCallChecker := func (ctx context.Context , modelOutput * schema.StreamReader [* schema.Message ]) (bool , error ) {
494
511
defer modelOutput .Close ()
@@ -528,7 +545,7 @@ func TestHostMultiAgent(t *testing.T) {
528
545
mockHostLLM .EXPECT ().Stream (gomock .Any (), gomock .Any ()).Return (sr , nil ).Times (1 )
529
546
mockSpecialistLLM1 .EXPECT ().Stream (gomock .Any (), gomock .Any ()).Return (sr1 , nil ).Times (1 )
530
547
531
- mockCallback := & mockAgentCallback {}
548
+ mockCallback := newMockAgentCallback ( 2 )
532
549
outStream , err := hostMA .Stream (ctx , nil , WithAgentCallbacks (mockCallback ))
533
550
assert .NoError (t , err )
534
551
@@ -551,6 +568,7 @@ func TestHostMultiAgent(t *testing.T) {
551
568
t .Errorf ("Unexpected message content: %s" , msg .Content )
552
569
}
553
570
571
+ mockCallback .wg .Wait ()
554
572
assert .Equal (t , []* HandOffInfo {
555
573
{
556
574
ToAgentName : specialist1 .Name ,
@@ -566,9 +584,21 @@ func TestHostMultiAgent(t *testing.T) {
566
584
567
585
type mockAgentCallback struct {
568
586
infos []* HandOffInfo
587
+ wg sync.WaitGroup
569
588
}
570
589
571
590
func (m * mockAgentCallback ) OnHandOff (ctx context.Context , info * HandOffInfo ) context.Context {
572
591
m .infos = append (m .infos , info )
592
+ m .wg .Done ()
573
593
return ctx
574
594
}
595
+
596
+ func newMockAgentCallback (expects int ) * mockAgentCallback {
597
+ m := & mockAgentCallback {
598
+ infos : make ([]* HandOffInfo , 0 ),
599
+ wg : sync.WaitGroup {},
600
+ }
601
+
602
+ m .wg .Add (expects )
603
+ return m
604
+ }
0 commit comments