Skip to content

Commit 062262b

Browse files
authored
refactor(flow): rework parallel processing of stream elements (#178)
1 parent e4d8dcf commit 062262b

File tree

6 files changed

+109
-53
lines changed

6 files changed

+109
-53
lines changed

flow/filter.go

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package flow
22

33
import (
44
"fmt"
5+
"sync"
56

67
"github.com/reugn/go-streams"
78
)
@@ -31,20 +32,27 @@ var _ streams.Flow = (*Filter[any])(nil)
3132
// NewFilter returns a new Filter operator.
3233
// T specifies the incoming and the outgoing element type.
3334
//
34-
// filterPredicate is the boolean-valued filter function.
35-
// parallelism is the flow parallelism factor. In case the events order matters, use parallelism = 1.
36-
// If the parallelism argument is not positive, NewFilter will panic.
35+
// filterPredicate is a function that accepts an element of type T and returns true
36+
// if the element should be included in the output stream, and false if it should be
37+
// filtered out.
38+
// parallelism specifies the number of goroutines to use for parallel processing. If
39+
// the order of elements in the output stream must be preserved, set parallelism to 1.
40+
//
41+
// NewFilter will panic if parallelism is less than 1.
3742
func NewFilter[T any](filterPredicate FilterPredicate[T], parallelism int) *Filter[T] {
3843
if parallelism < 1 {
3944
panic(fmt.Sprintf("nonpositive Filter parallelism: %d", parallelism))
4045
}
46+
4147
filter := &Filter[T]{
4248
filterPredicate: filterPredicate,
4349
in: make(chan any),
4450
out: make(chan any),
4551
parallelism: parallelism,
4652
}
47-
go filter.doStream()
53+
54+
// start processing stream elements
55+
go filter.stream()
4856

4957
return filter
5058
}
@@ -79,20 +87,26 @@ func (f *Filter[T]) transmit(inlet streams.Inlet) {
7987
close(inlet.In())
8088
}
8189

82-
// doStream discards items that don't match the filter predicate.
83-
func (f *Filter[T]) doStream() {
84-
sem := make(chan struct{}, f.parallelism)
85-
for elem := range f.in {
86-
sem <- struct{}{}
87-
go func(element T) {
88-
defer func() { <-sem }()
89-
if f.filterPredicate(element) {
90-
f.out <- element
91-
}
92-
}(elem.(T))
93-
}
90+
// stream reads elements from the input channel, filters them using the
91+
// filterPredicate, and sends the filtered elements to the output channel.
92+
// It uses a pool of goroutines to process elements in parallel.
93+
func (f *Filter[T]) stream() {
94+
var wg sync.WaitGroup
95+
// create a pool of worker goroutines
9496
for i := 0; i < f.parallelism; i++ {
95-
sem <- struct{}{}
97+
wg.Add(1)
98+
go func() {
99+
defer wg.Done()
100+
for element := range f.in {
101+
if f.filterPredicate(element.(T)) {
102+
f.out <- element
103+
}
104+
}
105+
}()
96106
}
107+
108+
// wait for worker goroutines to finish processing inbound elements
109+
wg.Wait()
110+
// close the output channel
97111
close(f.out)
98112
}

flow/flat_map.go

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package flow
22

33
import (
44
"fmt"
5+
"sync"
56

67
"github.com/reugn/go-streams"
78
)
@@ -30,19 +31,24 @@ var _ streams.Flow = (*FlatMap[any, any])(nil)
3031
// T specifies the incoming element type, and the outgoing element type is []R.
3132
//
3233
// flatMapFunction is the FlatMap transformation function.
33-
// parallelism is the flow parallelism factor. In case the events order matters, use parallelism = 1.
34-
// If the parallelism argument is not positive, NewFlatMap will panic.
34+
// parallelism specifies the number of goroutines to use for parallel processing. If
35+
// the order of elements in the output stream must be preserved, set parallelism to 1.
36+
//
37+
// NewFlatMap will panic if parallelism is less than 1.
3538
func NewFlatMap[T, R any](flatMapFunction FlatMapFunction[T, R], parallelism int) *FlatMap[T, R] {
3639
if parallelism < 1 {
3740
panic(fmt.Sprintf("nonpositive FlatMap parallelism: %d", parallelism))
3841
}
42+
3943
flatMap := &FlatMap[T, R]{
4044
flatMapFunction: flatMapFunction,
4145
in: make(chan any),
4246
out: make(chan any),
4347
parallelism: parallelism,
4448
}
45-
go flatMap.doStream()
49+
50+
// start processing stream elements
51+
go flatMap.stream()
4652

4753
return flatMap
4854
}
@@ -77,20 +83,27 @@ func (fm *FlatMap[T, R]) transmit(inlet streams.Inlet) {
7783
close(inlet.In())
7884
}
7985

80-
func (fm *FlatMap[T, R]) doStream() {
81-
sem := make(chan struct{}, fm.parallelism)
82-
for elem := range fm.in {
83-
sem <- struct{}{}
84-
go func(element T) {
85-
defer func() { <-sem }()
86-
result := fm.flatMapFunction(element)
87-
for _, item := range result {
88-
fm.out <- item
89-
}
90-
}(elem.(T))
91-
}
86+
// stream reads elements from the input channel, applies the flatMapFunction
87+
// to each element, and sends the resulting elements to the output channel.
88+
// It uses a pool of goroutines to process elements in parallel.
89+
func (fm *FlatMap[T, R]) stream() {
90+
var wg sync.WaitGroup
91+
// create a pool of worker goroutines
9292
for i := 0; i < fm.parallelism; i++ {
93-
sem <- struct{}{}
93+
wg.Add(1)
94+
go func() {
95+
defer wg.Done()
96+
for element := range fm.in {
97+
result := fm.flatMapFunction(element.(T))
98+
for _, item := range result {
99+
fm.out <- item
100+
}
101+
}
102+
}()
94103
}
104+
105+
// wait for worker goroutines to finish processing inbound elements
106+
wg.Wait()
107+
// close the output channel
95108
close(fm.out)
96109
}

flow/fold.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ func NewFold[T, R any](init R, foldFunction FoldFunction[T, R]) *Fold[T, R] {
3838
in: make(chan any),
3939
out: make(chan any),
4040
}
41-
go foldFlow.doStream()
41+
42+
// start processing stream elements
43+
go foldFlow.stream()
4244

4345
return foldFlow
4446
}
@@ -73,7 +75,12 @@ func (m *Fold[T, R]) transmit(inlet streams.Inlet) {
7375
close(inlet.In())
7476
}
7577

76-
func (m *Fold[T, R]) doStream() {
78+
// stream consumes elements from the input channel, applies the foldFunction to
79+
// each element along with the previously accumulated value, and emits the updated
80+
// value into the output channel. All input elements are assumed to be of type T.
81+
// The processing is done sequentially, ensuring that the order of accumulation is
82+
// maintained.
83+
func (m *Fold[T, R]) stream() {
7784
lastFolded := m.init
7885
for element := range m.in {
7986
lastFolded = m.foldFunction(element.(T), lastFolded)

flow/map.go

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package flow
22

33
import (
44
"fmt"
5+
"sync"
56

67
"github.com/reugn/go-streams"
78
)
@@ -30,19 +31,24 @@ var _ streams.Flow = (*Map[any, any])(nil)
3031
// T specifies the incoming element type, and the outgoing element type is R.
3132
//
3233
// mapFunction is the Map transformation function.
33-
// parallelism is the flow parallelism factor. In case the events order matters, use parallelism = 1.
34-
// If the parallelism argument is not positive, NewMap will panic.
34+
// parallelism specifies the number of goroutines to use for parallel processing. If
35+
// the order of elements in the output stream must be preserved, set parallelism to 1.
36+
//
37+
// NewMap will panic if parallelism is less than 1.
3538
func NewMap[T, R any](mapFunction MapFunction[T, R], parallelism int) *Map[T, R] {
3639
if parallelism < 1 {
3740
panic(fmt.Sprintf("nonpositive Map parallelism: %d", parallelism))
3841
}
42+
3943
mapFlow := &Map[T, R]{
4044
mapFunction: mapFunction,
4145
in: make(chan any),
4246
out: make(chan any),
4347
parallelism: parallelism,
4448
}
45-
go mapFlow.doStream()
49+
50+
// start processing stream elements
51+
go mapFlow.stream()
4652

4753
return mapFlow
4854
}
@@ -77,18 +83,24 @@ func (m *Map[T, R]) transmit(inlet streams.Inlet) {
7783
close(inlet.In())
7884
}
7985

80-
func (m *Map[T, R]) doStream() {
81-
sem := make(chan struct{}, m.parallelism)
82-
for elem := range m.in {
83-
sem <- struct{}{}
84-
go func(element T) {
85-
defer func() { <-sem }()
86-
result := m.mapFunction(element)
87-
m.out <- result
88-
}(elem.(T))
89-
}
86+
// stream reads elements from the input channel, applies the mapFunction
87+
// to each element, and sends the resulting element to the output channel.
88+
// It uses a pool of goroutines to process elements in parallel.
89+
func (m *Map[T, R]) stream() {
90+
var wg sync.WaitGroup
91+
// create a pool of worker goroutines
9092
for i := 0; i < m.parallelism; i++ {
91-
sem <- struct{}{}
93+
wg.Add(1)
94+
go func() {
95+
defer wg.Done()
96+
for element := range m.in {
97+
m.out <- m.mapFunction(element.(T))
98+
}
99+
}()
92100
}
101+
102+
// wait for worker goroutines to finish processing inbound elements
103+
wg.Wait()
104+
// close the output channel
93105
close(m.out)
94106
}

flow/pass_through.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ func NewPassThrough() *PassThrough {
2323
in: make(chan any),
2424
out: make(chan any),
2525
}
26-
go passThrough.doStream()
26+
27+
// start processing stream elements
28+
go passThrough.stream()
2729

2830
return passThrough
2931
}
@@ -58,7 +60,7 @@ func (pt *PassThrough) transmit(inlet streams.Inlet) {
5860
close(inlet.In())
5961
}
6062

61-
func (pt *PassThrough) doStream() {
63+
func (pt *PassThrough) stream() {
6264
for element := range pt.in {
6365
pt.out <- element
6466
}

flow/reduce.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ func NewReduce[T any](reduceFunction ReduceFunction[T]) *Reduce[T] {
3434
in: make(chan any),
3535
out: make(chan any),
3636
}
37-
go reduce.doStream()
37+
38+
// start processing stream elements
39+
go reduce.stream()
3840

3941
return reduce
4042
}
@@ -69,7 +71,13 @@ func (r *Reduce[T]) transmit(inlet streams.Inlet) {
6971
close(inlet.In())
7072
}
7173

72-
func (r *Reduce[T]) doStream() {
74+
// stream consumes elements from the input channel, applies the reduceFunction to
75+
// each element along with the previously reduced value, and emits the updated
76+
// value into the output channel. The first element received becomes the initial
77+
// reduced value. Subsequent elements are combined with the accumulated result.
78+
// All input elements are assumed to be of type T. The processing is done sequentially,
79+
// ensuring that the order of accumulation is maintained.
80+
func (r *Reduce[T]) stream() {
7381
var lastReduced any
7482
for element := range r.in {
7583
if lastReduced == nil {

0 commit comments

Comments
 (0)