diff --git a/fsm/fsm.go b/fsm/fsm.go index f1f1649..6d36812 100644 --- a/fsm/fsm.go +++ b/fsm/fsm.go @@ -13,7 +13,10 @@ var ( ErrWaitForStateTimedOut = errors.New( "timed out while waiting for event", ) - ErrInvalidContextType = errors.New("invalid context") + ErrInvalidContextType = errors.New("invalid context") + ErrWaitingForStateEarlyAbortError = errors.New( + "waiting for state early abort", + ) ) const ( @@ -73,6 +76,8 @@ type Notification struct { NextState StateType // Event is the event that was processed. Event EventType + // LastActionError is the error returned by the last action executed. + LastActionError error } // Observer is an interface that can be implemented by types that want to @@ -214,9 +219,10 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error { // Notify the state machine's observers. s.observerMutex.Lock() notification := Notification{ - PreviousState: s.previous, - NextState: s.current, - Event: event, + PreviousState: s.previous, + NextState: s.current, + Event: event, + LastActionError: s.LastActionError, } for _, observer := range s.observers { diff --git a/fsm/observer.go b/fsm/observer.go index 2d5e3fc..8677d51 100644 --- a/fsm/observer.go +++ b/fsm/observer.go @@ -55,7 +55,8 @@ type WaitForStateOption interface { // fsmOptions is a struct that holds all options that can be passed to the // WaitForState function. type fsmOptions struct { - initialWait time.Duration + initialWait time.Duration + abortEarlyOnError bool } // InitialWaitOption is an option that can be passed to the WaitForState @@ -76,6 +77,24 @@ func (w *InitialWaitOption) apply(o *fsmOptions) { o.initialWait = w.initialWait } +// AbortEarlyOnErrorOption is an option that can be passed to the WaitForState +// function to abort early if an error occurs. +type AbortEarlyOnErrorOption struct { + abortEarlyOnError bool +} + +// apply implements the WaitForStateOption interface. +func (a *AbortEarlyOnErrorOption) apply(o *fsmOptions) { + o.abortEarlyOnError = a.abortEarlyOnError +} + +// WithAbortEarlyOnErrorOption creates a new AbortEarlyOnErrorOption. +func WithAbortEarlyOnErrorOption() WaitForStateOption { + return &AbortEarlyOnErrorOption{ + abortEarlyOnError: true, + } +} + // WaitForState waits for the state machine to reach the given state. // If the optional initialWait parameter is set, the function will wait for // the given duration before checking the state. This is useful if the @@ -105,7 +124,8 @@ func (s *CachedObserver) WaitForState(ctx context.Context, defer cancel() // Channel to notify when the desired state is reached - ch := make(chan struct{}) + // or an error occurred. + ch := make(chan error) // Goroutine to wait on condition variable go func() { @@ -115,8 +135,26 @@ func (s *CachedObserver) WaitForState(ctx context.Context, for { // Check if the last state is the desired state if s.lastNotification.NextState == state { - ch <- struct{}{} - return + select { + case <-timeoutCtx.Done(): + return + + case ch <- nil: + return + } + } + + // Check if an error occurred + if s.lastNotification.Event == OnError { + if options.abortEarlyOnError { + select { + case <-timeoutCtx.Done(): + return + + case ch <- s.lastNotification.LastActionError: + return + } + } } // Otherwise, wait for the next notification @@ -130,7 +168,11 @@ func (s *CachedObserver) WaitForState(ctx context.Context, return NewErrWaitingForStateTimeout( state, s.lastNotification.NextState, ) - case <-ch: + + case lastActionErr := <-ch: + if lastActionErr != nil { + return lastActionErr + } return nil } }