@@ -2,6 +2,7 @@ package flow
22
33import (
44 "fmt"
5+ "sync"
56
67 "github.com/reugn/go-streams"
78)
@@ -30,19 +31,24 @@ var _ streams.Flow = (*FlatMap[any, any])(nil)
3031// T specifies the incoming element type, and the outgoing element type is []R.
3132//
3233// flatMapFunction is the FlatMap transformation function.
33- // parallelism is the flow parallelism factor. In case the events order matters, use parallelism = 1.
34- // If the parallelism argument is not positive, NewFlatMap will panic.
34+ // parallelism specifies the number of goroutines to use for parallel processing. If
35+ // the order of elements in the output stream must be preserved, set parallelism to 1.
36+ //
37+ // NewFlatMap will panic if parallelism is less than 1.
3538func NewFlatMap [T , R any ](flatMapFunction FlatMapFunction [T , R ], parallelism int ) * FlatMap [T , R ] {
3639 if parallelism < 1 {
3740 panic (fmt .Sprintf ("nonpositive FlatMap parallelism: %d" , parallelism ))
3841 }
42+
3943 flatMap := & FlatMap [T , R ]{
4044 flatMapFunction : flatMapFunction ,
4145 in : make (chan any ),
4246 out : make (chan any ),
4347 parallelism : parallelism ,
4448 }
45- go flatMap .doStream ()
49+
50+ // start processing stream elements
51+ go flatMap .stream ()
4652
4753 return flatMap
4854}
@@ -77,20 +83,27 @@ func (fm *FlatMap[T, R]) transmit(inlet streams.Inlet) {
7783 close (inlet .In ())
7884}
7985
80- func (fm * FlatMap [T , R ]) doStream () {
81- sem := make (chan struct {}, fm .parallelism )
82- for elem := range fm .in {
83- sem <- struct {}{}
84- go func (element T ) {
85- defer func () { <- sem }()
86- result := fm .flatMapFunction (element )
87- for _ , item := range result {
88- fm .out <- item
89- }
90- }(elem .(T ))
91- }
86+ // stream reads elements from the input channel, applies the flatMapFunction
87+ // to each element, and sends the resulting elements to the output channel.
88+ // It uses a pool of goroutines to process elements in parallel.
89+ func (fm * FlatMap [T , R ]) stream () {
90+ var wg sync.WaitGroup
91+ // create a pool of worker goroutines
9292 for i := 0 ; i < fm .parallelism ; i ++ {
93- sem <- struct {}{}
93+ wg .Add (1 )
94+ go func () {
95+ defer wg .Done ()
96+ for element := range fm .in {
97+ result := fm .flatMapFunction (element .(T ))
98+ for _ , item := range result {
99+ fm .out <- item
100+ }
101+ }
102+ }()
94103 }
104+
105+ // wait for worker goroutines to finish processing inbound elements
106+ wg .Wait ()
107+ // close the output channel
95108 close (fm .out )
96109}
0 commit comments