Skip to content

Commit 3843c39

Browse files
authored
Merge pull request #741 from bhandras/fsm-observer-fixup
fsm: add WaitForStateAsync to the cached observer
2 parents 7a8c052 + 811e9df commit 3843c39

File tree

3 files changed

+159
-52
lines changed

3 files changed

+159
-52
lines changed

fsm/example_fsm_test.go

+79
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,82 @@ func TestExampleFSMFlow(t *testing.T) {
243243
})
244244
}
245245
}
246+
247+
// TestObserverAsyncWait tests the observer's WaitForStateAsync function.
248+
func TestObserverAsyncWait(t *testing.T) {
249+
testCases := []struct {
250+
name string
251+
waitTime time.Duration
252+
blockTime time.Duration
253+
expectTimeout bool
254+
}{
255+
{
256+
name: "success",
257+
waitTime: time.Second,
258+
blockTime: time.Millisecond,
259+
expectTimeout: false,
260+
},
261+
{
262+
name: "timeout",
263+
waitTime: time.Millisecond,
264+
blockTime: time.Second,
265+
expectTimeout: true,
266+
},
267+
}
268+
269+
for _, tc := range testCases {
270+
tc := tc
271+
272+
t.Run(tc.name, func(t *testing.T) {
273+
service := &mockService{
274+
respondChan: make(chan bool),
275+
}
276+
277+
store := &mockStore{}
278+
279+
exampleContext := NewExampleFSMContext(service, store)
280+
cachedObserver := NewCachedObserver(100)
281+
exampleContext.RegisterObserver(cachedObserver)
282+
283+
t0 := time.Now()
284+
timeoutCtx, cancel := context.WithTimeout(
285+
context.Background(), tc.waitTime,
286+
)
287+
defer cancel()
288+
289+
// Wait for the final state.
290+
errChan := cachedObserver.WaitForStateAsync(
291+
timeoutCtx, StuffSuccess, true,
292+
)
293+
294+
go func() {
295+
err := exampleContext.SendEvent(
296+
OnRequestStuff,
297+
newInitStuffRequest(),
298+
)
299+
300+
require.NoError(t, err)
301+
302+
time.Sleep(tc.blockTime)
303+
service.respondChan <- true
304+
}()
305+
306+
timeout := false
307+
select {
308+
case <-timeoutCtx.Done():
309+
timeout = true
310+
311+
case <-errChan:
312+
}
313+
require.Equal(t, tc.expectTimeout, timeout)
314+
315+
t1 := time.Now()
316+
diff := t1.Sub(t0)
317+
if tc.expectTimeout {
318+
require.Less(t, diff, tc.blockTime)
319+
} else {
320+
require.Less(t, diff, tc.waitTime)
321+
}
322+
})
323+
}
324+
}

fsm/fsm.go

+23-10
Original file line numberDiff line numberDiff line change
@@ -306,23 +306,36 @@ func NoOpAction(_ EventContext) EventType {
306306
}
307307

308308
// ErrConfigError is an error returned when the state machine is misconfigured.
309-
type ErrConfigError error
309+
type ErrConfigError struct {
310+
msg string
311+
}
312+
313+
// Error returns the error message.
314+
func (e ErrConfigError) Error() string {
315+
return fmt.Sprintf("config error: %s", e.msg)
316+
}
310317

311318
// NewErrConfigError creates a new ErrConfigError.
312319
func NewErrConfigError(msg string) ErrConfigError {
313-
return (ErrConfigError)(fmt.Errorf("config error: %s", msg))
320+
return ErrConfigError{
321+
msg: msg,
322+
}
314323
}
315324

316325
// ErrWaitingForStateTimeout is an error returned when the state machine times
317326
// out while waiting for a state.
318-
type ErrWaitingForStateTimeout error
327+
type ErrWaitingForStateTimeout struct {
328+
expected StateType
329+
}
319330

