fsm: add early abort observer option

pull/698/head
sputn1ck 3 months ago
parent 2048b32c21
commit 1a31bbf75d
No known key found for this signature in database
GPG Key ID: 671103D881A5F0E4

@ -13,7 +13,10 @@ var (
ErrWaitForStateTimedOut = errors.New( ErrWaitForStateTimedOut = errors.New(
"timed out while waiting for event", "timed out while waiting for event",
) )
ErrInvalidContextType = errors.New("invalid context") ErrInvalidContextType = errors.New("invalid context")
ErrWaitingForStateEarlyAbortError = errors.New(
"waiting for state early abort",
)
) )
const ( const (
@ -73,6 +76,8 @@ type Notification struct {
NextState StateType NextState StateType
// Event is the event that was processed. // Event is the event that was processed.
Event EventType Event EventType
// LastActionError is the error returned by the last action executed.
LastActionError error
} }
// Observer is an interface that can be implemented by types that want to // Observer is an interface that can be implemented by types that want to
@ -214,9 +219,10 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
// Notify the state machine's observers. // Notify the state machine's observers.
s.observerMutex.Lock() s.observerMutex.Lock()
notification := Notification{ notification := Notification{
PreviousState: s.previous, PreviousState: s.previous,
NextState: s.current, NextState: s.current,
Event: event, Event: event,
LastActionError: s.LastActionError,
} }
for _, observer := range s.observers { for _, observer := range s.observers {

@ -55,7 +55,8 @@ type WaitForStateOption interface {
// fsmOptions is a struct that holds all options that can be passed to the // fsmOptions is a struct that holds all options that can be passed to the
// WaitForState function. // WaitForState function.
type fsmOptions struct { type fsmOptions struct {
initialWait time.Duration initialWait time.Duration
abortEarlyOnError bool
} }
// InitialWaitOption is an option that can be passed to the WaitForState // InitialWaitOption is an option that can be passed to the WaitForState
@ -76,6 +77,24 @@ func (w *InitialWaitOption) apply(o *fsmOptions) {
o.initialWait = w.initialWait o.initialWait = w.initialWait
} }
// AbortEarlyOnErrorOption is an option that can be passed to the WaitForState
// function to abort early if an error occurs.
type AbortEarlyOnErrorOption struct {
abortEarlyOnError bool
}
// apply implements the WaitForStateOption interface.
func (a *AbortEarlyOnErrorOption) apply(o *fsmOptions) {
o.abortEarlyOnError = a.abortEarlyOnError
}
// WithAbortEarlyOnErrorOption creates a new AbortEarlyOnErrorOption.
func WithAbortEarlyOnErrorOption() WaitForStateOption {
return &AbortEarlyOnErrorOption{
abortEarlyOnError: true,
}
}
// WaitForState waits for the state machine to reach the given state. // WaitForState waits for the state machine to reach the given state.
// If the optional initialWait parameter is set, the function will wait for // If the optional initialWait parameter is set, the function will wait for
// the given duration before checking the state. This is useful if the // the given duration before checking the state. This is useful if the
@ -105,7 +124,8 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
defer cancel() defer cancel()
// Channel to notify when the desired state is reached // Channel to notify when the desired state is reached
ch := make(chan struct{}) // or an error occurred.
ch := make(chan error)
// Goroutine to wait on condition variable // Goroutine to wait on condition variable
go func() { go func() {
@ -115,8 +135,26 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
for { for {
// Check if the last state is the desired state // Check if the last state is the desired state
if s.lastNotification.NextState == state { if s.lastNotification.NextState == state {
ch <- struct{}{} select {
return case <-timeoutCtx.Done():
return
case ch <- nil:
return
}
}
// Check if an error occurred
if s.lastNotification.Event == OnError {
if options.abortEarlyOnError {
select {
case <-timeoutCtx.Done():
return
case ch <- s.lastNotification.LastActionError:
return
}
}
} }
// Otherwise, wait for the next notification // Otherwise, wait for the next notification
@ -130,7 +168,11 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
return NewErrWaitingForStateTimeout( return NewErrWaitingForStateTimeout(
state, s.lastNotification.NextState, state, s.lastNotification.NextState,
) )
case <-ch:
case lastActionErr := <-ch:
if lastActionErr != nil {
return lastActionErr
}
return nil return nil
} }
} }

Loading…
Cancel
Save