diff --git a/docs/docs/forms.mdx b/docs/docs/forms.mdx index 364652a8188b..ec5210d8a9c4 100644 --- a/docs/docs/forms.mdx +++ b/docs/docs/forms.mdx @@ -263,7 +263,7 @@ When the form is executed it will run your custom action. In your custom action you can either - validate already extracted slots. You can retrieve them from the tracker by running - `tracker.get_extracted_slots`. + `tracker.form_slots_to_validate`. - use [Custom Slot Mappings](./forms.mdx#slot-mappings) to extract slot values . After validating the extracted slots, return `SlotSet` events for them. If you want @@ -291,7 +291,7 @@ class ValidateSlots(Action): def run( self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: Dict ) -> List[EventType]: - extracted_slots: Dict[Text, Any] = tracker.get_extracted_slots() + extracted_slots: Dict[Text, Any] = tracker.form_slots_to_validate() validation_events = [] diff --git a/docs/docs/migration-guide.mdx b/docs/docs/migration-guide.mdx index 5d2fa66e3d4c..83f5ba983197 100644 --- a/docs/docs/migration-guide.mdx +++ b/docs/docs/migration-guide.mdx @@ -303,7 +303,7 @@ class RestaurantFormValidator(Action): def run( self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: Dict ) -> List[EventType]: - extracted_slots: Dict[Text, Any] = tracker.get_extracted_slots() + extracted_slots: Dict[Text, Any] = tracker.form_slots_to_validate() cuisine_slot_value = extracted_slots.get("cuisine") validated_slot_event = self.validate_cuisine( diff --git a/rasa/core/test.py b/rasa/core/test.py index 92cda8dc2c04..619aebb70f92 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -236,7 +236,7 @@ async def _generate_trackers( augmentation_factor=0, tracker_limit=max_stories, ) - return g.generate() + return g.generate_story_trackers() def _clean_entity_results( diff --git a/rasa/core/training/generator.py b/rasa/core/training/generator.py index 938953ebfd44..dee0fed91432 100644 --- a/rasa/core/training/generator.py +++ b/rasa/core/training/generator.py @@ -222,7 +222,20 @@ def _phase_name(everything_reachable_is_reached, phase): else: return f"data generation round {phase}" - def _generate_ml_trackers(self) -> List[TrackerWithCachedStates]: + def generate(self) -> List[TrackerWithCachedStates]: + """Generate trackers from stories and rules. + + Returns: + The generated trackers. + """ + return self.generate_story_trackers() + self._generate_rule_trackers() + + def generate_story_trackers(self) -> List[TrackerWithCachedStates]: + """Generate trackers from stories (exclude rule trackers). + + Returns: + The generated story trackers. + """ steps = [step for step in self.story_graph.ordered_steps() if not step.is_rule] return self._generate(steps, is_rule_data=False) @@ -232,9 +245,6 @@ def _generate_rule_trackers(self) -> List[TrackerWithCachedStates]: return self._generate(steps, is_rule_data=True) - def generate(self) -> List[TrackerWithCachedStates]: - return self._generate_ml_trackers() + self._generate_rule_trackers() - def _generate( self, story_steps: List[StoryStep], is_rule_data: bool = False ) -> List[TrackerWithCachedStates]: diff --git a/tests/test_test.py b/tests/test_test.py index 9e72e009e6ca..74a3ab91b6ed 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -9,6 +9,7 @@ from _pytest.monkeypatch import MonkeyPatch import rasa.utils.io +from rasa.core.agent import Agent from rasa.core.events import UserUttered from rasa.core.test import ( EvaluationStore, @@ -201,3 +202,29 @@ def test_log_failed_stories(tmp_path: Path): assert dump.startswith("#") assert len(dump.split("\n")) == 1 + + +async def test_test_does_not_use_rules(tmp_path: Path, default_agent: Agent): + from rasa.core.test import _generate_trackers + + test_file = tmp_path / "test.yml" + test_name = "my test story" + tests = f""" +stories: +- story: {test_name} + steps: + - intent: greet + - action: utter_greet + +rules: +- rule: rule which is ignored + steps: + - intent: greet + - action: utter_greet + """ + + test_file.write_text(tests) + + test_trackers = await _generate_trackers(str(test_file), default_agent) + assert len(test_trackers) == 1 + assert test_trackers[0].sender_id == test_name