From 20db07dccf0481be7a132d7dcbf251281eb9baed Mon Sep 17 00:00:00 2001 From: sputn1ck Date: Wed, 23 Aug 2023 11:27:23 +0200 Subject: [PATCH] fsm: add fsm module This commit adds a module for a finite state machine. The goal of the module is to provide a simple, easy to use, and easy to understand finite state machine. The module is designed to be used in future loop subsystems. Additionally a state visualizer is provided to help with understanding the state machine. --- .golangci.yml | 5 + Makefile | 5 +- fsm/example_fsm.go | 127 ++++++++++++++ fsm/example_fsm.md | 12 ++ fsm/example_fsm_test.go | 245 +++++++++++++++++++++++++++ fsm/fsm.go | 296 +++++++++++++++++++++++++++++++++ fsm/fsm.md | 139 ++++++++++++++++ fsm/fsm_test.go | 117 +++++++++++++ fsm/log.go | 26 +++ fsm/observer.go | 134 +++++++++++++++ fsm/stateparser/stateparser.go | 96 +++++++++++ loopd/log.go | 2 + scripts/fsm-generate.sh | 2 + 13 files changed, 1205 insertions(+), 1 deletion(-) create mode 100644 fsm/example_fsm.go create mode 100644 fsm/example_fsm.md create mode 100644 fsm/example_fsm_test.go create mode 100644 fsm/fsm.go create mode 100644 fsm/fsm.md create mode 100644 fsm/fsm_test.go create mode 100644 fsm/log.go create mode 100644 fsm/observer.go create mode 100644 fsm/stateparser/stateparser.go create mode 100755 scripts/fsm-generate.sh diff --git a/.golangci.yml b/.golangci.yml index 1af0453..3173640 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -133,5 +133,10 @@ issues: # Allow fmt.Printf() in loop - path: cmd/loop/* + linters: + - forbidigo + + # Allow fmt.Printf() in stateparser + - path: fsm/stateparser/* linters: - forbidigo \ No newline at end of file diff --git a/Makefile b/Makefile index 0f5ff9d..bb458cd 100644 --- a/Makefile +++ b/Makefile @@ -134,4 +134,7 @@ sqlc-check: sqlc @$(call print, "Verifying sql code generation.") if test -n "$$(git status --porcelain '*.go')"; then echo "SQL models not properly generated!"; git status --porcelain '*.go'; exit 1; fi - +fsm: + @$(call print, "Generating state machine docs") + ./scripts/fsm-generate.sh; +.PHONY: fsm diff --git a/fsm/example_fsm.go b/fsm/example_fsm.go new file mode 100644 index 0000000..723f9dd --- /dev/null +++ b/fsm/example_fsm.go @@ -0,0 +1,127 @@ +package fsm + +import ( + "fmt" +) + +// ExampleService is an example service that we want to wait for in the FSM. +type ExampleService interface { + WaitForStuffHappening() (<-chan bool, error) +} + +// ExampleStore is an example store that we want to use in our exitFunc. +type ExampleStore interface { + StoreStuff() error +} + +// ExampleFSM implements the FSM and uses the ExampleService and ExampleStore +// to implement the actions. +type ExampleFSM struct { + *StateMachine + + service ExampleService + store ExampleStore +} + +// NewExampleFSMContext creates a new example FSM context. +func NewExampleFSMContext(service ExampleService, + store ExampleStore) *ExampleFSM { + + exampleFSM := &ExampleFSM{ + service: service, + store: store, + } + exampleFSM.StateMachine = NewStateMachine(exampleFSM.GetStates()) + + return exampleFSM +} + +// States. +const ( + InitFSM = StateType("InitFSM") + StuffSentOut = StateType("StuffSentOut") + WaitingForStuff = StateType("WaitingForStuff") + StuffFailed = StateType("StuffFailed") + StuffSuccess = StateType("StuffSuccess") +) + +// Events. +var ( + OnRequestStuff = EventType("OnRequestStuff") + OnStuffSentOut = EventType("OnStuffSentOut") + OnStuffSuccess = EventType("OnStuffSuccess") +) + +// GetStates returns the states for the example FSM. +func (e *ExampleFSM) GetStates() States { + return States{ + Default: State{ + Transitions: Transitions{ + OnRequestStuff: InitFSM, + }, + }, + InitFSM: State{ + Action: e.initFSM, + Transitions: Transitions{ + OnStuffSentOut: StuffSentOut, + OnError: StuffFailed, + }, + }, + StuffSentOut: State{ + Action: e.waitForStuff, + Transitions: Transitions{ + OnStuffSuccess: StuffSuccess, + OnError: StuffFailed, + }, + }, + StuffFailed: State{ + Action: NoOpAction, + }, + StuffSuccess: State{ + Action: NoOpAction, + }, + } +} + +// InitStuffRequest is the event context for the InitFSM state. +type InitStuffRequest struct { + Stuff string + respondChan chan<- string +} + +// initFSM is the action for the InitFSM state. +func (e *ExampleFSM) initFSM(eventCtx EventContext) EventType { + req, ok := eventCtx.(*InitStuffRequest) + if !ok { + return e.HandleError( + fmt.Errorf("invalid event context type: %T", eventCtx), + ) + } + + err := e.store.StoreStuff() + if err != nil { + return e.HandleError(err) + } + + req.respondChan <- req.Stuff + + return OnStuffSentOut +} + +// waitForStuff is an action that waits for stuff to happen. +func (e *ExampleFSM) waitForStuff(eventCtx EventContext) EventType { + waitChan, err := e.service.WaitForStuffHappening() + if err != nil { + return e.HandleError(err) + } + + go func() { + <-waitChan + err := e.SendEvent(OnStuffSuccess, nil) + if err != nil { + log.Errorf("unable to send event: %v", err) + } + }() + + return NoOp +} diff --git a/fsm/example_fsm.md b/fsm/example_fsm.md new file mode 100644 index 0000000..7de0644 --- /dev/null +++ b/fsm/example_fsm.md @@ -0,0 +1,12 @@ +```mermaid +stateDiagram-v2 +[*] --> InitFSM: OnRequestStuff +InitFSM +InitFSM --> StuffFailed: OnError +InitFSM --> StuffSentOut: OnStuffSentOut +StuffFailed +StuffSentOut +StuffSentOut --> StuffFailed: OnError +StuffSentOut --> StuffSuccess: OnStuffSuccess +StuffSuccess +``` \ No newline at end of file diff --git a/fsm/example_fsm_test.go b/fsm/example_fsm_test.go new file mode 100644 index 0000000..5855a8e --- /dev/null +++ b/fsm/example_fsm_test.go @@ -0,0 +1,245 @@ +package fsm + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +var ( + errService = errors.New("service error") + errStore = errors.New("store error") +) + +type mockStore struct { + storeErr error +} + +func (m *mockStore) StoreStuff() error { + return m.storeErr +} + +type mockService struct { + respondChan chan bool + respondErr error +} + +func (m *mockService) WaitForStuffHappening() (<-chan bool, error) { + return m.respondChan, m.respondErr +} + +func newInitStuffRequest() *InitStuffRequest { + return &InitStuffRequest{ + Stuff: "stuff", + respondChan: make(chan<- string, 1), + } +} + +func TestExampleFSM(t *testing.T) { + testCases := []struct { + name string + expectedState StateType + eventCtx EventContext + expectedLastActionError error + + sendEvent EventType + sendEventErr error + + serviceErr error + storeErr error + }{ + { + name: "success", + expectedState: StuffSuccess, + eventCtx: newInitStuffRequest(), + sendEvent: OnRequestStuff, + }, + { + name: "service error", + expectedState: StuffFailed, + eventCtx: newInitStuffRequest(), + sendEvent: OnRequestStuff, + serviceErr: errService, + expectedLastActionError: errService, + }, + { + name: "store error", + expectedLastActionError: errStore, + storeErr: errStore, + sendEvent: OnRequestStuff, + expectedState: StuffFailed, + eventCtx: newInitStuffRequest(), + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + respondChan := make(chan string, 1) + if req, ok := tc.eventCtx.(*InitStuffRequest); ok { + req.respondChan = respondChan + } + + serviceResponseChan := make(chan bool, 1) + serviceResponseChan <- true + + service := &mockService{ + respondChan: serviceResponseChan, + respondErr: tc.serviceErr, + } + + store := &mockStore{ + storeErr: tc.storeErr, + } + + exampleContext := NewExampleFSMContext(service, store) + cachedObserver := NewCachedObserver(100) + + exampleContext.RegisterObserver(cachedObserver) + + err := exampleContext.SendEvent( + tc.sendEvent, tc.eventCtx, + ) + require.Equal(t, tc.sendEventErr, err) + + require.Equal( + t, + tc.expectedLastActionError, + exampleContext.LastActionError, + ) + + err = cachedObserver.WaitForState( + context.Background(), + time.Second, + tc.expectedState, + ) + require.NoError(t, err) + }) + } +} + +// getTestContext returns a test context for the example FSM and a cached +// observer that can be used to verify the state transitions. +func getTestContext() (*ExampleFSM, *CachedObserver) { + service := &mockService{ + respondChan: make(chan bool, 1), + } + service.respondChan <- true + + store := &mockStore{} + + exampleContext := NewExampleFSMContext(service, store) + cachedObserver := NewCachedObserver(100) + + exampleContext.RegisterObserver(cachedObserver) + + return exampleContext, cachedObserver +} + +// TestExampleFSMFlow tests different flows that the example FSM can go through. +func TestExampleFSMFlow(t *testing.T) { + testCases := []struct { + name string + expectedStateFlow []StateType + expectedEventFlow []EventType + storeError error + serviceError error + }{ + { + name: "success", + expectedStateFlow: []StateType{ + InitFSM, + StuffSentOut, + StuffSuccess, + }, + expectedEventFlow: []EventType{ + OnRequestStuff, + OnStuffSentOut, + OnStuffSuccess, + }, + }, + { + name: "failure on store", + expectedStateFlow: []StateType{ + InitFSM, + StuffFailed, + }, + expectedEventFlow: []EventType{ + OnRequestStuff, + OnError, + }, + storeError: errStore, + }, + { + name: "failure on service", + expectedStateFlow: []StateType{ + InitFSM, + StuffSentOut, + StuffFailed, + }, + expectedEventFlow: []EventType{ + OnRequestStuff, + OnStuffSentOut, + OnError, + }, + serviceError: errService, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + exampleContext, cachedObserver := getTestContext() + + if tc.storeError != nil { + exampleContext.store.(*mockStore). + storeErr = tc.storeError + } + + if tc.serviceError != nil { + exampleContext.service.(*mockService). + respondErr = tc.serviceError + } + + go func() { + err := exampleContext.SendEvent( + OnRequestStuff, + newInitStuffRequest(), + ) + + require.NoError(t, err) + }() + + // Wait for the final state. + err := cachedObserver.WaitForState( + context.Background(), + time.Second, + tc.expectedStateFlow[len( + tc.expectedStateFlow, + )-1], + ) + require.NoError(t, err) + + allNotifications := cachedObserver. + GetCachedNotifications() + + for index, notification := range allNotifications { + require.Equal( + t, + tc.expectedStateFlow[index], + notification.NextState, + ) + require.Equal( + t, + tc.expectedEventFlow[index], + notification.Event, + ) + } + }) + } +} diff --git a/fsm/fsm.go b/fsm/fsm.go new file mode 100644 index 0000000..479f203 --- /dev/null +++ b/fsm/fsm.go @@ -0,0 +1,296 @@ +package fsm + +import ( + "errors" + "fmt" + "sync" +) + +// ErrEventRejected is the error returned when the state machine cannot process +// an event in the state that it is in. +var ( + ErrEventRejected = errors.New("event rejected") + ErrWaitForStateTimedOut = errors.New( + "timed out while waiting for event", + ) + ErrInvalidContextType = errors.New("invalid context") +) + +const ( + // Default represents the default state of the system. + Default StateType = "" + + // NoOp represents a no-op event. + NoOp EventType = "NoOp" + + // OnError can be used when an action returns a generic error. + OnError EventType = "OnError" + + // ContextValidationFailed can be when the passed context if + // not of the expected type. + ContextValidationFailed EventType = "ContextValidationFailed" +) + +// StateType represents an extensible state type in the state machine. +type StateType string + +// EventType represents an extensible event type in the state machine. +type EventType string + +// EventContext represents the context to be passed to the action +// implementation. +type EventContext interface{} + +// Action represents the action to be executed in a given state. +type Action func(eventCtx EventContext) EventType + +// Transitions represents a mapping of events and states. +type Transitions map[EventType]StateType + +// State binds a state with an action and a set of events it can handle. +type State struct { + // EntryFunc is a function that is called when the state is entered. + EntryFunc func() + // ExitFunc is a function that is called when the state is exited. + ExitFunc func() + // Action is the action to be executed in the state. + Action Action + // Transitions is a mapping of events and states. + Transitions Transitions +} + +// States represents a mapping of states and their implementations. +type States map[StateType]State + +// Notification represents a notification sent to the state machine's +// notification channel. +type Notification struct { + // PreviousState is the state the state machine was in before the event + // was processed. + PreviousState StateType + // NextState is the state the state machine is in after the event was + // processed. + NextState StateType + // Event is the event that was processed. + Event EventType +} + +// Observer is an interface that can be implemented by types that want to +// observe the state machine. +type Observer interface { + Notify(Notification) +} + +// StateMachine represents the state machine. +type StateMachine struct { + // Context represents the state machine context. + States States + + // ActionEntryFunc is a function that is called before an action is + // executed. + ActionEntryFunc func() + + // ActionExitFunc is a function that is called after an action is + // executed. + ActionExitFunc func() + + // mutex ensures that only 1 event is processed by the state machine at + // any given time. + mutex sync.Mutex + + // LastActionError is an error set by the last action executed. + LastActionError error + + // previous represents the previous state. + previous StateType + + // current represents the current state. + current StateType + + // observers is a slice of observers that are notified when the state + // machine transitions between states. + observers []Observer + + // observerMutex ensures that observers are only added or removed + // safely. + observerMutex sync.Mutex +} + +// NewStateMachine creates a new state machine. +func NewStateMachine(states States) *StateMachine { + return &StateMachine{ + States: states, + observers: make([]Observer, 0), + } +} + +// getNextState returns the next state for the event given the machine's current +// state, or an error if the event can't be handled in the given state. +func (s *StateMachine) getNextState(event EventType) (State, error) { + var ( + state State + ok bool + ) + + stateMap := s.States + + if state, ok = stateMap[s.current]; !ok { + return State{}, NewErrConfigError("current state not found") + } + + if state.Transitions == nil { + return State{}, NewErrConfigError( + "current state has no transitions", + ) + } + + var next StateType + if next, ok = state.Transitions[event]; !ok { + return State{}, NewErrConfigError( + "event not found in current transitions", + ) + } + + // Identify the state definition for the next state. + state, ok = stateMap[next] + if !ok { + return State{}, NewErrConfigError("next state not found") + } + + if state.Action == nil { + return State{}, NewErrConfigError("next state has no action") + } + + // Transition over to the next state. + s.previous = s.current + s.current = next + + return state, nil +} + +// SendEvent sends an event to the state machine. It returns an error if the +// event cannot be processed in the current state. Otherwise, it only returns +// nil if the event for the last action is a no-op. +func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.States == nil { + return NewErrConfigError("state machine config is nil") + } + + for { + // Determine the next state for the event given the machine's + // current state. + state, err := s.getNextState(event) + if err != nil { + return ErrEventRejected + } + + // Notify the state machine's observers. + s.observerMutex.Lock() + for _, observer := range s.observers { + observer.Notify(Notification{ + PreviousState: s.previous, + NextState: s.current, + Event: event, + }) + } + s.observerMutex.Unlock() + + // Execute the state machines ActionEntryFunc. + if s.ActionEntryFunc != nil { + s.ActionEntryFunc() + } + + // Execute the current state's entry function + if state.EntryFunc != nil { + state.EntryFunc() + } + + // Execute the next state's action and loop over again if the + // event returned is not a no-op. + nextEvent := state.Action(eventCtx) + + // Execute the current state's exit function + if state.ExitFunc != nil { + state.ExitFunc() + } + + // Execute the state machines ActionExitFunc. + if s.ActionExitFunc != nil { + s.ActionExitFunc() + } + + // If the next event is a no-op, we're done. + if nextEvent == NoOp { + return nil + } + + event = nextEvent + } +} + +// RegisterObserver registers an observer with the state machine. +func (s *StateMachine) RegisterObserver(observer Observer) { + s.observerMutex.Lock() + defer s.observerMutex.Unlock() + + if observer != nil { + s.observers = append(s.observers, observer) + } +} + +// RemoveObserver removes an observer from the state machine. It returns true +// if the observer was removed, false otherwise. +func (s *StateMachine) RemoveObserver(observer Observer) bool { + s.observerMutex.Lock() + defer s.observerMutex.Unlock() + + for i, o := range s.observers { + if o == observer { + s.observers = append( + s.observers[:i], s.observers[i+1:]..., + ) + return true + } + } + + return false +} + +// HandleError is a helper function that can be used by actions to handle +// errors. +func (s *StateMachine) HandleError(err error) EventType { + log.Errorf("StateMachine error: %s", err) + s.LastActionError = err + return OnError +} + +// NoOpAction is a no-op action that can be used by states that don't need to +// execute any action. +func NoOpAction(_ EventContext) EventType { + return NoOp +} + +// ErrConfigError is an error returned when the state machine is misconfigured. +type ErrConfigError error + +// NewErrConfigError creates a new ErrConfigError. +func NewErrConfigError(msg string) ErrConfigError { + return (ErrConfigError)(fmt.Errorf("config error: %s", msg)) +} + +// ErrWaitingForStateTimeout is an error returned when the state machine times +// out while waiting for a state. +type ErrWaitingForStateTimeout error + +// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout. +func NewErrWaitingForStateTimeout(expected, + actual StateType) ErrWaitingForStateTimeout { + + return (ErrWaitingForStateTimeout)(fmt.Errorf( + "waiting for state timeout: expected %s, actual: %s", + expected, actual, + )) +} diff --git a/fsm/fsm.md b/fsm/fsm.md new file mode 100644 index 0000000..5db7329 --- /dev/null +++ b/fsm/fsm.md @@ -0,0 +1,139 @@ +# Finite State Machine Module + +This module provides a simple golang finite state machine (FSM) implementation. + + +## Introduction + +The state machine uses events and actions to transition between states. The +events are used to trigger a transition and the actions are used to perform +some work when entering a state. Actions return new events which are then +used to trigger the next transition. + +## Usage + +A simple way to use the FSM is to embed it into a struct: + +```go +type LightSwitchFSM struct { + *StateMachine +} +``` + +In order to use the FSM you need to define the events, actions and statemaps +for the FSM. events are defined as constants, actions are defined as functions +on the `LightSwitchFSM` struct and statemaps are in a map of `State` to `StateMap` +where `StateMap` is a map of `Event` to `Action`. + +For the `LightSwitchFSM` we can first define the states +```go +const ( + OffState = StateType("Off") + OnState = StateType("On") +) + +const ( + SwitchOff = EventType("SwitchOff") + SwitchOn = EventType("SwitchOn") +) +``` + +Next we define the actions, here we're simply going to log from the action. +```go +func (a *LightSwitchFSM) OffAction(_ EventContext) EventType { + fmt.Println("The light has been switched off") + return NoOp +} + +func (a *LightSwitchFSM) OnAction(_ EventContext) EventType { + fmt.Println("The light has been switched on") + return NoOp +} +``` + +Next we define the statemap, here we're going to implement a getStates() +function that returns the statemap. +```go +func (l *LightSwitchFSM) getStates() States { + return States{ + OffState: State{ + Action: l.OffAction, + Transitions: Transitions{ + SwitchOn: OnState, + }, + }, + OnState: State{ + Action: l.OnAction, + Transitions: Transitions{ + SwitchOff: OffState, + }, + }, + } +} +``` + +Finally, we can create the FSM and use it. + +```go +func NewLightSwitchFSM() *LightSwitchFSM { + fsm := &LightSwitchFSM{} + fsm.StateMachine = &StateMachine{ + States: fsm.getStates(), + Current: OffState, + } + return fsm +} +``` + +This is what it would look like to use the FSM: +```go +func TestLightSwitchFSM(t *testing.T) { + // Create a new light switch FSM. + lightSwitch := NewLightSwitchFSM() + + // Expect the light to be off + require.Equal(t, lightSwitch.Current, OffState) + + // Send the On Event + err := lightSwitch.SendEvent(SwitchOn, nil) + require.NoError(t, err) + + // Expect the light to be on + require.Equal(t, lightSwitch.Current, OnState) + + // Send the Off Event + err = lightSwitch.SendEvent(SwitchOff, nil) + require.NoError(t, err) + + // Expect the light to be off + require.Equal(t, lightSwitch.Current, OffState) +} +``` + +## Observing the state machine +The state machine can be observed by registering an observer. The observer +will be called when the state machine transitions between states. The observer +is called with the old state, the new state and the event that triggered the +transition. + +An observer can be registered by calling the `RegisterObserver` function on +the state machine. The observer must implement the `Observer` interface. + +```go +type Observer interface { + Notify(Notification) +} +``` + +An example of a cached observer can be found in [observer.go](./observer.go). + + +## More Examples +A more elaborate example that uses error handling, event context and more +elaborate actions can be found in here [examples_fsm.go](./example_fsm.go). +With the tests in [examples_fsm_test.go](./example_fsm_test.go) showing how to +use the FSM. + +## Visualizing the FSM +The FSM can be visualized to mermaid markdown using the [stateparser.go](./stateparser/stateparser.go) +tool. The visualization for the exampleFSM can be found in [example_fsm.md](./example_fsm.md). \ No newline at end of file diff --git a/fsm/fsm_test.go b/fsm/fsm_test.go new file mode 100644 index 0000000..fae70e7 --- /dev/null +++ b/fsm/fsm_test.go @@ -0,0 +1,117 @@ +package fsm + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +var ( + errAction = errors.New("action error") +) + +// TestStateMachineContext is a test context for the state machine. +type TestStateMachineContext struct { + *StateMachine +} + +// GetStates returns the states for the test state machine. +// The StateMap looks like this: +// State1 -> Event1 -> State2 . +func (c *TestStateMachineContext) GetStates() States { + return States{ + "State1": State{ + Action: func(ctx EventContext) EventType { + return "Event1" + }, + Transitions: Transitions{ + "Event1": "State2", + }, + }, + "State2": State{ + Action: func(ctx EventContext) EventType { + return "NoOp" + }, + Transitions: Transitions{}, + }, + } +} + +// errorAction returns an error. +func (c *TestStateMachineContext) errorAction(eventCtx EventContext) EventType { + return c.StateMachine.HandleError(errAction) +} + +func setupTestStateMachineContext() *TestStateMachineContext { + ctx := &TestStateMachineContext{} + + ctx.StateMachine = &StateMachine{ + States: ctx.GetStates(), + current: "State1", + previous: "", + } + + return ctx +} + +// TestStateMachine_Success tests the state machine with a successful event. +func TestStateMachine_Success(t *testing.T) { + ctx := setupTestStateMachineContext() + + // Send an event to the state machine. + err := ctx.SendEvent("Event1", nil) + require.NoError(t, err) + + // Check that the state machine has transitioned to the next state. + require.Equal(t, StateType("State2"), ctx.current) +} + +// TestStateMachine_ConfigurationError tests the state machine with a +// configuration error. +func TestStateMachine_ConfigurationError(t *testing.T) { + ctx := setupTestStateMachineContext() + ctx.StateMachine.States = nil + + err := ctx.SendEvent("Event1", nil) + require.EqualError( + t, err, + NewErrConfigError("state machine config is nil").Error(), + ) +} + +// TestStateMachine_ActionError tests the state machine with an action error. +func TestStateMachine_ActionError(t *testing.T) { + ctx := setupTestStateMachineContext() + + states := ctx.StateMachine.States + + // Add a Transition to State2 if the Action on Stat2 fails. + // The new StateMap looks like this: + // State1 -> Event1 -> State2 + // State2 -> OnError -> ErrorState + states["State2"] = State{ + Action: ctx.errorAction, + Transitions: Transitions{ + OnError: "ErrorState", + }, + } + + states["ErrorState"] = State{ + Action: func(ctx EventContext) EventType { + return "NoOp" + }, + Transitions: Transitions{}, + } + + err := ctx.SendEvent("Event1", nil) + + // Sending an event to the state machine should not return an error. + require.NoError(t, err) + + // Ensure that the last error is set. + require.Equal(t, errAction, ctx.StateMachine.LastActionError) + + // Expect the state machine to have transitioned to the ErrorState. + require.Equal(t, StateType("ErrorState"), ctx.StateMachine.current) +} diff --git a/fsm/log.go b/fsm/log.go new file mode 100644 index 0000000..d552dd0 --- /dev/null +++ b/fsm/log.go @@ -0,0 +1,26 @@ +package fsm + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// Subsystem defines the sub system name of this package. +const Subsystem = "FSM" + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger(Subsystem, nil)) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/fsm/observer.go b/fsm/observer.go new file mode 100644 index 0000000..b9c7286 --- /dev/null +++ b/fsm/observer.go @@ -0,0 +1,134 @@ +package fsm + +import ( + "context" + "sync" + "time" +) + +// CachedObserver is an observer that caches all states and transitions of +// the observed state machine. +type CachedObserver struct { + lastNotification Notification + cachedNotifications *FixedSizeSlice[Notification] + + notificationCond *sync.Cond + notificationMx sync.Mutex +} + +// NewCachedObserver creates a new cached observer with the given maximum +// number of cached notifications. +func NewCachedObserver(maxElements int) *CachedObserver { + fixedSizeSlice := NewFixedSizeSlice[Notification](maxElements) + observer := &CachedObserver{ + cachedNotifications: fixedSizeSlice, + } + observer.notificationCond = sync.NewCond(&observer.notificationMx) + + return observer +} + +// Notify implements the Observer interface. +func (c *CachedObserver) Notify(notification Notification) { + c.notificationMx.Lock() + defer c.notificationMx.Unlock() + + c.cachedNotifications.Add(notification) + c.lastNotification = notification + c.notificationCond.Broadcast() +} + +// GetCachedNotifications returns a copy of the cached notifications. +func (c *CachedObserver) GetCachedNotifications() []Notification { + c.notificationMx.Lock() + defer c.notificationMx.Unlock() + + return c.cachedNotifications.Get() +} + +// WaitForState waits for the state machine to reach the given state. +func (s *CachedObserver) WaitForState(ctx context.Context, + timeout time.Duration, state StateType) error { + + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Channel to notify when the desired state is reached + ch := make(chan struct{}) + + // Goroutine to wait on condition variable + go func() { + s.notificationMx.Lock() + defer s.notificationMx.Unlock() + + for { + // Check if the last state is the desired state + if s.lastNotification.NextState == state { + ch <- struct{}{} + return + } + + // Otherwise, wait for the next notification + s.notificationCond.Wait() + } + }() + + // Wait for either the condition to be met or for a timeout + select { + case <-timeoutCtx.Done(): + return NewErrWaitingForStateTimeout( + state, s.lastNotification.NextState, + ) + case <-ch: + return nil + } +} + +// FixedSizeSlice is a slice with a fixed size. +type FixedSizeSlice[T any] struct { + data []T + maxLen int + + sync.Mutex +} + +// NewFixedSlice initializes a new FixedSlice with a given maximum length. +func NewFixedSizeSlice[T any](maxLen int) *FixedSizeSlice[T] { + return &FixedSizeSlice[T]{ + data: make([]T, 0, maxLen), + maxLen: maxLen, + } +} + +// Add appends a new element to the slice. If the slice reaches its maximum +// length, the first element is removed. +func (fs *FixedSizeSlice[T]) Add(element T) { + fs.Lock() + defer fs.Unlock() + + if len(fs.data) == fs.maxLen { + // Remove the first element + fs.data = fs.data[1:] + } + // Add the new element + fs.data = append(fs.data, element) +} + +// Get returns a copy of the slice. +func (fs *FixedSizeSlice[T]) Get() []T { + fs.Lock() + defer fs.Unlock() + + data := make([]T, len(fs.data)) + copy(data, fs.data) + + return data +} + +// GetElement returns the element at the given index. +func (fs *FixedSizeSlice[T]) GetElement(index int) T { + fs.Lock() + defer fs.Unlock() + + return fs.data[index] +} diff --git a/fsm/stateparser/stateparser.go b/fsm/stateparser/stateparser.go new file mode 100644 index 0000000..9e107d9 --- /dev/null +++ b/fsm/stateparser/stateparser.go @@ -0,0 +1,96 @@ +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "os" + "path/filepath" + "sort" + + "github.com/lightninglabs/loop/fsm" +) + +func main() { + if err := run(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func run() error { + out := flag.String("out", "", "outfile") + stateMachine := flag.String("fsm", "", "the swap state machine to parse") + flag.Parse() + + if filepath.Ext(*out) != ".md" { + return errors.New("wrong argument: out must be a .md file") + } + + fp, err := filepath.Abs(*out) + if err != nil { + return err + } + + switch *stateMachine { + case "example": + exampleFSM := &fsm.ExampleFSM{} + err = writeMermaidFile(fp, exampleFSM.GetStates()) + if err != nil { + return err + } + + default: + fmt.Println("Missing or wrong argument: fsm must be one of:") + fmt.Println("\treservations") + fmt.Println("\texample") + } + + return nil +} + +func writeMermaidFile(filename string, states fsm.States) error { + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + var b bytes.Buffer + fmt.Fprint(&b, "```mermaid\nstateDiagram-v2\n") + + sortedStates := sortedKeys(states) + for _, state := range sortedStates { + edges := states[fsm.StateType(state)] + // write state name + if len(state) > 0 { + fmt.Fprintf(&b, "%s\n", state) + } else { + state = "[*]" + } + // write transitions + for edge, target := range edges.Transitions { + fmt.Fprintf(&b, "%s --> %s: %s\n", state, target, edge) + } + } + + fmt.Fprint(&b, "```") + _, err = f.Write(b.Bytes()) + if err != nil { + return err + } + + return nil +} + +func sortedKeys(m fsm.States) []string { + keys := make([]string, len(m)) + i := 0 + for k := range m { + keys[i] = string(k) + i++ + } + sort.Strings(keys) + return keys +} diff --git a/loopd/log.go b/loopd/log.go index cb96f79..7a0ce47 100644 --- a/loopd/log.go +++ b/loopd/log.go @@ -5,6 +5,7 @@ import ( "github.com/lightninglabs/aperture/lsat" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" + "github.com/lightninglabs/loop/fsm" "github.com/lightninglabs/loop/liquidity" "github.com/lightninglabs/loop/loopdb" "github.com/lightningnetwork/lnd" @@ -36,6 +37,7 @@ func SetupLoggers(root *build.RotatingLogWriter, intercept signal.Interceptor) { lnd.AddSubLogger( root, liquidity.Subsystem, intercept, liquidity.UseLogger, ) + lnd.AddSubLogger(root, fsm.Subsystem, intercept, fsm.UseLogger) } // genSubLogger creates a logger for a subsystem. We provide an instance of diff --git a/scripts/fsm-generate.sh b/scripts/fsm-generate.sh new file mode 100755 index 0000000..d6b1ce3 --- /dev/null +++ b/scripts/fsm-generate.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +go run ./fsm/stateparser/stateparser.go --out ./fsm/example_fsm.md --fsm example \ No newline at end of file