320-
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
321-
func NewErrWaitingForStateTimeout(expected,
322-
actual StateType) ErrWaitingForStateTimeout {
331+
// Error returns the error message.
332+
func (e ErrWaitingForStateTimeout) Error() string {
333+
return fmt.Sprintf("waiting for state timed out: %s", e.expected)
334+
}
323335

324-
return (ErrWaitingForStateTimeout)(fmt.Errorf(
325-
"waiting for state timeout: expected %s, actual: %s",
326-
expected, actual,
327-
))
336+
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
337+
func NewErrWaitingForStateTimeout(expected StateType) ErrWaitingForStateTimeout {
338+
return ErrWaitingForStateTimeout{
339+
expected: expected,
340+
}
328341
}

fsm/observer.go

+57-42
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,11 @@ func WithAbortEarlyOnErrorOption() WaitForStateOption {
100100
// the given duration before checking the state. This is useful if the
101101
// function is called immediately after sending an event to the state machine
102102
// and the state machine needs some time to process the event.
103-
func (s *CachedObserver) WaitForState(ctx context.Context,
103+
func (c *CachedObserver) WaitForState(ctx context.Context,
104104
timeout time.Duration, state StateType,
105105
opts ...WaitForStateOption) error {
106106

107107
var options fsmOptions
108-
109108
for _, opt := range opts {
110109
opt.apply(&options)
111110
}
@@ -120,61 +119,77 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
120119
}
121120
}
122121

122+
// Create a new context with a timeout.
123123
timeoutCtx, cancel := context.WithTimeout(ctx, timeout)
124124
defer cancel()
125125

126-
// Channel to notify when the desired state is reached
127-
// or an error occurred.
128-
ch := make(chan error)
126+
ch := c.WaitForStateAsync(timeoutCtx, state, options.abortEarlyOnError)
129127

130-
// Goroutine to wait on condition variable
131-
go func() {
132-
s.notificationMx.Lock()
133-
defer s.notificationMx.Unlock()
128+
// Wait for either the condition to be met or for a timeout.
129+
select {
130+
case <-timeoutCtx.Done():
131+
return NewErrWaitingForStateTimeout(state)
134132

135-
for {
136-
// Check if the last state is the desired state
137-
if s.lastNotification.NextState == state {
138-
select {
139-
case <-timeoutCtx.Done():
140-
return
133+
case err := <-ch:
134+
return err
135+
}
136+
}
141137

142-
case ch <- nil:
143-
return
144-
}
138+
// WaitForStateAsync waits asynchronously until the passed context is canceled
139+
// or the expected state is reached. The function returns a channel that will
140+
// receive an error if the expected state is reached or an error occurred. If
141+
// the context is canceled before the expected state is reached, the channel
142+
// will receive an ErrWaitingForStateTimeout error.
143+
func (c *CachedObserver) WaitForStateAsync(ctx context.Context, state StateType,
144+
abortOnEarlyError bool) chan error {
145+
146+
// Channel to notify when the desired state is reached or an error
147+
// occurred.
148+
ch := make(chan error, 1)
149+
150+
// Wait on the notification condition variable asynchronously to avoid
151+
// blocking the caller.
152+
go func() {
153+
c.notificationMx.Lock()
154+
defer c.notificationMx.Unlock()
155+
156+
// writeResult writes the result to the channel. If the context
157+
// is canceled, an ErrWaitingForStateTimeout error is written
158+
// to the channel.
159+
writeResult := func(err error) {
160+
select {
161+
case <-ctx.Done():
162+
ch <- NewErrWaitingForStateTimeout(
163+
state,
164+
)
165+
166+
case ch <- err:
145167
}
168+
}
146169

147-
// Check if an error occurred
148-
if s.lastNotification.Event == OnError {
149-
if options.abortEarlyOnError {
150-
select {
151-
case <-timeoutCtx.Done():
152-
return
170+
for {
171+
// Check if the last state is the desired state.
172+
if c.lastNotification.NextState == state {
173+
writeResult(nil)
174+
return
175+
}
153176

154-
case ch <- s.lastNotification.LastActionError:
155-
return
156-
}
177+
// Check if an error has occurred.
178+
if c.lastNotification.Event == OnError {
179+
lastErr := c.lastNotification.LastActionError
180+
if abortOnEarlyError {
181+
writeResult(lastErr)
182+
return
157183
}
158184
}
159185

160-
// Otherwise, wait for the next notification
161-
s.notificationCond.Wait()
186+
// Otherwise use the conditional variable to wait for
187+
// the next notification.
188+
c.notificationCond.Wait()
162189
}
163190
}()
164191

165-
// Wait for either the condition to be met or for a timeout
166-
select {
167-
case <-timeoutCtx.Done():
168-
return NewErrWaitingForStateTimeout(
169-
state, s.lastNotification.NextState,
170-
)
171-
172-
case lastActionErr := <-ch:
173-
if lastActionErr != nil {
174-
return lastActionErr
175-
}
176-
return nil
177-
}
192+
return ch
178193
}
179194

180195
// FixedSizeSlice is a slice with a fixed size.

0 commit comments

Comments
 (0)