diff --git a/wf/engine.go b/wf/engine.go index 00749ae..51535de 100644 --- a/wf/engine.go +++ b/wf/engine.go @@ -3,69 +3,70 @@ package wf import "fmt" type Engine struct { + initialState *State currentState *State states map[StateName]*State } -// NewEngine - build a new wf engine with an initial state -func NewEngine() *Engine { - s := map[StateName]*State{} - - state := &State{ - name: stateInitial, - actions: map[Event]*State{}, - } +type EngineOption func(state *Engine) error - s[stateInitial] = state - return &Engine{ - currentState: state, - states: s, +// NewEngine - build a new wf engine with an initial state +func NewEngine(initialStateName StateName, opts ...EngineOption) (*Engine, error) { + initialState := &State{ + name: initialStateName, + actions: map[EventName]*State{}, } -} -// RegisterState - will add a new state -// will return the new state or error if the state was previously defined -func (e *Engine) RegisterState(name StateName) (*State, error) { - _, ok := e.states[name] - if ok { - return nil, fmt.Errorf("state %q already defined", name) + e := &Engine{ + initialState: initialState, + currentState: initialState, + states: map[StateName]*State{ + initialStateName: initialState, + }, } - s := &State{ - name: name, - actions: map[Event]*State{}, + for _, opt := range opts { + err := opt(e) + if err != nil { + return nil, err + } } - e.states[name] = s - return s, nil + return e, nil } -// RegisterEvent - will add an event to facilitate transition from current state to the next state -func (e *Engine) RegisterEvent(curState *State, event Event, nextState *State) error { - if !curState.attachEvent(event, nextState) { - return fmt.Errorf("event %q already defined for the state %q", event, curState.name) +// WithState - will add a state during build +func WithState(fromStateName StateName, eventName EventName, toStateName StateName) EngineOption { + return func(e *Engine) error { + _, err := e.RegisterState(fromStateName, eventName, toStateName) + return err } - return nil } -// RegisterStateAndEvent - will add a new state and event -func (e *Engine) RegisterStateAndEvent(fromStateName StateName, event Event, toStateName StateName) (*State, error) { - fromState, ok := e.states[fromStateName] - if !ok { - return nil, fmt.Errorf("state %q is not defined", fromStateName) +// RegisterState - will add a new state +// will return the new state or error if the state was previously defined +func (e *Engine) RegisterState(fromStateName StateName, eventName EventName, toStateName StateName) (*State, error) { + fromState := e.getOrCreateState(fromStateName) + toState := e.getOrCreateState(toStateName) + if eventOk := fromState.attachEvent(eventName, toState); !eventOk { + return nil, fmt.Errorf("event %q already defined for the state %q", eventName, fromState.name) } + return toState, nil +} - toState, err := e.RegisterState(toStateName) - if err != nil { - return nil, err +// getOrCreateState - gets or creates a state. +func (e *Engine) getOrCreateState(name StateName) *State { + state, ok := e.states[name] + if ok { + return state } - err = e.RegisterEvent(fromState, event, toState) - if err != nil { - return nil, err + newState := &State{ + name: name, + actions: map[EventName]*State{}, } - - return toState, nil + e.states[name] = newState + return newState } // GetState - returns a state by name or nil @@ -79,7 +80,7 @@ func (e *Engine) GetState(name StateName) *State { // GetInitialState - returns the initial state func (e *Engine) GetInitialState() *State { - return e.GetState(stateInitial) + return e.initialState } // GetCurrentState - returns the current state @@ -89,10 +90,10 @@ func (e *Engine) GetCurrentState() *State { // ProcessEvent - will run the event and return the next state // in case of the event was not defined, will return error -func (e *Engine) ProcessEvent(event Event) (*State, error) { - nextState, err := e.currentState.execEvent(event) - if err != nil { - return nil, err +func (e *Engine) ProcessEvent(event EventName) (*State, error) { + nextState, ok := e.currentState.execEvent(event) + if !ok { + return nil, fmt.Errorf("event %q is not defined for the current state %q", event, e.currentState.name) } e.currentState = nextState return nextState, nil diff --git a/wf/engine_test.go b/wf/engine_test.go index 3acbd2a..7ed65ca 100644 --- a/wf/engine_test.go +++ b/wf/engine_test.go @@ -10,11 +10,12 @@ import ( ) const ( - S1 = "STATE_1" - S2 = "STATE_2" - S3 = "STATE_3" - S4 = "STATE_4" - S5 = "STATE_5" + START = "START" + S1 = "STATE_1" + S2 = "STATE_2" + S3 = "STATE_3" + S4 = "STATE_4" + FINISH = "FINISH" E1 = "EVENT_1" E2 = "EVENT_2" @@ -25,151 +26,121 @@ const ( ) func TestNewEngine(t *testing.T) { - wfe := wf.NewEngine() + wfe, err := wf.NewEngine(START) + require.NoError(t, err) require.NotNil(t, wfe) - s0 := wfe.GetInitialState() - assert.NotNil(t, s0) + initial := wfe.GetInitialState() + assert.NotNil(t, initial) + assert.EqualValues(t, "START", initial.GetName()) } -func TestEngine_RegisterState(t *testing.T) { - wfe := wf.NewEngine() - s0 := wfe.GetInitialState() - s1, err := wfe.RegisterState(S1) +func TestNewEngine_WithState_WhenNoErrors(t *testing.T) { + wfe, err := wf.NewEngine( + START, + wf.WithState(START, E1, S1), + wf.WithState(S1, E2, S2), + wf.WithState(S1, E4, S4), + wf.WithState(S2, E3, S3), + wf.WithState(S3, E3, S2), + wf.WithState(S3, E5, FINISH), + wf.WithState(S4, E5, FINISH), + ) require.NoError(t, err) + require.NotNil(t, wfe) - require.NotNil(t, s1) - assert.EqualValues(t, s0, wfe.GetCurrentState()) - assert.EqualValues(t, wf.StateName("STATE_1"), s1.GetName()) -} + // test initial state + initial := wfe.GetInitialState() + assert.NotNil(t, initial) + assert.EqualValues(t, "START", initial.GetName()) -func TestEngine_RegisterState_When_Error(t *testing.T) { - wfe := wf.NewEngine() - s0 := wfe.GetInitialState() - _, err := wfe.RegisterState(s0.GetName()) - require.Error(t, err) - assert.EqualValues(t, "state \"INITIAL_STATE\" already defined", err.Error()) + // test jump to S1 + nextState, err := wfe.ProcessEvent(E1) + require.NoError(t, err) + require.NotNil(t, nextState) + assert.EqualValues(t, wf.StateName("STATE_1"), nextState.GetName()) + + // test jump to S2 + nextState, err = wfe.ProcessEvent(E2) + require.NoError(t, err) + require.NotNil(t, nextState) + assert.EqualValues(t, wf.StateName("STATE_2"), nextState.GetName()) + + // test illegal jump + nextState, err = wfe.ProcessEvent(E2) + require.Nil(t, nextState) + require.NotNil(t, err) + assert.EqualValues(t, "event \"EVENT_2\" is not defined for the current state \"STATE_2\"", err.Error()) } -func TestEngine_RegisterEvent_When_Error(t *testing.T) { - wfe := wf.NewEngine() +func TestEngine_RegisterState(t *testing.T) { + wfe, err := wf.NewEngine(START) + require.NoError(t, err) + require.NotNil(t, wfe) s0 := wfe.GetInitialState() - s1, _ := wfe.RegisterState(S1) - err := wfe.RegisterEvent(s0, E1, s1) + s1, err := wfe.RegisterState(START, E1, S1) require.NoError(t, err) - err = wfe.RegisterEvent(s0, E1, s1) - assert.Error(t, err) - assert.EqualValues(t, "event \"EVENT_1\" already defined for the state \"INITIAL_STATE\"", err.Error()) + require.NotNil(t, s1) + + s1_ := wfe.GetState(S1) + require.NotNil(t, s1_) + + assert.EqualValues(t, s0, wfe.GetCurrentState()) + assert.EqualValues(t, wf.StateName("STATE_1"), s1_.GetName()) } func TestEngine_GetCurrentState(t *testing.T) { - wfe := wf.NewEngine() + wfe, err := wf.NewEngine(START) + require.NoError(t, err) + require.NotNil(t, wfe) s0 := wfe.GetInitialState() assert.NotNil(t, s0) assert.EqualValues(t, s0, wfe.GetCurrentState()) - s1, _ := wfe.RegisterState(S1) + s1, _ := wfe.RegisterState(START, E1, S1) assert.NotNil(t, s1) - err := wfe.RegisterEvent(s0, E1, s1) - assert.NoError(t, err) - nextState, err := wfe.ProcessEvent(E1) require.NoError(t, err) assert.EqualValues(t, nextState, wfe.GetCurrentState()) } func TestEngine_GetState_When_Error(t *testing.T) { - wfe := wf.NewEngine() + wfe, err := wf.NewEngine(START) + require.NoError(t, err) + require.NotNil(t, wfe) + s1 := wfe.GetState("FAKE STATE") assert.Nil(t, s1) } func TestEngine_ProcessEvent_When_Error(t *testing.T) { - wfe := wf.NewEngine() - - s0 := wfe.GetInitialState() - s1, _ := wfe.RegisterState(S1) - - err := wfe.RegisterEvent(s0, E1, s1) + wfe, err := wf.NewEngine(START) require.NoError(t, err) + require.NotNil(t, wfe) + + s1, _ := wfe.RegisterState(START, E1, S1) + assert.NotNil(t, s1) _, err = wfe.ProcessEvent(E2) require.Error(t, err) - assert.EqualValues(t, "event \"EVENT_2\" is not defined for the current state \"INITIAL_STATE\"", err.Error()) + assert.EqualValues(t, "event \"EVENT_2\" is not defined for the current state \"START\"", err.Error()) } func TestEngine_JumpToState(t *testing.T) { - wfe := wf.NewEngine() - - wfe.GetInitialState() - _, _ = wfe.RegisterState(S1) - _, _ = wfe.RegisterState(S2) - - err := wfe.JumpToState(S2) - require.NoError(t, err) - assert.EqualValues(t, "STATE_2", wfe.GetCurrentState().GetName()) -} - -func TestEngine_RegisterFutureStateByName(t *testing.T) { - wfe := wf.NewEngine() - s1, err := wfe.RegisterStateAndEvent(wfe.GetInitialState().GetName(), E1, S1) - require.NoError(t, err) - - nextState, err := wfe.ProcessEvent(E1) + wfe, err := wf.NewEngine(START) require.NoError(t, err) - assert.EqualValues(t, S1, nextState.GetName()) - assert.EqualValues(t, s1.GetName(), nextState.GetName()) -} - -func TestFullFlow(t *testing.T) { - wfe := wf.NewEngine() - - s0 := wfe.GetInitialState() - s1, _ := wfe.RegisterState(S1) - s2, _ := wfe.RegisterState(S2) - s3, _ := wfe.RegisterState(S3) - s4, _ := wfe.RegisterState(S4) - s5, _ := wfe.RegisterState(S5) - - err := wfe.RegisterEvent(s0, E1, s1) - assert.NoError(t, err) - - err = wfe.RegisterEvent(s0, E2, s2) - assert.NoError(t, err) - - err = wfe.RegisterEvent(s1, E3, s3) - assert.NoError(t, err) - - err = wfe.RegisterEvent(s2, E3, s3) - assert.NoError(t, err) - - err = wfe.RegisterEvent(s3, E4, s4) - assert.NoError(t, err) - - err = wfe.RegisterEvent(s4, E5, s2) - assert.NoError(t, err) - - err = wfe.RegisterEvent(s4, E6, s5) - assert.NoError(t, err) - - nextState, err := wfe.ProcessEvent(E1) - require.NoError(t, err) - assert.EqualValues(t, s1, nextState) + require.NotNil(t, wfe) - nextState, err = wfe.ProcessEvent(E3) + wfe.GetInitialState() + _, err = wfe.RegisterState(START, E1, S1) require.NoError(t, err) - assert.EqualValues(t, s3, nextState) - - nextState, err = wfe.ProcessEvent(E4) + _, err = wfe.RegisterState(S1, E2, S2) require.NoError(t, err) - assert.EqualValues(t, s4, nextState) - nextState, err = wfe.ProcessEvent(E6) + err = wfe.JumpToState(S2) require.NoError(t, err) - assert.EqualValues(t, s5, nextState) - - assert.EqualValues(t, s5, wfe.GetCurrentState()) + assert.EqualValues(t, "STATE_2", wfe.GetCurrentState().GetName()) } diff --git a/wf/event.go b/wf/event.go index 36873fa..fd84428 100644 --- a/wf/event.go +++ b/wf/event.go @@ -1,3 +1,3 @@ package wf -type Event string +type EventName string diff --git a/wf/state.go b/wf/state.go index 11baf33..9526bc1 100644 --- a/wf/state.go +++ b/wf/state.go @@ -1,7 +1,5 @@ package wf -import "fmt" - type StateName string const ( @@ -10,10 +8,10 @@ const ( type State struct { name StateName - actions map[Event]*State + actions map[EventName]*State } -func (s *State) attachEvent(event Event, nextState *State) bool { +func (s *State) attachEvent(event EventName, nextState *State) bool { _, ok := s.actions[event] if !ok { s.actions[event] = nextState @@ -21,12 +19,12 @@ func (s *State) attachEvent(event Event, nextState *State) bool { return !ok } -func (s *State) execEvent(event Event) (*State, error) { +func (s *State) execEvent(event EventName) (*State, bool) { newState, ok := s.actions[event] if ok { - return newState, nil + return newState, true } - return nil, fmt.Errorf("event %q is not defined for the current state %q", event, s.name) + return nil, false } func (s *State) GetName() StateName { diff --git a/wf/state_test.go b/wf/state_test.go index 6254115..750b3ed 100644 --- a/wf/state_test.go +++ b/wf/state_test.go @@ -10,7 +10,8 @@ import ( ) func TestState_GetName(t *testing.T) { - wfe := wf.NewEngine() + wfe, err := wf.NewEngine("INITIAL_STATE") + require.NoError(t, err) require.NotNil(t, wfe) s0 := wfe.GetInitialState()