diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9967359..6354350 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,15 +10,15 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v5 - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: - go-version: 1.16 + go-version: "1.24.x" - name: Test - run: go test -coverprofile=coverage.out ./... + run: make test - name: Convert coverage uses: jandelgado/gcov2lcov-action@v1.0.5 diff --git a/.gitignore b/.gitignore index f69a717..174d9f3 100644 --- a/.gitignore +++ b/.gitignore @@ -29,5 +29,6 @@ _testmain.go # Testing .coverprofile +coverage.out -.vscode \ No newline at end of file +.vscode diff --git a/Makefile b/Makefile index 70e816f..fe0da54 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ default: services test .PHONY: test test: - go test ./... + CGO_ENABLED=1 go test -benchmem -bench=. -v ./... -race -coverprofile=coverage.out -covermode=atomic && go tool cover -func=coverage.out .PHONY: lint lint: diff --git a/README.md b/README.md index 57c5511..7510773 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ [![PkgGoDev](https://pkg.go.dev/badge/github.com/looplab/fsm)](https://pkg.go.dev/github.com/looplab/fsm) -![Bulid Status](https://github.com/looplab/fsm/actions/workflows/main.yml/badge.svg) +![Build Status](https://github.com/looplab/fsm/actions/workflows/main.yml/badge.svg) [![Coverage Status](https://img.shields.io/coveralls/looplab/fsm.svg)](https://coveralls.io/r/looplab/fsm) [![Go Report Card](https://goreportcard.com/badge/looplab/fsm)](https://goreportcard.com/report/looplab/fsm) @@ -24,17 +24,17 @@ package main import ( "fmt" - "github.com/looplab/fsm" + "github.com/looplab/fsm/v2" ) func main() { - fsm := fsm.NewFSM( + fsm := fsm.New[string, string]( "closed", - fsm.Events{ + fsm.Events[string, string]{ {Name: "open", Src: []string{"closed"}, Dst: "open"}, {Name: "close", Src: []string{"open"}, Dst: "closed"}, }, - fsm.Callbacks{}, + fsm.Callbacks[string, string]{}, ) fmt.Println(fsm.Current()) @@ -64,7 +64,7 @@ package main import ( "fmt" - "github.com/looplab/fsm" + "github.com/looplab/fsm/v2" ) type Door struct { @@ -77,14 +77,19 @@ func NewDoor(to string) *Door { To: to, } - d.FSM = fsm.NewFSM( + d.FSM = fsm.New[string, string]( "closed", - fsm.Events{ + fsm.Events[string, string]{ {Name: "open", Src: []string{"closed"}, Dst: "open"}, {Name: "close", Src: []string{"open"}, Dst: "closed"}, }, - fsm.Callbacks{ - "enter_state": func(e *fsm.Event) { d.enterState(e) }, + fsm.Callbacks[string, string]{ + fsm.Callback[string, string]{ + When: fsm.AfterAllStates, + F: func(cr *fsm.CallbackContext[MyEvent, MyState]) { + d.enterState(e) + }, + }, }, ) diff --git a/callback.go b/callback.go new file mode 100644 index 0000000..a3820dc --- /dev/null +++ b/callback.go @@ -0,0 +1,150 @@ +// Copyright (c) 2013 - Max Persson +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fsm + +import ( + "cmp" + "fmt" +) + +// CallbackType defines at which type of Event this callback should be called. +type CallbackType string + +const ( + // BeforeEvent called before event E + BeforeEvent = CallbackType("before_event") + // BeforeAllEvents called before all events + BeforeAllEvents = CallbackType("before_all_events") + // AfterEvent called after event E + AfterEvent = CallbackType("after_event") + // AfterAllEvents called after all events + AfterAllEvents = CallbackType("after_all_events") + // EnterState called after entering state S + EnterState = CallbackType("enter_state") + // EnterAllStates called after entering all states + EnterAllStates = CallbackType("enter_all_states") + // LeaveState is called before leaving state S. + LeaveState = CallbackType("leave_state") + // LeaveAllStates is called before leaving all states. + LeaveAllStates = CallbackType("leave_all_states") +) + +// Callback defines a condition when the callback function F should be called in certain conditions. +// The order of execution for CallbackTypes in the same event or state is: +// The concrete CallbackType has precedence over a general one, e.g. +// BeforeEvent E will be fired before BeforeAllEvents. +type Callback[E cmp.Ordered, S cmp.Ordered] struct { + // When should the callback be called. + When CallbackType + // Event is the event that the callback should be called for. Only relevant for BeforeEvent and AfterEvent. + Event E + // State is the state that the callback should be called for. Only relevant for EnterState and LeaveState. + State S + // F is the callback function. + F func(*CallbackContext[E, S]) +} + +// Callbacks is a shorthand for defining the callbacks in New. +type Callbacks[E cmp.Ordered, S cmp.Ordered] []Callback[E, S] + +// CallbackContext is the info that get passed as a reference in the callbacks. +type CallbackContext[E cmp.Ordered, S cmp.Ordered] struct { + // FSM is an reference to the current FSM. + FSM *FSM[E, S] + // Event is the event name. + Event E + // Src is the state before the transition. + Src S + // Dst is the state after the transition. + Dst S + // Err is an optional error that can be returned from a callback. + Err error + // Args is an optional list of arguments passed to the callback. + Args []any + // canceled is an internal flag set if the transition is canceled. + canceled bool + // async is an internal flag set if the transition should be asynchronous + async bool +} + +// Cancel can be called in before_ or leave_ to cancel the +// current transition before it happens. It takes an optional error, which will +// overwrite e.Err if set before. +func (ctx *CallbackContext[E, S]) Cancel(err ...error) { + ctx.canceled = true + + if len(err) > 0 { + ctx.Err = err[0] + } +} + +// Async can be called in leave_ to do an asynchronous state transition. +// +// The current state transition will be on hold in the old state until a final +// call to Transition is made. This will complete the transition and possibly +// call the other callbacks. +func (ctx *CallbackContext[E, S]) Async() { + ctx.async = true +} +func (cs Callbacks[E, S]) validate() error { + for i := range cs { + cb := cs[i] + err := cb.validate() + if err != nil { + return err + } + } + return nil +} + +func (c *Callback[E, S]) validate() error { + var ( + zeroEvent E + zeroState S + ) + switch c.When { + case BeforeEvent, AfterEvent: + if c.Event == zeroEvent { + return fmt.Errorf("%v given but no event", c.When) + } + if c.State != zeroState { + return fmt.Errorf("%v given but state %v specified", c.When, c.State) + } + case BeforeAllEvents, AfterAllEvents: + if c.Event != zeroEvent { + return fmt.Errorf("%v given with event %v", c.When, c.Event) + } + if c.State != zeroState { + return fmt.Errorf("%v given with state %v", c.When, c.State) + } + case EnterState, LeaveState: + if c.State == zeroState { + return fmt.Errorf("%v given but no state", c.When) + } + if c.Event != zeroEvent { + return fmt.Errorf("%v given but event %v specified", c.When, c.Event) + } + case EnterAllStates, LeaveAllStates: + if c.State != zeroState { + return fmt.Errorf("%v given with state %v", c.When, c.State) + } + if c.Event != zeroEvent { + return fmt.Errorf("%v given with event %v", c.When, c.Event) + } + default: + return fmt.Errorf("invalid callback:%v", c) + } + return nil +} diff --git a/callback_test.go b/callback_test.go new file mode 100644 index 0000000..8f040ac --- /dev/null +++ b/callback_test.go @@ -0,0 +1,62 @@ +package fsm + +import "testing" + +func TestCallbackValidate(t *testing.T) { + tests := []struct { + name string + cb Callback[string, string] + errString string + }{ + { + name: "before_event without event", + cb: Callback[string, string]{When: BeforeEvent}, + errString: "before_event given but no event", + }, + { + name: "before_event with state", + cb: Callback[string, string]{When: BeforeEvent, Event: "open", State: "closed"}, + errString: "before_event given but state closed specified", + }, + { + name: "before_event with state", + cb: Callback[string, string]{When: BeforeAllEvents, Event: "open"}, + errString: "before_all_events given with event open", + }, + + { + name: "before_event without event", + cb: Callback[string, string]{When: EnterState}, + errString: "enter_state given but no state", + }, + { + name: "before_event with state", + cb: Callback[string, string]{When: EnterState, Event: "open", State: "closed"}, + errString: "enter_state given but event open specified", + }, + { + name: "before_event with state", + cb: Callback[string, string]{When: EnterAllStates, State: "closed"}, + errString: "enter_all_states given with state closed", + }, + } + + for i := range tests { + tt := tests[i] + t.Run(tt.name, func(t *testing.T) { + err := tt.cb.validate() + + if tt.errString == "" && err != nil { + t.Errorf("err:%v", err) + } + if tt.errString != "" && err == nil { + t.Errorf("errstring:%s but err is nil", tt.errString) + } + + if tt.errString != "" && err.Error() != tt.errString { + t.Errorf("transition failed %v", err) + } + }) + } + +} diff --git a/errors.go b/errors.go index 9c32a49..e572e9d 100644 --- a/errors.go +++ b/errors.go @@ -14,15 +14,20 @@ package fsm +import ( + "cmp" + "fmt" +) + // InvalidEventError is returned by FSM.Event() when the event cannot be called // in the current state. -type InvalidEventError struct { - Event string - State string +type InvalidEventError[E cmp.Ordered, S cmp.Ordered] struct { + Event E + State S } -func (e InvalidEventError) Error() string { - return "event " + e.Event + " inappropriate in current state " + e.State +func (e InvalidEventError[E, S]) Error() string { + return fmt.Sprintf("event %v inappropriate in current state %v", e.Event, e.State) } // UnknownEventError is returned by FSM.Event() when the event is not defined. @@ -31,7 +36,7 @@ type UnknownEventError struct { } func (e UnknownEventError) Error() string { - return "event " + e.Event + " does not exist" + return fmt.Sprintf("event %s does not exist", e.Event) } // InTransitionError is returned by FSM.Event() when an asynchronous transition @@ -41,7 +46,7 @@ type InTransitionError struct { } func (e InTransitionError) Error() string { - return "event " + e.Event + " inappropriate because previous transition did not complete" + return fmt.Sprintf("event %s inappropriate because previous transition did not complete", e.Event) } // NotInTransitionError is returned by FSM.Transition() when an asynchronous @@ -60,7 +65,7 @@ type NoTransitionError struct { func (e NoTransitionError) Error() string { if e.Err != nil { - return "no transition with error: " + e.Err.Error() + return fmt.Sprintf("no transition with error: %s", e.Err.Error()) } return "no transition" } @@ -73,7 +78,7 @@ type CanceledError struct { func (e CanceledError) Error() string { if e.Err != nil { - return "transition canceled with error: " + e.Err.Error() + return fmt.Sprintf("transition canceled with error: %s", e.Err.Error()) } return "transition canceled" } @@ -86,7 +91,7 @@ type AsyncError struct { func (e AsyncError) Error() string { if e.Err != nil { - return "async started with error: " + e.Err.Error() + return fmt.Sprintf("async started with error: %s", e.Err.Error()) } return "async started" } diff --git a/errors_test.go b/errors_test.go index ba384ee..e53259c 100644 --- a/errors_test.go +++ b/errors_test.go @@ -22,7 +22,7 @@ import ( func TestInvalidEventError(t *testing.T) { event := "invalid event" state := "state" - e := InvalidEventError{Event: event, State: state} + e := InvalidEventError[string, string]{Event: event, State: state} if e.Error() != "event "+e.Event+" inappropriate in current state "+e.State { t.Error("InvalidEventError string mismatch") } diff --git a/event.go b/event.go deleted file mode 100644 index 6707198..0000000 --- a/event.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) 2013 - Max Persson -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fsm - -// Event is the info that get passed as a reference in the callbacks. -type Event struct { - // FSM is an reference to the current FSM. - FSM *FSM - - // Event is the event name. - Event string - - // Src is the state before the transition. - Src string - - // Dst is the state after the transition. - Dst string - - // Err is an optional error that can be returned from a callback. - Err error - - // Args is an optional list of arguments passed to the callback. - Args []interface{} - - // canceled is an internal flag set if the transition is canceled. - canceled bool - - // async is an internal flag set if the transition should be asynchronous - async bool -} - -// Cancel can be called in before_ or leave_ to cancel the -// current transition before it happens. It takes an optional error, which will -// overwrite e.Err if set before. -func (e *Event) Cancel(err ...error) { - e.canceled = true - - if len(err) > 0 { - e.Err = err[0] - } -} - -// Async can be called in leave_ to do an asynchronous state transition. -// -// The current state transition will be on hold in the old state until a final -// call to Transition is made. This will complete the transition and possibly -// call the other callbacks. -func (e *Event) Async() { - e.async = true -} diff --git a/examples/alternate.go b/examples/alternate.go index 8de3fe7..a4282e9 100644 --- a/examples/alternate.go +++ b/examples/alternate.go @@ -1,66 +1,78 @@ +//go:build ignore // +build ignore package main import ( "fmt" - "github.com/looplab/fsm" + + "github.com/looplab/fsm/v2" ) func main() { - fsm := fsm.NewFSM( + f, err := fsm.New( "idle", - fsm.Events{ - {Name: "scan", Src: []string{"idle"}, Dst: "scanning"}, - {Name: "working", Src: []string{"scanning"}, Dst: "scanning"}, - {Name: "situation", Src: []string{"scanning"}, Dst: "scanning"}, - {Name: "situation", Src: []string{"idle"}, Dst: "idle"}, - {Name: "finish", Src: []string{"scanning"}, Dst: "idle"}, + fsm.Transitions[string, string]{ + {Event: "scan", Src: []string{"idle"}, Dst: "scanning"}, + {Event: "working", Src: []string{"scanning"}, Dst: "scanning"}, + {Event: "situation", Src: []string{"scanning"}, Dst: "scanning"}, + {Event: "situation", Src: []string{"idle"}, Dst: "idle"}, + {Event: "finish", Src: []string{"scanning"}, Dst: "idle"}, }, - fsm.Callbacks{ - "scan": func(e *fsm.Event) { - fmt.Println("after_scan: " + e.FSM.Current()) + fsm.Callbacks[string, string]{ + fsm.Callback[string, string]{When: fsm.BeforeEvent, Event: "scan", + F: func(e *fsm.CallbackContext[string, string]) { + fmt.Println("after_scan: " + e.FSM.Current()) + }, }, - "working": func(e *fsm.Event) { - fmt.Println("working: " + e.FSM.Current()) + fsm.Callback[string, string]{When: fsm.BeforeEvent, Event: "working", + F: func(e *fsm.CallbackContext[string, string]) { + fmt.Println("working: " + e.FSM.Current()) + }, }, - "situation": func(e *fsm.Event) { - fmt.Println("situation: " + e.FSM.Current()) + fsm.Callback[string, string]{When: fsm.BeforeEvent, Event: "situation", + F: func(e *fsm.CallbackContext[string, string]) { + fmt.Println("situation: " + e.FSM.Current()) + }, }, - "finish": func(e *fsm.Event) { - fmt.Println("finish: " + e.FSM.Current()) + fsm.Callback[string, string]{When: fsm.BeforeEvent, Event: "finish", + F: func(e *fsm.CallbackContext[string, string]) { + fmt.Println("finish: " + e.FSM.Current()) + }, }, }, ) + if err != nil { + fmt.Println(err) + } + fmt.Println(f.Current()) - fmt.Println(fsm.Current()) - - err := fsm.Event("scan") + err = f.Event("scan") if err != nil { fmt.Println(err) } - fmt.Println("1:" + fsm.Current()) + fmt.Println("1:" + f.Current()) - err = fsm.Event("working") + err = f.Event("working") if err != nil { fmt.Println(err) } - fmt.Println("2:" + fsm.Current()) + fmt.Println("2:" + f.Current()) - err = fsm.Event("situation") + err = f.Event("situation") if err != nil { fmt.Println(err) } - fmt.Println("3:" + fsm.Current()) + fmt.Println("3:" + f.Current()) - err = fsm.Event("finish") + err = f.Event("finish") if err != nil { fmt.Println(err) } - fmt.Println("4:" + fsm.Current()) + fmt.Println("4:" + f.Current()) } diff --git a/examples/data.go b/examples/data.go index 26aa1f0..3f27ada 100644 --- a/examples/data.go +++ b/examples/data.go @@ -1,3 +1,4 @@ +//go:build ignore // +build ignore package main @@ -5,34 +6,39 @@ package main import ( "fmt" - "github.com/looplab/fsm" + "github.com/looplab/fsm/v2" ) func main() { - fsm := fsm.NewFSM( + fsm, err := fsm.New( "idle", - fsm.Events{ - {Name: "produce", Src: []string{"idle"}, Dst: "idle"}, - {Name: "consume", Src: []string{"idle"}, Dst: "idle"}, + fsm.Transitions[string, string]{ + {Event: "produce", Src: []string{"idle"}, Dst: "idle"}, + {Event: "consume", Src: []string{"idle"}, Dst: "idle"}, }, - fsm.Callbacks{ - "produce": func(e *fsm.Event) { - e.FSM.SetMetadata("message", "hii") - fmt.Println("produced data") + fsm.Callbacks[string, string]{ + fsm.Callback[string, string]{When: fsm.BeforeEvent, Event: "sproduce", + F: func(e *fsm.CallbackContext[string, string]) { + e.FSM.SetMetadata("message", "hii") + fmt.Println("produced data") + }, }, - "consume": func(e *fsm.Event) { - message, ok := e.FSM.Metadata("message") - if ok { - fmt.Println("message = " + message.(string)) - } - + fsm.Callback[string, string]{When: fsm.BeforeEvent, Event: "consume", + F: func(e *fsm.CallbackContext[string, string]) { + message, ok := e.FSM.Metadata("message") + if ok { + fmt.Println("message = " + message.(string)) + } + }, }, }, ) - + if err != nil { + fmt.Println(err) + } fmt.Println(fsm.Current()) - err := fsm.Event("produce") + err = fsm.Event("produce") if err != nil { fmt.Println(err) } diff --git a/examples/generic.go b/examples/generic.go new file mode 100644 index 0000000..a1f6c2f --- /dev/null +++ b/examples/generic.go @@ -0,0 +1,66 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "fmt" + + "github.com/looplab/fsm/v2" +) + +type MyEvent string +type MyState string + +const ( + Close MyEvent = "close" + Open MyEvent = "open" + Any MyEvent = "" + + IsClosed MyState = "closed" + IsOpen MyState = "open" +) + +func main() { + fsm, err := fsm.New( + IsClosed, + fsm.Transitions[MyEvent, MyState]{ + {Event: Open, Src: []MyState{IsClosed}, Dst: IsOpen}, + {Event: Close, Src: []MyState{IsOpen}, Dst: IsClosed}, + }, + fsm.Callbacks[MyEvent, MyState]{ + fsm.Callback[MyEvent, MyState]{ + When: fsm.AfterEvent, Event: Open, + F: func(cr *fsm.CallbackContext[MyEvent, MyState]) { + fmt.Printf("callback: event:%s src:%s dst:%s\n", cr.Event, cr.Src, cr.Dst) + }, + }, + fsm.Callback[MyEvent, MyState]{ + When: fsm.BeforeEvent, + Event: Open, + + F: func(cr *fsm.CallbackContext[MyEvent, MyState]) { + fmt.Printf("callback after all: event:%s src:%s dst:%s\n", cr.Event, cr.Src, cr.Dst) + }, + }, + }, + ) + if err != nil { + fmt.Println(err) + } + fmt.Println(fsm.Current()) + err = fsm.Event(Open) + if err != nil { + fmt.Println(err) + } + fmt.Println(fsm.Current()) + err = fsm.Event(Close) + if err != nil { + fmt.Println(err) + } + fmt.Println(fsm.Current()) + // Output: + // closed + // open + // closed +} diff --git a/examples/simple.go b/examples/simple.go index 740e4d9..ae5016b 100644 --- a/examples/simple.go +++ b/examples/simple.go @@ -1,25 +1,29 @@ +//go:build ignore // +build ignore package main import ( "fmt" - "github.com/looplab/fsm" + + "github.com/looplab/fsm/v2" ) func main() { - fsm := fsm.NewFSM( + fsm, err := fsm.New( "closed", - fsm.Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + fsm.Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - fsm.Callbacks{}, + fsm.Callbacks[string, string]{}, ) - + if err != nil { + fmt.Println(err) + } fmt.Println(fsm.Current()) - err := fsm.Event("open") + err = fsm.Event("open") if err != nil { fmt.Println(err) } diff --git a/examples/struct.go b/examples/struct.go index 17fa712..cf0c31b 100644 --- a/examples/struct.go +++ b/examples/struct.go @@ -1,15 +1,17 @@ +//go:build ignore // +build ignore package main import ( "fmt" - "github.com/looplab/fsm" + + "github.com/looplab/fsm/v2" ) type Door struct { To string - FSM *fsm.FSM + FSM *fsm.FSM[string, string] } func NewDoor(to string) *Door { @@ -17,21 +19,26 @@ func NewDoor(to string) *Door { To: to, } - d.FSM = fsm.NewFSM( + var err error + d.FSM, err = fsm.New( "closed", - fsm.Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + fsm.Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - fsm.Callbacks{ - "enter_state": func(e *fsm.Event) { d.enterState(e) }, + fsm.Callbacks[string, string]{ + fsm.Callback[string, string]{When: fsm.EnterAllStates, + F: func(e *fsm.CallbackContext[string, string]) { d.enterState(e) }, + }, }, ) - + if err != nil { + fmt.Println(err) + } return d } -func (d *Door) enterState(e *fsm.Event) { +func (d *Door) enterState(e *fsm.CallbackContext[string, string]) { fmt.Printf("The door to %s is %s\n", d.To, e.Dst) } diff --git a/fsm.go b/fsm.go index f5c8cf8..59c9353 100644 --- a/fsm.go +++ b/fsm.go @@ -21,198 +21,109 @@ // // Fysom for Python // https://github.com/oxplot/fysom (forked at https://github.com/mriehl/fysom) -// package fsm import ( - "strings" + "cmp" + "errors" + "fmt" "sync" ) // transitioner is an interface for the FSM's transition function. -type transitioner interface { - transition(*FSM) error +type transitioner[E cmp.Ordered, S cmp.Ordered] interface { + transition(*FSM[E, S]) error } // FSM is the state machine that holds the current state. -// -// It has to be created with NewFSM to function properly. -type FSM struct { +// E is the event +// S is the state +// It has to be created with New to function properly. +type FSM[E cmp.Ordered, S cmp.Ordered] struct { // current is the state that the FSM is currently in. - current string + current S // transitions maps events and source states to destination states. - transitions map[eKey]string + transitions map[eKey[E, S]]S // callbacks maps events and targets to callback functions. - callbacks map[cKey]Callback + callbacks Callbacks[E, S] // transition is the internal transition functions used either directly // or when Transition is called in an asynchronous state transition. transition func() - // transitionerObj calls the FSM's transition() function. - transitionerObj transitioner + // transitioner calls the FSM's transition() function. + transitioner transitioner[E, S] // stateMu guards access to the current state. stateMu sync.RWMutex // eventMu guards access to Event() and Transition(). eventMu sync.Mutex + // metadata can be used to store and load data that maybe used across events // use methods SetMetadata() and Metadata() to store and load data - metadata map[string]interface{} - + metadata map[string]any + // metadataMu guards access to the metadata. metadataMu sync.RWMutex } -// EventDesc represents an event when initializing the FSM. +// Transition represents an event when initializing the FSM. // // The event can have one or more source states that is valid for performing // the transition. If the FSM is in one of the source states it will end up in // the specified destination state, calling all defined callbacks as it goes. -type EventDesc struct { - // Name is the event name used when calling for a transition. - Name string +type Transition[E cmp.Ordered, S cmp.Ordered] struct { + // Event is the event used when calling for a transition. + Event E // Src is a slice of source states that the FSM must be in to perform a // state transition. - Src []string + Src []S // Dst is the destination state that the FSM will be in if the transition // succeeds. - Dst string + Dst S } -// Callback is a function type that callbacks should use. Event is the current -// event info as the callback happens. -type Callback func(*Event) +// Transitions is a shorthand for defining the transition map in NewFSM. +type Transitions[E cmp.Ordered, S cmp.Ordered] []Transition[E, S] -// Events is a shorthand for defining the transition map in NewFSM. -type Events []EventDesc - -// Callbacks is a shorthand for defining the callbacks in NewFSM. -type Callbacks map[string]Callback - -// NewFSM constructs a FSM from events and callbacks. -// -// The events and transitions are specified as a slice of Event structs -// specified as Events. Each Event is mapped to one or more internal -// transitions from Event.Src to Event.Dst. -// -// Callbacks are added as a map specified as Callbacks where the key is parsed -// as the callback event as follows, and called in the same order: -// -// 1. before_ - called before event named -// -// 2. before_event - called before all events -// -// 3. leave_ - called before leaving -// -// 4. leave_state - called before leaving all states -// -// 5. enter_ - called after entering -// -// 6. enter_state - called after entering all states -// -// 7. after_ - called after event named +// New constructs a generic FSM with a initial state S, for events E. +// E is the event type, S is the state type. // -// 8. after_event - called after all events +// Transitions define the state transitions that can be performed for a given event +// and a slice of source states, the destination state and the callback function. // -// There are also two short form versions for the most commonly used callbacks. -// They are simply the name of the event or state: -// -// 1. - called after entering -// -// 2. - called after event named -// -// If both a shorthand version and a full version is specified it is undefined -// which version of the callback will end up in the internal map. This is due -// to the pseudo random nature of Go maps. No checking for multiple keys is -// currently performed. -func NewFSM(initial string, events []EventDesc, callbacks map[string]Callback) *FSM { - f := &FSM{ - transitionerObj: &transitionerStruct{}, - current: initial, - transitions: make(map[eKey]string), - callbacks: make(map[cKey]Callback), - metadata: make(map[string]interface{}), +// Callbacks are added as a slice specified as Callbacks and called in the same order. +func New[E cmp.Ordered, S cmp.Ordered](initial S, transitions Transitions[E, S], callbacks Callbacks[E, S]) (*FSM[E, S], error) { + f := &FSM[E, S]{ + current: initial, + transitioner: &defaultTransitioner[E, S]{}, + transitions: map[eKey[E, S]]S{}, + callbacks: callbacks, + metadata: map[string]any{}, } // Build transition map and store sets of all events and states. - allEvents := make(map[string]bool) - allStates := make(map[string]bool) - for _, e := range events { + for _, e := range transitions { for _, src := range e.Src { - f.transitions[eKey{e.Name, src}] = e.Dst - allStates[src] = true - allStates[e.Dst] = true + // FIXME eKey still required? + f.transitions[eKey[E, S]{e.Event, src}] = e.Dst } - allEvents[e.Name] = true } - - // Map all callbacks to events/states. - for name, fn := range callbacks { - var target string - var callbackType int - - switch { - case strings.HasPrefix(name, "before_"): - target = strings.TrimPrefix(name, "before_") - if target == "event" { - target = "" - callbackType = callbackBeforeEvent - } else if _, ok := allEvents[target]; ok { - callbackType = callbackBeforeEvent - } - case strings.HasPrefix(name, "leave_"): - target = strings.TrimPrefix(name, "leave_") - if target == "state" { - target = "" - callbackType = callbackLeaveState - } else if _, ok := allStates[target]; ok { - callbackType = callbackLeaveState - } - case strings.HasPrefix(name, "enter_"): - target = strings.TrimPrefix(name, "enter_") - if target == "state" { - target = "" - callbackType = callbackEnterState - } else if _, ok := allStates[target]; ok { - callbackType = callbackEnterState - } - case strings.HasPrefix(name, "after_"): - target = strings.TrimPrefix(name, "after_") - if target == "event" { - target = "" - callbackType = callbackAfterEvent - } else if _, ok := allEvents[target]; ok { - callbackType = callbackAfterEvent - } - default: - target = name - if _, ok := allStates[target]; ok { - callbackType = callbackEnterState - } else if _, ok := allEvents[target]; ok { - callbackType = callbackAfterEvent - } - } - - if callbackType != callbackNone { - f.callbacks[cKey{target, callbackType}] = fn - } - } - - return f + err := callbacks.validate() + return f, err } // Current returns the current state of the FSM. -func (f *FSM) Current() string { +func (f *FSM[E, S]) Current() S { f.stateMu.RLock() defer f.stateMu.RUnlock() return f.current } // Is returns true if state is the current state. -func (f *FSM) Is(state string) bool { +func (f *FSM[E, S]) Is(state S) bool { f.stateMu.RLock() defer f.stateMu.RUnlock() return state == f.current @@ -220,26 +131,32 @@ func (f *FSM) Is(state string) bool { // SetState allows the user to move to the given state from current state. // The call does not trigger any callbacks, if defined. -func (f *FSM) SetState(state string) { +func (f *FSM[E, S]) SetState(state S) { f.stateMu.Lock() defer f.stateMu.Unlock() f.current = state } // Can returns true if event can occur in the current state. -func (f *FSM) Can(event string) bool { +func (f *FSM[E, S]) Can(event E) bool { f.stateMu.RLock() defer f.stateMu.RUnlock() - _, ok := f.transitions[eKey{event, f.current}] + _, ok := f.transitions[eKey[E, S]{event, f.current}] return ok && (f.transition == nil) } +// Cannot returns true if event can not occur in the current state. +// It is a convenience method to help code read nicely. +func (f *FSM[E, S]) Cannot(event E) bool { + return !f.Can(event) +} + // AvailableTransitions returns a list of transitions available in the // current state. -func (f *FSM) AvailableTransitions() []string { +func (f *FSM[E, S]) AvailableTransitions() []E { f.stateMu.RLock() defer f.stateMu.RUnlock() - var transitions []string + var transitions []E for key := range f.transitions { if key.src == f.current { transitions = append(transitions, key.event) @@ -248,14 +165,8 @@ func (f *FSM) AvailableTransitions() []string { return transitions } -// Cannot returns true if event can not occur in the current state. -// It is a convenience method to help code read nicely. -func (f *FSM) Cannot(event string) bool { - return !f.Can(event) -} - // Metadata returns the value stored in metadata -func (f *FSM) Metadata(key string) (interface{}, bool) { +func (f *FSM[E, S]) Metadata(key string) (any, bool) { f.metadataMu.RLock() defer f.metadataMu.RUnlock() dataElement, ok := f.metadata[key] @@ -263,10 +174,10 @@ func (f *FSM) Metadata(key string) (interface{}, bool) { } // SetMetadata stores the dataValue in metadata indexing it with key -func (f *FSM) SetMetadata(key string, dataValue interface{}) { +func (f *FSM[E, S]) SetMetadata(key string, value any) { f.metadataMu.Lock() defer f.metadataMu.Unlock() - f.metadata[key] = dataValue + f.metadata[key] = value } // Event initiates a state transition with the named event. @@ -286,7 +197,7 @@ func (f *FSM) SetMetadata(key string, dataValue interface{}) { // // The last error should never occur in this situation and is a sign of an // internal bug. -func (f *FSM) Event(event string, args ...interface{}) error { +func (f *FSM[E, S]) Event(event E, args ...any) error { f.eventMu.Lock() defer f.eventMu.Unlock() @@ -294,20 +205,29 @@ func (f *FSM) Event(event string, args ...interface{}) error { defer f.stateMu.RUnlock() if f.transition != nil { - return InTransitionError{event} + return InTransitionError{fmt.Sprintf("%v", event)} } - dst, ok := f.transitions[eKey{event, f.current}] + dst, ok := f.transitions[eKey[E, S]{event, f.current}] if !ok { for ekey := range f.transitions { if ekey.event == event { - return InvalidEventError{event, f.current} + return InvalidEventError[E, S]{event, f.current} } } - return UnknownEventError{event} + return UnknownEventError{fmt.Sprintf("%v", event)} } - e := &Event{f, event, f.current, dst, nil, args, false, false} + e := &CallbackContext[E, S]{ + FSM: f, + Event: event, + Src: f.current, + Dst: dst, + Err: nil, + Args: args, + canceled: false, + async: false, + } err := f.beforeEventCallbacks(e) if err != nil { @@ -330,7 +250,8 @@ func (f *FSM) Event(event string, args ...interface{}) error { } if err = f.leaveStateCallbacks(e); err != nil { - if _, ok := err.(CanceledError); ok { + var ce *CanceledError + if errors.As(err, &ce) { f.transition = nil } return err @@ -348,26 +269,26 @@ func (f *FSM) Event(event string, args ...interface{}) error { } // Transition wraps transitioner.transition. -func (f *FSM) Transition() error { +func (f *FSM[E, S]) Transition() error { f.eventMu.Lock() defer f.eventMu.Unlock() return f.doTransition() } // doTransition wraps transitioner.transition. -func (f *FSM) doTransition() error { - return f.transitionerObj.transition(f) +func (f *FSM[E, S]) doTransition() error { + return f.transitioner.transition(f) } -// transitionerStruct is the default implementation of the transitioner +// defaultTransitioner is the default implementation of the transitioner // interface. Other implementations can be swapped in for testing. -type transitionerStruct struct{} +type defaultTransitioner[E cmp.Ordered, S cmp.Ordered] struct{} // Transition completes an asynchronous state change. // // The callback for leave_ must previously have called Async on its // event to have initiated an asynchronous state transition. -func (t transitionerStruct) transition(f *FSM) error { +func (t defaultTransitioner[E, S]) transition(f *FSM[E, S]) error { if f.transition == nil { return NotInTransitionError{} } @@ -376,92 +297,89 @@ func (t transitionerStruct) transition(f *FSM) error { return nil } -// beforeEventCallbacks calls the before_ callbacks, first the named then the +// beforeEventCallbacks calls the before callbacks, first the named then the // general version. -func (f *FSM) beforeEventCallbacks(e *Event) error { - if fn, ok := f.callbacks[cKey{e.Event, callbackBeforeEvent}]; ok { - fn(e) - if e.canceled { - return CanceledError{e.Err} +func (f *FSM[E, S]) beforeEventCallbacks(cc *CallbackContext[E, S]) error { + for _, cb := range f.callbacks { + if cb.When == BeforeEvent { + if cb.Event == cc.Event { + cb.F(cc) + if cc.canceled { + return CanceledError{cc.Err} + } + } } - } - if fn, ok := f.callbacks[cKey{"", callbackBeforeEvent}]; ok { - fn(e) - if e.canceled { - return CanceledError{e.Err} + if cb.When == BeforeAllEvents { + cb.F(cc) + if cc.canceled { + return CanceledError{cc.Err} + } } } return nil } -// leaveStateCallbacks calls the leave_ callbacks, first the named then the +// leaveStateCallbacks calls the leave callbacks, first the named then the // general version. -func (f *FSM) leaveStateCallbacks(e *Event) error { - if fn, ok := f.callbacks[cKey{f.current, callbackLeaveState}]; ok { - fn(e) - if e.canceled { - return CanceledError{e.Err} - } else if e.async { - return AsyncError{e.Err} +func (f *FSM[E, S]) leaveStateCallbacks(cc *CallbackContext[E, S]) error { + for _, cb := range f.callbacks { + if cb.When == LeaveState { + if cb.State == cc.Src { + cb.F(cc) + if cc.canceled { + return CanceledError{cc.Err} + } else if cc.async { + return AsyncError{cc.Err} + } + } } - } - if fn, ok := f.callbacks[cKey{"", callbackLeaveState}]; ok { - fn(e) - if e.canceled { - return CanceledError{e.Err} - } else if e.async { - return AsyncError{e.Err} + if cb.When == LeaveAllStates { + cb.F(cc) + if cc.canceled { + return CanceledError{cc.Err} + } else if cc.async { + return AsyncError{cc.Err} + } } } return nil } -// enterStateCallbacks calls the enter_ callbacks, first the named then the +// enterStateCallbacks calls the enter callbacks, first the named then the // general version. -func (f *FSM) enterStateCallbacks(e *Event) { - if fn, ok := f.callbacks[cKey{f.current, callbackEnterState}]; ok { - fn(e) - } - if fn, ok := f.callbacks[cKey{"", callbackEnterState}]; ok { - fn(e) +func (f *FSM[E, S]) enterStateCallbacks(cc *CallbackContext[E, S]) { + for _, cb := range f.callbacks { + if cb.When == EnterState { + if cb.State == cc.Dst { + cb.F(cc) + } + } + if cb.When == EnterAllStates { + cb.F(cc) + } } } -// afterEventCallbacks calls the after_ callbacks, first the named then the +// afterEventCallbacks calls the after callbacks, first the named then the // general version. -func (f *FSM) afterEventCallbacks(e *Event) { - if fn, ok := f.callbacks[cKey{e.Event, callbackAfterEvent}]; ok { - fn(e) - } - if fn, ok := f.callbacks[cKey{"", callbackAfterEvent}]; ok { - fn(e) +func (f *FSM[E, S]) afterEventCallbacks(cc *CallbackContext[E, S]) { + for _, cb := range f.callbacks { + if cb.When == AfterEvent { + if cb.Event == cc.Event { + cb.F(cc) + } + } + if cb.When == AfterAllEvents { + cb.F(cc) + } } } -const ( - callbackNone int = iota - callbackBeforeEvent - callbackLeaveState - callbackEnterState - callbackAfterEvent -) - -// cKey is a struct key used for keeping the callbacks mapped to a target. -type cKey struct { - // target is either the name of a state or an event depending on which - // callback type the key refers to. It can also be "" for a non-targeted - // callback like before_event. - target string - - // callbackType is the situation when the callback will be run. - callbackType int -} - // eKey is a struct key used for storing the transition map. -type eKey struct { +type eKey[E cmp.Ordered, S cmp.Ordered] struct { // event is the name of the event that the keys refers to. - event string + event E // src is the source from where the event can transition. - src string + src S } diff --git a/fsm_test.go b/fsm_test.go index 431fd65..833e353 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -15,6 +15,7 @@ package fsm import ( + "cmp" "fmt" "sort" "sync" @@ -22,21 +23,24 @@ import ( "time" ) -type fakeTransitionerObj struct { +type fakeTransitioner[E cmp.Ordered, S cmp.Ordered] struct { } -func (t fakeTransitionerObj) transition(f *FSM) error { +func (t fakeTransitioner[E, S]) transition(f *FSM[E, S]) error { return &InternalError{} } func TestSameState(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "start"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "start"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } _ = fsm.Event("run") if fsm.Current() != "start" { t.Error("expected state to be 'start'") @@ -44,80 +48,99 @@ func TestSameState(t *testing.T) { } func TestSetState(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "walking", - Events{ - {Name: "walk", Src: []string{"start"}, Dst: "walking"}, + Transitions[string, string]{ + {Event: "walk", Src: []string{"start"}, Dst: "walking"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } + fsm.SetState("start") if fsm.Current() != "start" { t.Error("expected state to be 'walking'") } - err := fsm.Event("walk") + err = fsm.Event("walk") if err != nil { t.Error("transition is expected no error") } } func TestBadTransition(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "running"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "running"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) - fsm.transitionerObj = new(fakeTransitionerObj) - err := fsm.Event("run") + if err != nil { + t.Errorf("constructor failed:%s", err) + } + + fsm.transitioner = new(fakeTransitioner[string, string]) + err = fsm.Event("run") if err == nil { t.Error("bad transition should give an error") } } func TestInappropriateEvent(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) - err := fsm.Event("close") - if e, ok := err.(InvalidEventError); !ok && e.Event != "close" && e.State != "closed" { + if err != nil { + t.Errorf("constructor failed:%s", err) + } + + err = fsm.Event("close") + if e, ok := err.(InvalidEventError[string, string]); !ok && e.Event != "close" && e.State != "closed" { t.Error("expected 'InvalidEventError' with correct state and event") } } func TestInvalidEvent(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) - err := fsm.Event("lock") + if err != nil { + t.Errorf("constructor failed:%s", err) + } + + err = fsm.Event("lock") if e, ok := err.(UnknownEventError); !ok && e.Event != "close" { t.Error("expected 'UnknownEventError' with correct event") } } func TestMultipleSources(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "one", - Events{ - {Name: "first", Src: []string{"one"}, Dst: "two"}, - {Name: "second", Src: []string{"two"}, Dst: "three"}, - {Name: "reset", Src: []string{"one", "two", "three"}, Dst: "one"}, + Transitions[string, string]{ + {Event: "first", Src: []string{"one"}, Dst: "two"}, + {Event: "second", Src: []string{"two"}, Dst: "three"}, + {Event: "reset", Src: []string{"one", "two", "three"}, Dst: "one"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } - err := fsm.Event("first") + err = fsm.Event("first") if err != nil { t.Errorf("transition failed %v", err) } @@ -152,19 +175,22 @@ func TestMultipleSources(t *testing.T) { } func TestMultipleEvents(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "first", Src: []string{"start"}, Dst: "one"}, - {Name: "second", Src: []string{"start"}, Dst: "two"}, - {Name: "reset", Src: []string{"one"}, Dst: "reset_one"}, - {Name: "reset", Src: []string{"two"}, Dst: "reset_two"}, - {Name: "reset", Src: []string{"reset_one", "reset_two"}, Dst: "start"}, + Transitions[string, string]{ + {Event: "first", Src: []string{"start"}, Dst: "one"}, + {Event: "second", Src: []string{"start"}, Dst: "two"}, + {Event: "reset", Src: []string{"one"}, Dst: "reset_one"}, + {Event: "reset", Src: []string{"two"}, Dst: "reset_two"}, + {Event: "reset", Src: []string{"reset_one", "reset_two"}, Dst: "start"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } - err := fsm.Event("first") + err = fsm.Event("first") if err != nil { t.Errorf("transition failed %v", err) } @@ -209,28 +235,39 @@ func TestGenericCallbacks(t *testing.T) { enterState := false afterEvent := false - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "before_event": func(e *Event) { - beforeEvent = true + Callbacks[string, string]{ + Callback[string, string]{When: BeforeAllEvents, + F: func(e *CallbackContext[string, string]) { + beforeEvent = true + }, }, - "leave_state": func(e *Event) { - leaveState = true + Callback[string, string]{When: LeaveAllStates, + F: func(e *CallbackContext[string, string]) { + leaveState = true + }, }, - "enter_state": func(e *Event) { - enterState = true + Callback[string, string]{When: EnterAllStates, + F: func(e *CallbackContext[string, string]) { + enterState = true + }, }, - "after_event": func(e *Event) { - afterEvent = true + Callback[string, string]{When: AfterAllEvents, + F: func(e *CallbackContext[string, string]) { + afterEvent = true + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } - err := fsm.Event("run") + err = fsm.Event("run") if err != nil { t.Errorf("transition failed %v", err) } @@ -245,28 +282,39 @@ func TestSpecificCallbacks(t *testing.T) { enterState := false afterEvent := false - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "before_run": func(e *Event) { - beforeEvent = true + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + beforeEvent = true + }, }, - "leave_start": func(e *Event) { - leaveState = true + Callback[string, string]{When: LeaveState, State: "start", + F: func(e *CallbackContext[string, string]) { + leaveState = true + }, }, - "enter_end": func(e *Event) { - enterState = true + Callback[string, string]{When: EnterState, State: "end", + F: func(e *CallbackContext[string, string]) { + enterState = true + }, }, - "after_run": func(e *Event) { - afterEvent = true + Callback[string, string]{When: AfterEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + afterEvent = true + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } - err := fsm.Event("run") + err = fsm.Event("run") if err != nil { t.Errorf("transition failed %v", err) } @@ -279,22 +327,29 @@ func TestSpecificCallbacksShortform(t *testing.T) { enterState := false afterEvent := false - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "end": func(e *Event) { - enterState = true + Callbacks[string, string]{ + Callback[string, string]{When: EnterState, State: "end", + F: func(e *CallbackContext[string, string]) { + enterState = true + }, }, - "run": func(e *Event) { - afterEvent = true + Callback[string, string]{When: AfterEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + afterEvent = true + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } - err := fsm.Event("run") + err = fsm.Event("run") if err != nil { t.Errorf("transition failed %v", err) } @@ -306,19 +361,24 @@ func TestSpecificCallbacksShortform(t *testing.T) { func TestBeforeEventWithoutTransition(t *testing.T) { beforeEvent := true - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "dontrun", Src: []string{"start"}, Dst: "start"}, + Transitions[string, string]{ + {Event: "dontrun", Src: []string{"start"}, Dst: "start"}, }, - Callbacks{ - "before_event": func(e *Event) { - beforeEvent = true + Callbacks[string, string]{ + Callback[string, string]{When: BeforeAllEvents, + F: func(e *CallbackContext[string, string]) { + beforeEvent = true + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } - err := fsm.Event("dontrun") + err = fsm.Event("dontrun") if e, ok := err.(NoTransitionError); !ok && e.Err != nil { t.Error("expected 'NoTransitionError' without custom error") } @@ -332,17 +392,22 @@ func TestBeforeEventWithoutTransition(t *testing.T) { } func TestCancelBeforeGenericEvent(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "before_event": func(e *Event) { - e.Cancel() + Callbacks[string, string]{ + Callback[string, string]{When: BeforeAllEvents, + F: func(e *CallbackContext[string, string]) { + e.Cancel() + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } _ = fsm.Event("run") if fsm.Current() != "start" { t.Error("expected state to be 'start'") @@ -350,17 +415,22 @@ func TestCancelBeforeGenericEvent(t *testing.T) { } func TestCancelBeforeSpecificEvent(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "before_run": func(e *Event) { - e.Cancel() + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + e.Cancel() + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } _ = fsm.Event("run") if fsm.Current() != "start" { t.Error("expected state to be 'start'") @@ -368,17 +438,23 @@ func TestCancelBeforeSpecificEvent(t *testing.T) { } func TestCancelLeaveGenericState(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "leave_state": func(e *Event) { - e.Cancel() + Callbacks[string, string]{ + Callback[string, string]{When: LeaveState, State: "start", + F: func(e *CallbackContext[string, string]) { + e.Cancel() + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } + _ = fsm.Event("run") if fsm.Current() != "start" { t.Error("expected state to be 'start'") @@ -386,17 +462,23 @@ func TestCancelLeaveGenericState(t *testing.T) { } func TestCancelLeaveSpecificState(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "leave_start": func(e *Event) { - e.Cancel() + Callbacks[string, string]{ + Callback[string, string]{When: LeaveState, State: "start", + F: func(e *CallbackContext[string, string]) { + e.Cancel() + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } + _ = fsm.Event("run") if fsm.Current() != "start" { t.Error("expected state to be 'start'") @@ -404,18 +486,23 @@ func TestCancelLeaveSpecificState(t *testing.T) { } func TestCancelWithError(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "before_event": func(e *Event) { - e.Cancel(fmt.Errorf("error")) + Callbacks[string, string]{ + Callback[string, string]{When: BeforeAllEvents, + F: func(e *CallbackContext[string, string]) { + e.Cancel(fmt.Errorf("error")) + }, }, }, ) - err := fsm.Event("run") + if err != nil { + t.Errorf("constructor failed:%s", err) + } + err = fsm.Event("run") if _, ok := err.(CanceledError); !ok { t.Error("expected only 'CanceledError'") } @@ -430,22 +517,27 @@ func TestCancelWithError(t *testing.T) { } func TestAsyncTransitionGenericState(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "leave_state": func(e *Event) { - e.Async() + Callbacks[string, string]{ + Callback[string, string]{When: LeaveState, State: "start", + F: func(e *CallbackContext[string, string]) { + e.Async() + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } _ = fsm.Event("run") if fsm.Current() != "start" { t.Error("expected state to be 'start'") } - err := fsm.Transition() + err = fsm.Transition() if err != nil { t.Errorf("transition failed %v", err) } @@ -455,22 +547,27 @@ func TestAsyncTransitionGenericState(t *testing.T) { } func TestAsyncTransitionSpecificState(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "leave_start": func(e *Event) { - e.Async() + Callbacks[string, string]{ + Callback[string, string]{When: LeaveState, State: "start", + F: func(e *CallbackContext[string, string]) { + e.Async() + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } _ = fsm.Event("run") if fsm.Current() != "start" { t.Error("expected state to be 'start'") } - err := fsm.Transition() + err = fsm.Transition() if err != nil { t.Errorf("transition failed %v", err) } @@ -480,20 +577,25 @@ func TestAsyncTransitionSpecificState(t *testing.T) { } func TestAsyncTransitionInProgress(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, - {Name: "reset", Src: []string{"end"}, Dst: "start"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, + {Event: "reset", Src: []string{"end"}, Dst: "start"}, }, - Callbacks{ - "leave_start": func(e *Event) { - e.Async() + Callbacks[string, string]{ + Callback[string, string]{When: LeaveState, State: "start", + F: func(e *CallbackContext[string, string]) { + e.Async() + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } _ = fsm.Event("run") - err := fsm.Event("reset") + err = fsm.Event("reset") if e, ok := err.(InTransitionError); !ok && e.Event != "reset" { t.Error("expected 'InTransitionError' with correct state") } @@ -511,31 +613,39 @@ func TestAsyncTransitionInProgress(t *testing.T) { } func TestAsyncTransitionNotInProgress(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, - {Name: "reset", Src: []string{"end"}, Dst: "start"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, + {Event: "reset", Src: []string{"end"}, Dst: "start"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) - err := fsm.Transition() + if err != nil { + t.Errorf("constructor failed:%s", err) + } + err = fsm.Transition() if _, ok := err.(NotInTransitionError); !ok { t.Error("expected 'NotInTransitionError'") } } func TestCallbackNoError(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "run": func(e *Event) { + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } e := fsm.Event("run") if e != nil { t.Error("expected no error") @@ -543,17 +653,22 @@ func TestCallbackNoError(t *testing.T) { } func TestCallbackError(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "run": func(e *Event) { - e.Err = fmt.Errorf("error") + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + e.Err = fmt.Errorf("error") + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } e := fsm.Event("run") if e.Error() != "error" { t.Error("expected error to be 'error'") @@ -561,27 +676,62 @@ func TestCallbackError(t *testing.T) { } func TestCallbackArgs(t *testing.T) { - fsm := NewFSM( + fsm, err := New( + "start", + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + if len(e.Args) != 1 { + t.Error("too few arguments") + } + arg, ok := e.Args[0].(string) + if !ok { + t.Error("not a string argument") + } + if arg != "test" { + t.Error("incorrect argument") + } + }, + }, + }, + ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } + err = fsm.Event("run", "test") + if err != nil { + t.Errorf("transition failed %v", err) + } +} + +func TestCallbackMeta(t *testing.T) { + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "run": func(e *Event) { - if len(e.Args) != 1 { - t.Error("too few arguments") - } - arg, ok := e.Args[0].(string) - if !ok { - t.Error("not a string argument") - } - if arg != "test" { - t.Error("incorrect argument") - } + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + value, ok := e.FSM.Metadata("key") + if !ok { + t.Error("no metadata with `key` found") + } + if value != "value" { + t.Error("incorrect value") + } + }, }, }, ) - err := fsm.Event("run", "test") + if err != nil { + t.Errorf("constructor failed:%s", err) + } + fsm.SetMetadata("key", "value") + err = fsm.Event("run") if err != nil { t.Errorf("transition failed %v", err) } @@ -595,17 +745,22 @@ func TestCallbackPanic(t *testing.T) { t.Errorf("expected panic message to be '%s', got %v", panicMsg, r) } }() - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "run": func(e *Event) { - panic(panicMsg) + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + panic(panicMsg) + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } e := fsm.Event("run") if e.Error() != "error" { t.Error("expected error to be 'error'") @@ -613,42 +768,52 @@ func TestCallbackPanic(t *testing.T) { } func TestNoDeadLock(t *testing.T) { - var fsm *FSM - fsm = NewFSM( + var fsm *FSM[string, string] + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "run": func(e *Event) { - fsm.Current() // Should not result in a panic / deadlock + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + fsm.Current() // Should not result in a panic / deadlock + }, }, }, ) - err := fsm.Event("run") + if err != nil { + t.Errorf("constructor failed:%s", err) + } + err = fsm.Event("run") if err != nil { t.Errorf("transition failed %v", err) } } func TestThreadSafetyRaceCondition(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, }, - Callbacks{ - "run": func(e *Event) { + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + }, }, }, ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() _ = fsm.Current() }() - err := fsm.Event("run") + err = fsm.Event("run") if err != nil { t.Errorf("transition failed %v", err) } @@ -656,97 +821,124 @@ func TestThreadSafetyRaceCondition(t *testing.T) { } func TestDoubleTransition(t *testing.T) { - var fsm *FSM + var fsm *FSM[string, string] var wg sync.WaitGroup wg.Add(2) - fsm = NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "end"}, - }, - Callbacks{ - "before_run": func(e *Event) { - wg.Done() - // Imagine a concurrent event coming in of the same type while - // the data access mutex is unlocked because the current transition - // is running its event callbacks, getting around the "active" - // transition checks - if len(e.Args) == 0 { - // Must be concurrent so the test may pass when we add a mutex that synchronizes - // calls to Event(...). It will then fail as an inappropriate transition as we - // have changed state. - go func() { - if err := fsm.Event("run", "second run"); err != nil { - fmt.Println(err) - wg.Done() // It should fail, and then we unfreeze the test. - } - }() - time.Sleep(20 * time.Millisecond) - } else { - panic("Was able to reissue an event mid-transition") - } - }, - }, - ) - if err := fsm.Event("run"); err != nil { + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "run", + F: func(e *CallbackContext[string, string]) { + wg.Done() + // Imagine a concurrent event coming in of the same type while + // the data access mutex is unlocked because the current transition + // is running its event callbacks, getting around the "active" + // transition checks + if len(e.Args) == 0 { + // Must be concurrent so the test may pass when we add a mutex that synchronizes + // calls to Event(...). It will then fail as an inappropriate transition as we + // have changed state. + go func() { + if err := fsm.Event("run", "second run"); err != nil { + fmt.Println(err) + wg.Done() // It should fail, and then we unfreeze the test. + } + }() + time.Sleep(20 * time.Millisecond) + } else { + panic("Was able to reissue an event mid-transition") + } + }, + }, + }, + ) + if err != nil { + t.Errorf("constructor failed:%s", err) + } + if err = fsm.Event("run"); err != nil { fmt.Println(err) } wg.Wait() } func TestNoTransition(t *testing.T) { - fsm := NewFSM( + fsm, err := New( "start", - Events{ - {Name: "run", Src: []string{"start"}, Dst: "start"}, + Transitions[string, string]{ + {Event: "run", Src: []string{"start"}, Dst: "start"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) - err := fsm.Event("run") + if err != nil { + t.Errorf("constructor failed:%s", err) + } + err = fsm.Event("run") if _, ok := err.(NoTransitionError); !ok { t.Error("expected 'NoTransitionError'") } } -func ExampleNewFSM() { - fsm := NewFSM( +func ExampleNew() { + fsm, err := New( "green", - Events{ - {Name: "warn", Src: []string{"green"}, Dst: "yellow"}, - {Name: "panic", Src: []string{"yellow"}, Dst: "red"}, - {Name: "panic", Src: []string{"green"}, Dst: "red"}, - {Name: "calm", Src: []string{"red"}, Dst: "yellow"}, - {Name: "clear", Src: []string{"yellow"}, Dst: "green"}, + Transitions[string, string]{ + {Event: "warn", Src: []string{"green"}, Dst: "yellow"}, + {Event: "panic", Src: []string{"yellow"}, Dst: "red"}, + {Event: "panic", Src: []string{"green"}, Dst: "red"}, + {Event: "calm", Src: []string{"red"}, Dst: "yellow"}, + {Event: "clear", Src: []string{"yellow"}, Dst: "green"}, }, - Callbacks{ - "before_warn": func(e *Event) { - fmt.Println("before_warn") + Callbacks[string, string]{ + Callback[string, string]{When: BeforeEvent, Event: "warn", + F: func(cc *CallbackContext[string, string]) { + fmt.Println("before_warn") + }, }, - "before_event": func(e *Event) { - fmt.Println("before_event") + Callback[string, string]{When: BeforeAllEvents, + F: func(cc *CallbackContext[string, string]) { + fmt.Println("before_event") + }, }, - "leave_green": func(e *Event) { - fmt.Println("leave_green") + Callback[string, string]{When: LeaveState, State: "green", + F: func(cc *CallbackContext[string, string]) { + fmt.Println("leave_green") + }, }, - "leave_state": func(e *Event) { - fmt.Println("leave_state") + Callback[string, string]{When: LeaveAllStates, + F: func(cc *CallbackContext[string, string]) { + fmt.Println("leave_state") + }, }, - "enter_yellow": func(e *Event) { - fmt.Println("enter_yellow") + Callback[string, string]{When: EnterState, State: "yellow", + F: func(cc *CallbackContext[string, string]) { + fmt.Println("enter_yellow") + }, }, - "enter_state": func(e *Event) { - fmt.Println("enter_state") + Callback[string, string]{When: EnterAllStates, + F: func(cc *CallbackContext[string, string]) { + fmt.Println("enter_state") + }, }, - "after_warn": func(e *Event) { - fmt.Println("after_warn") + Callback[string, string]{When: AfterEvent, Event: "warn", + F: func(cc *CallbackContext[string, string]) { + fmt.Println("after_warn") + }, }, - "after_event": func(e *Event) { - fmt.Println("after_event") + Callback[string, string]{When: AfterAllEvents, + F: func(cc *CallbackContext[string, string]) { + fmt.Println("after_event") + }, }, }, ) + if err != nil { + fmt.Println(err) + } fmt.Println(fsm.Current()) - err := fsm.Event("warn") + err = fsm.Event("warn") if err != nil { fmt.Println(err) } @@ -765,27 +957,33 @@ func ExampleNewFSM() { } func ExampleFSM_Current() { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + fmt.Println(err) + } fmt.Println(fsm.Current()) // Output: closed } func ExampleFSM_Is() { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + fmt.Println(err) + } fmt.Println(fsm.Is("closed")) fmt.Println(fsm.Is("open")) // Output: @@ -794,14 +992,17 @@ func ExampleFSM_Is() { } func ExampleFSM_Can() { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + fmt.Println(err) + } fmt.Println(fsm.Can("open")) fmt.Println(fsm.Can("close")) // Output: @@ -810,15 +1011,18 @@ func ExampleFSM_Can() { } func ExampleFSM_AvailableTransitions() { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, - {Name: "kick", Src: []string{"closed"}, Dst: "broken"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, + {Event: "kick", Src: []string{"closed"}, Dst: "broken"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + fmt.Println(err) + } // sort the results ordering is consistent for the output checker transitions := fsm.AvailableTransitions() sort.Strings(transitions) @@ -828,14 +1032,17 @@ func ExampleFSM_AvailableTransitions() { } func ExampleFSM_Cannot() { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + fmt.Println(err) + } fmt.Println(fsm.Cannot("open")) fmt.Println(fsm.Cannot("close")) // Output: @@ -844,16 +1051,19 @@ func ExampleFSM_Cannot() { } func ExampleFSM_Event() { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) + if err != nil { + fmt.Println(err) + } fmt.Println(fsm.Current()) - err := fsm.Event("open") + err = fsm.Event("open") if err != nil { fmt.Println(err) } @@ -870,19 +1080,25 @@ func ExampleFSM_Event() { } func ExampleFSM_Transition() { - fsm := NewFSM( + fsm, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, }, - Callbacks{ - "leave_closed": func(e *Event) { - e.Async() + Callbacks[string, string]{ + Callback[string, string]{ + When: LeaveState, State: "closed", + F: func(cc *CallbackContext[string, string]) { + cc.Async() + }, }, }, ) - err := fsm.Event("open") + if err != nil { + fmt.Println(err) + } + err = fsm.Event("open") if e, ok := err.(AsyncError); !ok && e.Err != nil { fmt.Println(err) } @@ -896,3 +1112,130 @@ func ExampleFSM_Transition() { // closed // open } + +type MyEvent string +type MyState string + +const ( + Close MyEvent = "close" + Open MyEvent = "open" + Any MyEvent = "" + + IsClosed MyState = "closed" + IsOpen MyState = "open" +) + +func ExampleFSM_Event_generic() { + fsm, err := New( + IsClosed, + Transitions[MyEvent, MyState]{ + {Event: Open, Src: []MyState{IsClosed}, Dst: IsOpen}, + {Event: Close, Src: []MyState{IsOpen}, Dst: IsClosed}, + }, + Callbacks[MyEvent, MyState]{ + Callback[MyEvent, MyState]{ + When: BeforeEvent, + Event: Close, + F: func(cc *CallbackContext[MyEvent, MyState]) { + + }, + }, + }, + ) + if err != nil { + fmt.Println(err) + } + fmt.Println(fsm.Current()) + err = fsm.Event(Open) + if err != nil { + fmt.Println(err) + } + fmt.Println(fsm.Current()) + err = fsm.Event(Close) + if err != nil { + fmt.Println(err) + } + fmt.Println(fsm.Current()) + // Output: + // closed + // open + // closed +} + +func BenchmarkGenericFSM(b *testing.B) { + fsm, err := New( + IsClosed, + Transitions[MyEvent, MyState]{ + {Event: Open, Src: []MyState{IsClosed}, Dst: IsOpen}, + {Event: Close, Src: []MyState{IsOpen}, Dst: IsClosed}, + }, + Callbacks[MyEvent, MyState]{ + + Callback[MyEvent, MyState]{ + When: BeforeEvent, + Event: Open, + F: func(cc *CallbackContext[MyEvent, MyState]) { + + }, + }, + }, + ) + if err != nil { + fmt.Println(err) + } + for i := 0; i < b.N; i++ { + _ = fsm.Event(Open) + } +} +func BenchmarkFSM(b *testing.B) { + fsm, err := New( + "closed", + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, + }, + Callbacks[string, string]{ + Callback[string, string]{ + When: BeforeEvent, + Event: "open", + F: func(cc *CallbackContext[string, string]) { + + }, + }, + }, + ) + if err != nil { + fmt.Println(err) + } + for i := 0; i < b.N; i++ { + _ = fsm.Event("open") + } +} + +func BenchmarkGenericFSMManyEvents(b *testing.B) { + transitions := Transitions[int, int]{} + for i := 0; i < 100; i++ { + transitions = append(transitions, Transition[int, int]{Event: i, Src: []int{i}, Dst: i + 1}) + } + callbacks := Callbacks[int, int]{} + for i := 0; i < 100; i++ { + callbacks = append(callbacks, Callback[int, int]{ + When: BeforeAllEvents, + F: func(cc *CallbackContext[int, int]) { + fmt.Print(cc.Event) + }, + }) + } + + fsm, err := New( + 0, + transitions, + callbacks, + ) + if err != nil { + fmt.Println(err) + } + for i := 0; i < b.N; i++ { + _ = fsm.Event(1) + } +} diff --git a/go.mod b/go.mod index 1af2b1c..d59dada 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/looplab/fsm +module github.com/looplab/fsm/v2 -go 1.16 +go 1.24.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/graphviz_visualizer.go b/graphviz_visualizer.go index 5a5b641..3f5148c 100644 --- a/graphviz_visualizer.go +++ b/graphviz_visualizer.go @@ -2,11 +2,12 @@ package fsm import ( "bytes" + "cmp" "fmt" ) // Visualize outputs a visualization of a FSM in Graphviz format. -func Visualize(fsm *FSM) string { +func Visualize[E cmp.Ordered, S cmp.Ordered](fsm *FSM[E, S]) string { var buf bytes.Buffer // we sort the key alphabetically to have a reproducible graph output @@ -26,19 +27,19 @@ func writeHeaderLine(buf *bytes.Buffer) { buf.WriteString("\n") } -func writeTransitions(buf *bytes.Buffer, current string, sortedEKeys []eKey, transitions map[eKey]string) { +func writeTransitions[E cmp.Ordered, S cmp.Ordered](buf *bytes.Buffer, current S, sortedEKeys []eKey[E, S], transitions map[eKey[E, S]]S) { // make sure the current state is at top for _, k := range sortedEKeys { if k.src == current { v := transitions[k] - buf.WriteString(fmt.Sprintf(` "%s" -> "%s" [ label = "%s" ];`, k.src, v, k.event)) + fmt.Fprintf(buf, ` "%v" -> "%v" [ label = "%v" ];`, k.src, v, k.event) buf.WriteString("\n") } } for _, k := range sortedEKeys { if k.src != current { v := transitions[k] - buf.WriteString(fmt.Sprintf(` "%s" -> "%s" [ label = "%s" ];`, k.src, v, k.event)) + fmt.Fprintf(buf, ` "%v" -> "%v" [ label = "%v" ];`, k.src, v, k.event) buf.WriteString("\n") } } @@ -46,13 +47,13 @@ func writeTransitions(buf *bytes.Buffer, current string, sortedEKeys []eKey, tra buf.WriteString("\n") } -func writeStates(buf *bytes.Buffer, sortedStateKeys []string) { +func writeStates[S cmp.Ordered](buf *bytes.Buffer, sortedStateKeys []S) { for _, k := range sortedStateKeys { - buf.WriteString(fmt.Sprintf(` "%s";`, k)) + fmt.Fprintf(buf, ` "%v";`, k) buf.WriteString("\n") } } func writeFooter(buf *bytes.Buffer) { - buf.WriteString(fmt.Sprintln("}")) + fmt.Fprintln(buf, "}") } diff --git a/graphviz_visualizer_test.go b/graphviz_visualizer_test.go index b28c476..d260860 100644 --- a/graphviz_visualizer_test.go +++ b/graphviz_visualizer_test.go @@ -7,16 +7,18 @@ import ( ) func TestGraphvizOutput(t *testing.T) { - fsmUnderTest := NewFSM( + fsmUnderTest, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, - {Name: "part-close", Src: []string{"intermediate"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, + {Event: "part-close", Src: []string{"intermediate"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) - + if err != nil { + t.Errorf("constructor failed:%s", err) + } got := Visualize(fsmUnderTest) wanted := ` digraph fsm { @@ -32,7 +34,7 @@ digraph fsm { normalizedWanted := strings.ReplaceAll(wanted, "\n", "") if normalizedGot != normalizedWanted { t.Errorf("build graphivz graph failed. \nwanted \n%s\nand got \n%s\n", wanted, got) - fmt.Println([]byte(normalizedGot)) - fmt.Println([]byte(normalizedWanted)) + fmt.Println(normalizedGot) + fmt.Println(normalizedWanted) } } diff --git a/mermaid_visualizer.go b/mermaid_visualizer.go index d9b089e..16748a8 100644 --- a/mermaid_visualizer.go +++ b/mermaid_visualizer.go @@ -2,6 +2,7 @@ package fsm import ( "bytes" + "cmp" "fmt" ) @@ -18,7 +19,7 @@ const ( ) // VisualizeForMermaidWithGraphType outputs a visualization of a FSM in Mermaid format as specified by the graphType. -func VisualizeForMermaidWithGraphType(fsm *FSM, graphType MermaidDiagramType) (string, error) { +func VisualizeForMermaidWithGraphType[E cmp.Ordered, S cmp.Ordered](fsm *FSM[E, S], graphType MermaidDiagramType) (string, error) { switch graphType { case FlowChart: return visualizeForMermaidAsFlowChart(fsm), nil @@ -29,7 +30,7 @@ func VisualizeForMermaidWithGraphType(fsm *FSM, graphType MermaidDiagramType) (s } } -func visualizeForMermaidAsStateDiagram(fsm *FSM) string { +func visualizeForMermaidAsStateDiagram[E cmp.Ordered, S cmp.Ordered](fsm *FSM[E, S]) string { var buf bytes.Buffer sortedTransitionKeys := getSortedTransitionKeys(fsm.transitions) @@ -39,7 +40,7 @@ func visualizeForMermaidAsStateDiagram(fsm *FSM) string { for _, k := range sortedTransitionKeys { v := fsm.transitions[k] - buf.WriteString(fmt.Sprintf(` %s --> %s: %s`, k.src, v, k.event)) + buf.WriteString(fmt.Sprintf(` %v --> %v: %v`, k.src, v, k.event)) buf.WriteString("\n") } @@ -47,7 +48,7 @@ func visualizeForMermaidAsStateDiagram(fsm *FSM) string { } // visualizeForMermaidAsFlowChart outputs a visualization of a FSM in Mermaid format (including highlighting of current state). -func visualizeForMermaidAsFlowChart(fsm *FSM) string { +func visualizeForMermaidAsFlowChart[E cmp.Ordered, S cmp.Ordered](fsm *FSM[E, S]) string { var buf bytes.Buffer sortedTransitionKeys := getSortedTransitionKeys(fsm.transitions) @@ -65,25 +66,25 @@ func writeFlowChartGraphType(buf *bytes.Buffer) { buf.WriteString("graph LR\n") } -func writeFlowChartStates(buf *bytes.Buffer, sortedStates []string, statesToIDMap map[string]string) { +func writeFlowChartStates[S cmp.Ordered](buf *bytes.Buffer, sortedStates []S, statesToIDMap map[S]string) { for _, state := range sortedStates { - buf.WriteString(fmt.Sprintf(` %s[%s]`, statesToIDMap[state], state)) + fmt.Fprintf(buf, ` %s[%v]`, statesToIDMap[state], state) buf.WriteString("\n") } buf.WriteString("\n") } -func writeFlowChartTransitions(buf *bytes.Buffer, transitions map[eKey]string, sortedTransitionKeys []eKey, statesToIDMap map[string]string) { +func writeFlowChartTransitions[E cmp.Ordered, S cmp.Ordered](buf *bytes.Buffer, transitions map[eKey[E, S]]S, sortedTransitionKeys []eKey[E, S], statesToIDMap map[S]string) { for _, transition := range sortedTransitionKeys { target := transitions[transition] - buf.WriteString(fmt.Sprintf(` %s --> |%s| %s`, statesToIDMap[transition.src], transition.event, statesToIDMap[target])) + fmt.Fprintf(buf, ` %s --> |%v| %s`, statesToIDMap[transition.src], transition.event, statesToIDMap[target]) buf.WriteString("\n") } buf.WriteString("\n") } -func writeFlowChartHighlightCurrent(buf *bytes.Buffer, current string, statesToIDMap map[string]string) { - buf.WriteString(fmt.Sprintf(` style %s fill:%s`, statesToIDMap[current], highlightingColor)) +func writeFlowChartHighlightCurrent[S cmp.Ordered](buf *bytes.Buffer, current S, statesToIDMap map[S]string) { + fmt.Fprintf(buf, ` style %s fill:%s`, statesToIDMap[current], highlightingColor) buf.WriteString("\n") } diff --git a/mermaid_visualizer_test.go b/mermaid_visualizer_test.go index f922ba5..6bfa6cd 100644 --- a/mermaid_visualizer_test.go +++ b/mermaid_visualizer_test.go @@ -7,16 +7,18 @@ import ( ) func TestMermaidOutput(t *testing.T) { - fsmUnderTest := NewFSM( + fsmUnderTest, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, - {Name: "part-close", Src: []string{"intermediate"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, + {Event: "part-close", Src: []string{"intermediate"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) - + if err != nil { + t.Errorf("constructor failed:%s", err) + } got, err := VisualizeForMermaidWithGraphType(fsmUnderTest, StateDiagram) if err != nil { t.Errorf("got error for visualizing with type MERMAID: %s", err) @@ -32,24 +34,26 @@ stateDiagram-v2 normalizedWanted := strings.ReplaceAll(wanted, "\n", "") if normalizedGot != normalizedWanted { t.Errorf("build mermaid graph failed. \nwanted \n%s\nand got \n%s\n", wanted, got) - fmt.Println([]byte(normalizedGot)) - fmt.Println([]byte(normalizedWanted)) + fmt.Println(normalizedGot) + fmt.Println(normalizedWanted) } } func TestMermaidFlowChartOutput(t *testing.T) { - fsmUnderTest := NewFSM( + fsmUnderTest, err := New( "closed", - Events{ - {Name: "open", Src: []string{"closed"}, Dst: "open"}, - {Name: "part-open", Src: []string{"closed"}, Dst: "intermediate"}, - {Name: "part-open", Src: []string{"intermediate"}, Dst: "open"}, - {Name: "close", Src: []string{"open"}, Dst: "closed"}, - {Name: "part-close", Src: []string{"intermediate"}, Dst: "closed"}, + Transitions[string, string]{ + {Event: "open", Src: []string{"closed"}, Dst: "open"}, + {Event: "part-open", Src: []string{"closed"}, Dst: "intermediate"}, + {Event: "part-open", Src: []string{"intermediate"}, Dst: "open"}, + {Event: "close", Src: []string{"open"}, Dst: "closed"}, + {Event: "part-close", Src: []string{"intermediate"}, Dst: "closed"}, }, - Callbacks{}, + Callbacks[string, string]{}, ) - + if err != nil { + t.Errorf("constructor failed:%s", err) + } got, err := VisualizeForMermaidWithGraphType(fsmUnderTest, FlowChart) if err != nil { t.Errorf("got error for visualizing with type MERMAID: %s", err) @@ -72,7 +76,7 @@ graph LR normalizedWanted := strings.ReplaceAll(wanted, "\n", "") if normalizedGot != normalizedWanted { t.Errorf("build mermaid graph failed. \nwanted \n%s\nand got \n%s\n", wanted, got) - fmt.Println([]byte(normalizedGot)) - fmt.Println([]byte(normalizedWanted)) + fmt.Println(normalizedGot) + fmt.Println(normalizedWanted) } } diff --git a/visualizer.go b/visualizer.go index 04cc872..71fac67 100644 --- a/visualizer.go +++ b/visualizer.go @@ -1,7 +1,9 @@ package fsm import ( + "cmp" "fmt" + "slices" "sort" ) @@ -21,7 +23,7 @@ const ( // VisualizeWithType outputs a visualization of a FSM in the desired format. // If the type is not given it defaults to GRAPHVIZ -func VisualizeWithType(fsm *FSM, visualizeType VisualizeType) (string, error) { +func VisualizeWithType[E cmp.Ordered, S cmp.Ordered](fsm *FSM[E, S], visualizeType VisualizeType) (string, error) { switch visualizeType { case GRAPHVIZ: return Visualize(fsm), nil @@ -36,9 +38,9 @@ func VisualizeWithType(fsm *FSM, visualizeType VisualizeType) (string, error) { } } -func getSortedTransitionKeys(transitions map[eKey]string) []eKey { +func getSortedTransitionKeys[E cmp.Ordered, S cmp.Ordered](transitions map[eKey[E, S]]S) []eKey[E, S] { // we sort the key alphabetically to have a reproducible graph output - sortedTransitionKeys := make([]eKey, 0) + sortedTransitionKeys := make([]eKey[E, S], 0) for transition := range transitions { sortedTransitionKeys = append(sortedTransitionKeys, transition) @@ -53,8 +55,8 @@ func getSortedTransitionKeys(transitions map[eKey]string) []eKey { return sortedTransitionKeys } -func getSortedStates(transitions map[eKey]string) ([]string, map[string]string) { - statesToIDMap := make(map[string]string) +func getSortedStates[E cmp.Ordered, S cmp.Ordered](transitions map[eKey[E, S]]S) ([]S, map[S]string) { + statesToIDMap := make(map[S]string) for transition, target := range transitions { if _, ok := statesToIDMap[transition.src]; !ok { statesToIDMap[transition.src] = "" @@ -64,11 +66,12 @@ func getSortedStates(transitions map[eKey]string) ([]string, map[string]string) } } - sortedStates := make([]string, 0, len(statesToIDMap)) + sortedStates := make([]S, 0, len(statesToIDMap)) for state := range statesToIDMap { sortedStates = append(sortedStates, state) } - sort.Strings(sortedStates) + + slices.Sort(sortedStates) for i, state := range sortedStates { statesToIDMap[state] = fmt.Sprintf("id%d", i)