Skip to content

Commit

Permalink
strict_mode added to validation
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Nov 4, 2024
1 parent 4f4af19 commit 3ed8ace
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 69 deletions.
4 changes: 3 additions & 1 deletion yasmin/include/yasmin/state_machine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,11 @@ class StateMachine : public State {
/**
* @brief Validates the state machine configuration.
*
* @param strict Whether the validation is strict, which means checking if all
* state outcomes are used and all state machine outcomes are reached.
* @throws std::runtime_error If the state machine is misconfigured.
*/
void validate();
void validate(bool strict_mode = false);

/**
* @brief Executes the state machine.
Expand Down
47 changes: 27 additions & 20 deletions yasmin/src/yasmin/state_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ void StateMachine::call_end_cbs(
}
}

void StateMachine::validate() {
void StateMachine::validate(bool strict_mode) {

YASMIN_LOG_DEBUG("Validating state machine '%s'", this->to_string().c_str());

if (this->validated.load()) {
if (this->validated.load() && !strict_mode) {
YASMIN_LOG_DEBUG("State machine '%s' has already been validated",
this->to_string().c_str());
}
Expand All @@ -210,6 +210,7 @@ void StateMachine::validate() {
throw std::runtime_error("No initial state set");
}

// Terminal outcomes from all transitions
std::set<std::string> terminal_outcomes;

// Check all states
Expand All @@ -222,26 +223,30 @@ void StateMachine::validate() {

std::set<std::string> outcomes = state->get_outcomes();

// Check if all state outcomes are in transitions
for (const std::string &o : outcomes) {
if (strict_mode) {
// Check if all outcomes of the state are in transitions
for (const std::string &o : outcomes) {

if (transitions.find(o) == transitions.end() &&
std::find(this->get_outcomes().begin(), this->get_outcomes().end(),
o) == this->get_outcomes().end()) {
if (transitions.find(o) == transitions.end() &&
std::find(this->get_outcomes().begin(), this->get_outcomes().end(),
o) == this->get_outcomes().end()) {

throw std::runtime_error("State '" + state_name + "' outcome '" + o +
"' not registered in transitions");
throw std::runtime_error("State '" + state_name + "' outcome '" + o +
"' not registered in transitions");

} else if (std::find(this->get_outcomes().begin(),
this->get_outcomes().end(),
o) != this->get_outcomes().end()) {
terminal_outcomes.insert(o);
// Outcomes of the state that are in outcomes of the state machine
// do not need transitions
} else if (std::find(this->get_outcomes().begin(),
this->get_outcomes().end(),
o) != this->get_outcomes().end()) {
terminal_outcomes.insert(o);
}
}
}

// If state is a state machine, validate it
if (std::dynamic_pointer_cast<StateMachine>(state)) {
std::dynamic_pointer_cast<StateMachine>(state)->validate();
std::dynamic_pointer_cast<StateMachine>(state)->validate(strict_mode);
}

// Add terminal outcomes
Expand All @@ -255,15 +260,17 @@ void StateMachine::validate() {
std::set<std::string> sm_outcomes(this->get_outcomes().begin(),
this->get_outcomes().end());

// Check if all state machine outcomes are in the terminal outcomes
for (const std::string &o : this->get_outcomes()) {
if (terminal_outcomes.find(o) == terminal_outcomes.end()) {
throw std::runtime_error("Target outcome '" + o +
"' not registered in transitions");
if (strict_mode) {
// Check if all outcomes from the state machine are in the terminal outcomes
for (const std::string &o : this->get_outcomes()) {
if (terminal_outcomes.find(o) == terminal_outcomes.end()) {
throw std::runtime_error("Target outcome '" + o +
"' not registered in transitions");
}
}
}

// Check if all terminal outcomes are states or state machine outcomes
// Check if all terminal outcomes are states or outcomes of the state machine
for (const std::string &o : terminal_outcomes) {
if (this->states.find(o) == this->states.end() &&
sm_outcomes.find(o) == sm_outcomes.end()) {
Expand Down
8 changes: 4 additions & 4 deletions yasmin/tests/python/test_blackboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ class TestBlackboard(unittest.TestCase):
def setUp(self):
self.blackboard = Blackboard()

def test_blackboard_get(self):
def test_get(self):
self.blackboard["foo"] = "foo"
self.assertEqual("foo", self.blackboard["foo"])

def test_blackboard_delete(self):
def test_delete(self):
self.blackboard["foo"] = "foo"
del self.blackboard["foo"]
self.assertFalse("foo" in self.blackboard)

def test_blackboard_contains(self):
def test_contains(self):
self.blackboard["foo"] = "foo"
self.assertTrue("foo" in self.blackboard)

def test_blackboard_len(self):
def test_len(self):
self.blackboard["foo"] = "foo"
self.assertEqual(1, len(self.blackboard))
10 changes: 5 additions & 5 deletions yasmin/tests/python/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ class TestState(unittest.TestCase):
def setUp(self):
self.state = FooState()

def test_state_call(self):
def test_call(self):
self.assertEqual("outcome1", self.state())

def test_state_cancel(self):
def test_cancel(self):
self.assertFalse(self.state.is_canceled())
self.state.cancel_state()
self.assertTrue(self.state.is_canceled())

def test_state_get_outcomes(self):
def test_get_outcomes(self):
self.assertEqual("outcome1", list(self.state.get_outcomes())[0])

def test_state_str(self):
def test_str(self):
self.assertEqual("FooState", str(self.state))

def test_state_init_exception(self):
def test_init_exception(self):
self.assertRaises(Exception, BarState)
30 changes: 17 additions & 13 deletions yasmin/tests/python/test_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def execute(self, blackboard):

class BarState(State):
def __init__(self):
super().__init__(outcomes=["outcome2"])
super().__init__(outcomes=["outcome2", "outcome3"])

def execute(self, blackboard):
return "outcome2"
Expand All @@ -46,7 +46,7 @@ class TestStateMachine(unittest.TestCase):
maxDiff = None

def setUp(self):
self.sm = StateMachine(outcomes=["outcome4"])
self.sm = StateMachine(outcomes=["outcome4", "outcome5"])

self.sm.add_state(
"FOO",
Expand All @@ -62,19 +62,19 @@ def setUp(self):
transitions={"outcome2": "FOO"},
)

def test_state_machine_str(self):
def test_str(self):
self.assertEqual("State Machine [BAR (BarState), FOO (FooState)]", str(self.sm))

def test_state_machine_get_states(self):
def test_get_states(self):
self.assertTrue(isinstance(self.sm.get_states()["FOO"]["state"], FooState))
self.assertTrue(isinstance(self.sm.get_states()["BAR"]["state"], BarState))

def test_state_machine_get_start_state(self):
def test_get_start_state(self):
self.assertEqual("FOO", self.sm.get_start_state())
self.sm.set_start_state("BAR")
self.assertEqual("BAR", self.sm.get_start_state())

def test_state_machine_get_current_state(self):
def test_get_current_state(self):
self.assertEqual("", self.sm.get_current_state())

def test_state_call(self):
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_add_wrong_target_transition(self):
str(context.exception), "Transitions with empty target in state 'FOO1'"
)

def test_validate_state_machine_outcome_from_fsm_not_used(self):
def test_validate_outcome_from_fsm_not_used(self):

sm_1 = StateMachine(outcomes=["outcome4"])

Expand All @@ -163,14 +163,15 @@ def test_validate_state_machine_outcome_from_fsm_not_used(self):
"outcome2": "outcome4",
},
)

with self.assertRaises(KeyError) as context:
sm_1.validate()
sm_1.validate(True)
self.assertEqual(
str(context.exception),
"\"State 'FSM' outcome 'outcome5' not registered in transitions\"",
)

def test_validate_state_machine_outcome_from_state_not_used(self):
def test_validate_outcome_from_state_not_used(self):

sm_1 = StateMachine(outcomes=["outcome4"])

Expand All @@ -184,14 +185,15 @@ def test_validate_state_machine_outcome_from_state_not_used(self):
"outcome1": "outcome4",
},
)

with self.assertRaises(KeyError) as context:
sm_1.validate()
sm_1.validate(True)
self.assertEqual(
str(context.exception),
"\"State 'FOO' outcome 'outcome2' not registered in transitions\"",
)

def test_validate_state_machine_fsm_outcome_not_used(self):
def test_validate_fsm_outcome_not_used(self):

sm_1 = StateMachine(outcomes=["outcome4"])

Expand All @@ -212,14 +214,15 @@ def test_validate_state_machine_fsm_outcome_not_used(self):
"outcome2": "outcome4",
},
)

with self.assertRaises(KeyError) as context:
sm_1.validate()
sm_1.validate(True)
self.assertEqual(
str(context.exception),
"\"Target outcome 'outcome5' not registered in transitions\"",
)

def test_validate_state_machine_wrong_state(self):
def test_validate_wrong_state(self):

sm_1 = StateMachine(outcomes=["outcome4"])

Expand All @@ -240,6 +243,7 @@ def test_validate_state_machine_wrong_state(self):
"outcome2": "outcome4",
},
)

with self.assertRaises(KeyError) as context:
sm_1.validate()
self.assertEqual(
Expand Down
42 changes: 24 additions & 18 deletions yasmin/yasmin/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class StateMachine(State):
_start_state (str): The name of the initial state of the state machine.
__current_state (str): The name of the current state being executed.
__current_state_lock (Lock): A threading lock to manage access to the current state.
_validated (bool): A flag indicating whether the state machine has been validated.
__start_cbs (List[Tuple[Callable[[Blackboard, str, List[Any]], None], List[Any]]]): A list of callbacks to call when the state machine starts.
__transition_cbs (List[Tuple[Callable[[Blackboard, str, List[Any]], None], List[Any]]]): A list of callbacks to call during state transitions.
__end_cbs (List[Tuple[Callable[[Blackboard, str, List[Any]], None], List[Any]]]): A list of callbacks to call when the state machine ends.
Expand Down Expand Up @@ -276,20 +277,23 @@ def _call_end_cbs(self, blackboard: Blackboard, outcome: str) -> None:
except Exception as e:
yasmin.YASMIN_LOG_ERROR(f"Could not execute end callback: {e}")

def validate(self) -> None:
def validate(self, strict_mode: bool = False) -> None:
"""
Validates the state machine to ensure all states and transitions are correct.
Parameters:
strict_mode (bool): Whether the validation is strict, which means checking if all state outcomes are used and all state machine outcomes are reached.
Raises:
RuntimeError: If no initial state is set.
KeyError: If there are any unregistered outcomes or transitions.
"""
yasmin.YASMIN_LOG_DEBUG(f"Validating state machine '{self}'")

if self._validated:
if self._validated and not strict_mode:
yasmin.YASMIN_LOG_DEBUG("State machine '{self}' has already been validated")

# terminal outcomes
# Terminal outcomes from all transitions
terminal_outcomes = []

# Check initial state
Expand All @@ -303,33 +307,35 @@ def validate(self) -> None:

outcomes = state.get_outcomes()

# Check if all state outcomes are in transitions
for o in outcomes:
if o not in set(list(transitions.keys()) + list(self.get_outcomes())):
raise KeyError(
f"State '{state_name}' outcome '{o}' not registered in transitions"
)
if strict_mode:
# Check if all outcomes of the state are in transitions
for o in outcomes:
if o not in set(list(transitions.keys()) + list(self.get_outcomes())):
raise KeyError(
f"State '{state_name}' outcome '{o}' not registered in transitions"
)

# State outcomes that are in state machine outcomes do not need transitions
elif o in self.get_outcomes():
terminal_outcomes.append(o)
# Outcomes of the state that are in outcomes of the state machine do not need transitions
elif o in self.get_outcomes():
terminal_outcomes.append(o)

# If state is a state machine, validate it
if isinstance(state, StateMachine):
state.validate()
state.validate(strict_mode)

# Add terminal outcomes
terminal_outcomes.extend([transitions[key] for key in transitions])

# Check terminal outcomes for the state machine
terminal_outcomes = set(terminal_outcomes)

# Check if all state machine outcomes are in the terminal outcomes
for o in self.get_outcomes():
if o not in terminal_outcomes:
raise KeyError(f"Target outcome '{o}' not registered in transitions")
if strict_mode:
# Check if all outcomes of the state machine are in the terminal outcomes
for o in self.get_outcomes():
if o not in terminal_outcomes:
raise KeyError(f"Target outcome '{o}' not registered in transitions")

# Check if all terminal outcomes are states or state machine outcomes
# Check if all terminal outcomes are states or outcomes of the state machine
for o in terminal_outcomes:
if o not in set(list(self._states.keys()) + list(self.get_outcomes())):
raise KeyError(
Expand Down
Loading

0 comments on commit 3ed8ace

Please sign in to comment.