Skip to content

Commit 161c4a5

Browse files
committed
fsm: add WaitForStateAsync to the cached observer
By adding WaitForStateAsync to the observer we can always observe state changes in an atomic way without relying on the observer's internal cache.
1 parent 6ac6ee0 commit 161c4a5

File tree

3 files changed

+124
-118
lines changed

3 files changed

+124
-118
lines changed

fsm/example_fsm_test.go

Lines changed: 49 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -140,71 +140,52 @@ func getTestContext() (*ExampleFSM, *CachedObserver) {
140140
return exampleContext, cachedObserver
141141
}
142142

143-
// TestExampleFSMFlow tests different flows that the example FSM can go through.
144-
func TestExampleFSMFlow(t *testing.T) {
143+
// TestObserverAsyncWait tests the observer's WaitForStateAsync function.
144+
func TestObserverAsyncWait(t *testing.T) {
145145
testCases := []struct {
146-
name string
147-
expectedStateFlow []StateType
148-
expectedEventFlow []EventType
149-
storeError error
150-
serviceError error
146+
name string
147+
waitTime time.Duration
148+
blockTime time.Duration
149+
expectTimeout bool
151150
}{
152151
{
153-
name: "success",
154-
expectedStateFlow: []StateType{
155-
InitFSM,
156-
StuffSentOut,
157-
StuffSuccess,
158-
},
159-
expectedEventFlow: []EventType{
160-
OnRequestStuff,
161-
OnStuffSentOut,
162-
OnStuffSuccess,
163-
},
164-
},
165-
{
166-
name: "failure on store",
167-
expectedStateFlow: []StateType{
168-
InitFSM,
169-
StuffFailed,
170-
},
171-
expectedEventFlow: []EventType{
172-
OnRequestStuff,
173-
OnError,
174-
},
175-
storeError: errStore,
152+
name: "success",
153+
waitTime: time.Second,
154+
blockTime: time.Millisecond,
155+
expectTimeout: false,
176156
},
177157
{
178-
name: "failure on service",
179-
expectedStateFlow: []StateType{
180-
InitFSM,
181-
StuffSentOut,
182-
StuffFailed,
183-
},
184-
expectedEventFlow: []EventType{
185-
OnRequestStuff,
186-
OnStuffSentOut,
187-
OnError,
188-
},
189-
serviceError: errService,
158+
name: "timeout",
159+
waitTime: time.Millisecond,
160+
blockTime: time.Second,
161+
expectTimeout: true,
190162
},
191163
}
192164

193165
for _, tc := range testCases {
194166
tc := tc
195167

196168
t.Run(tc.name, func(t *testing.T) {
197-
exampleContext, cachedObserver := getTestContext()
198-
199-
if tc.storeError != nil {
200-
exampleContext.store.(*mockStore).
201-
storeErr = tc.storeError
169+
service := &mockService{
170+
respondChan: make(chan bool),
202171
}
203172

204-
if tc.serviceError != nil {
205-
exampleContext.service.(*mockService).
206-
respondErr = tc.serviceError
207-
}
173+
store := &mockStore{}
174+
175+
exampleContext := NewExampleFSMContext(service, store)
176+
cachedObserver := NewCachedObserver(100)
177+
exampleContext.RegisterObserver(cachedObserver)
178+
179+
t0 := time.Now()
180+
timeoutCtx, cancel := context.WithTimeout(
181+
context.Background(), tc.waitTime,
182+
)
183+
defer cancel()
184+
185+
// Wait for the final state.
186+
errChan := cachedObserver.WaitForStateAsync(
187+
timeoutCtx, StuffSuccess, true,
188+
)
208189

209190
go func() {
210191
err := exampleContext.SendEvent(
@@ -213,32 +194,26 @@ func TestExampleFSMFlow(t *testing.T) {
213194
)
214195

215196
require.NoError(t, err)
216-
}()
217197

218-
// Wait for the final state.
219-
err := cachedObserver.WaitForState(
220-
context.Background(),
221-
time.Second,
222-
tc.expectedStateFlow[len(
223-
tc.expectedStateFlow,
224-
)-1],
225-
)
226-
require.NoError(t, err)
198+
time.Sleep(tc.blockTime)
199+
service.respondChan <- true
200+
}()
227201

228-
allNotifications := cachedObserver.
229-
GetCachedNotifications()
202+
timeout := false
203+
select {
204+
case <-timeoutCtx.Done():
205+
timeout = true
230206

231-
for index, notification := range allNotifications {
232-
require.Equal(
233-
t,
234-
tc.expectedStateFlow[index],
235-
notification.NextState,
236-
)
237-
require.Equal(
238-
t,
239-
tc.expectedEventFlow[index],
240-
notification.Event,
241-
)
207+
case <-errChan:
208+
}
209+
require.Equal(t, tc.expectTimeout, timeout)
210+
211+
t1 := time.Now()
212+
diff := t1.Sub(t0)
213+
if tc.expectTimeout {
214+
require.Less(t, diff, tc.blockTime)
215+
} else {
216+
require.Less(t, diff, tc.waitTime)
242217
}
243218
})
244219
}

fsm/fsm.go

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -306,23 +306,36 @@ func NoOpAction(_ EventContext) EventType {
306306
}
307307

308308
// ErrConfigError is an error returned when the state machine is misconfigured.
309-
type ErrConfigError error
309+
type ErrConfigError struct {
310+
msg string
311+
}
312+
313+
// Error returns the error message.
314+
func (e ErrConfigError) Error() string {
315+
return fmt.Sprintf("config error: %s", e.msg)
316+
}
310317

311318
// NewErrConfigError creates a new ErrConfigError.
312319
func NewErrConfigError(msg string) ErrConfigError {
313-
return (ErrConfigError)(fmt.Errorf("config error: %s", msg))
320+
return ErrConfigError{
321+
msg: msg,
322+
}
314323
}
315324

316325
// ErrWaitingForStateTimeout is an error returned when the state machine times
317326
// out while waiting for a state.
318-
type ErrWaitingForStateTimeout error
327+
type ErrWaitingForStateTimeout struct {
328+
expected StateType
329+
}
319330

320-
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
321-
func NewErrWaitingForStateTimeout(expected,
322-
actual StateType) ErrWaitingForStateTimeout {
331+
// Error returns the error message.
332+
func (e ErrWaitingForStateTimeout) Error() string {
333+
return fmt.Sprintf("waiting for state timed out: %s", e.expected)
334+
}
323335

324-
return (ErrWaitingForStateTimeout)(fmt.Errorf(
325-
"waiting for state timeout: expected %s, actual: %s",
326-
expected, actual,
327-
))
336+
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
337+
func NewErrWaitingForStateTimeout(expected StateType) ErrWaitingForStateTimeout {
338+
return ErrWaitingForStateTimeout{
339+
expected: expected,
340+
}
328341
}

fsm/observer.go

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
105105
opts ...WaitForStateOption) error {
106106

107107
var options fsmOptions
108-
109108
for _, opt := range opts {
110109
opt.apply(&options)
111110
}
@@ -120,61 +119,80 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
120119
}
121120
}
122121

122+
// Create a new context with a timeout.
123123
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
124124
defer cancel()
125125

126-
// Channel to notify when the desired state is reached
127-
// or an error occurred.
126+
ch := s.WaitForStateAsync(timeoutCtx, state, options.abortEarlyOnError)
127+
128+
// Wait for either the condition to be met or for a timeout.
129+
select {
130+
case <-timeoutCtx.Done():
131+
return NewErrWaitingForStateTimeout(state)
132+
133+
case err := <-ch:
134+
return err
135+
}
136+
}
137+
138+
// WaitForStateAsync waits asynchronously until the passed context is canceled
139+
// or the expected state is reached. The function returns a channel that will
140+
// receive an error if the expected state is reached or an error occurred. If
141+
// the context is canceled before the expected state is reached, the channel
142+
// will receive an ErrWaitingForStateTimeout error.
143+
func (s *CachedObserver) WaitForStateAsync(ctx context.Context, state StateType,
144+
abortOnEarlyError bool) chan error {
145+
146+
// Channel to notify when the desired state is reached or an error
147+
// occurred.
128148
ch := make(chan error)
129149

130-
// Goroutine to wait on condition variable
150+
// Wait on the notification condition variable asynchronously to avoid
151+
// blocking the caller.
131152
go func() {
132153
s.notificationMx.Lock()
133154
defer s.notificationMx.Unlock()
134155

135-
for {
136-
// Check if the last state is the desired state
137-
if s.lastNotification.NextState == state {
156+
notifyAsync := func(err error) {
157+
// To avoid potential deadlock when sending to the
158+
// channel, while still holding the notificationMx
159+
// lock, we use a goroutine to send the error to the
160+
// channel.
161+
go func() {
138162
select {
139-
case <-timeoutCtx.Done():
140-
return
163+
case <-ctx.Done():
164+
ch <- NewErrWaitingForStateTimeout(
165+
state,
166+
)
141167

142-
case ch <- nil:
143-
return
168+
case ch <- err:
144169
}
170+
}()
171+
}
172+
173+
for {
174+
// Check if the last state is the desired state.
175+
if s.lastNotification.NextState == state {
176+
notifyAsync(nil)
177+
return
145178
}
146179

147-
// Check if an error occurred
180+
// Check if an error has occurred.
148181
if s.lastNotification.Event == OnError {
149-
if options.abortEarlyOnError {
150-
select {
151-
case <-timeoutCtx.Done():
152-
return
153-
154-
case ch <- s.lastNotification.LastActionError:
155-
return
156-
}
182+
lastErr := s.lastNotification.LastActionError
183+
if abortOnEarlyError {
184+
notifyAsync(lastErr)
185+
return
157186
}
158187
}
159188

160-
// Otherwise, wait for the next notification
189+
// Otherwise use the conditonal variable to wait for
190+
// the next notification.
161191
s.notificationCond.Wait()
162192
}
163193
}()
164194

165-
// Wait for either the condition to be met or for a timeout
166-
select {
167-
case <-timeoutCtx.Done():
168-
return NewErrWaitingForStateTimeout(
169-
state, s.lastNotification.NextState,
170-
)
171-
172-
case lastActionErr := <-ch:
173-
if lastActionErr != nil {
174-
return lastActionErr
175-
}
176-
return nil
177-
}
195+
return ch
178196
}
179197

180198
// FixedSizeSlice is a slice with a fixed size.

0 commit comments

Comments
 (0)