From 1702515640a79e870871837bba4523c80f2752c8 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Mon, 31 Aug 2020 15:39:34 +0200 Subject: [PATCH 1/3] exclude rules when using `rasa test` --- rasa/core/test.py | 2 +- tests/test_test.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/rasa/core/test.py b/rasa/core/test.py index 92cda8dc2c04..f0116e775a41 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -236,7 +236,7 @@ async def _generate_trackers( augmentation_factor=0, tracker_limit=max_stories, ) - return g.generate() + return g._generate_ml_trackers() def _clean_entity_results( diff --git a/tests/test_test.py b/tests/test_test.py index 9e72e009e6ca..b3199165ef37 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -9,6 +9,7 @@ from _pytest.monkeypatch import MonkeyPatch import rasa.utils.io +from rasa.core.agent import Agent from rasa.core.events import UserUttered from rasa.core.test import ( EvaluationStore, @@ -201,3 +202,29 @@ def test_log_failed_stories(tmp_path: Path): assert dump.startswith("#") assert len(dump.split("\n")) == 1 + + +async def test_test_does_not_use_rules(tmp_path: Path, default_agent: Agent): + from rasa.core.test import _generate_trackers + + test_file = tmp_path / "test.yml" + test_name = "my test story" + tests = f""" +stories: +- story: {test_name} + steps: + - intent: greet + - action: utter_greet + +rules: +- rule: rule which is ignored + steps: + - intent: greet + - action: utter_greet + """ + + test_file.write_text(tests) + + test_trackers = await _generate_trackers(str(test_file), default_agent) + assert len(test_trackers) == 1 + assert test_trackers[0].sender_id == test_name From 24c8a930c3896a1aa8e3fb8686fb35b3f0d96627 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Mon, 31 Aug 2020 15:40:53 +0200 Subject: [PATCH 2/3] rename `_generate_ml_trackers` to `_generate_story_trackeers` --- rasa/core/test.py | 2 +- rasa/core/training/generator.py | 18 ++++++++++++++---- tests/test_test.py | 2 +- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/rasa/core/test.py b/rasa/core/test.py index f0116e775a41..619aebb70f92 100644 --- a/rasa/core/test.py +++ b/rasa/core/test.py @@ -236,7 +236,7 @@ async def _generate_trackers( augmentation_factor=0, tracker_limit=max_stories, ) - return g._generate_ml_trackers() + return g.generate_story_trackers() def _clean_entity_results( diff --git a/rasa/core/training/generator.py b/rasa/core/training/generator.py index 938953ebfd44..dee0fed91432 100644 --- a/rasa/core/training/generator.py +++ b/rasa/core/training/generator.py @@ -222,7 +222,20 @@ def _phase_name(everything_reachable_is_reached, phase): else: return f"data generation round {phase}" - def _generate_ml_trackers(self) -> List[TrackerWithCachedStates]: + def generate(self) -> List[TrackerWithCachedStates]: + """Generate trackers from stories and rules. + + Returns: + The generated trackers. + """ + return self.generate_story_trackers() + self._generate_rule_trackers() + + def generate_story_trackers(self) -> List[TrackerWithCachedStates]: + """Generate trackers from stories (exclude rule trackers). + + Returns: + The generated story trackers. + """ steps = [step for step in self.story_graph.ordered_steps() if not step.is_rule] return self._generate(steps, is_rule_data=False) @@ -232,9 +245,6 @@ def _generate_rule_trackers(self) -> List[TrackerWithCachedStates]: return self._generate(steps, is_rule_data=True) - def generate(self) -> List[TrackerWithCachedStates]: - return self._generate_ml_trackers() + self._generate_rule_trackers() - def _generate( self, story_steps: List[StoryStep], is_rule_data: bool = False ) -> List[TrackerWithCachedStates]: diff --git a/tests/test_test.py b/tests/test_test.py index b3199165ef37..74a3ab91b6ed 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -215,7 +215,7 @@ async def test_test_does_not_use_rules(tmp_path: Path, default_agent: Agent): steps: - intent: greet - action: utter_greet - + rules: - rule: rule which is ignored steps: From 699886a4a99421443cf904ec74490f550198e0c8 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Mon, 31 Aug 2020 16:00:18 +0200 Subject: [PATCH 3/3] rename helper function to actual implementation --- docs/docs/forms.mdx | 4 ++-- docs/docs/migration-guide.mdx | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/docs/forms.mdx b/docs/docs/forms.mdx index 364652a8188b..ec5210d8a9c4 100644 --- a/docs/docs/forms.mdx +++ b/docs/docs/forms.mdx @@ -263,7 +263,7 @@ When the form is executed it will run your custom action. In your custom action you can either - validate already extracted slots. You can retrieve them from the tracker by running - `tracker.get_extracted_slots`. + `tracker.form_slots_to_validate`. - use [Custom Slot Mappings](./forms.mdx#slot-mappings) to extract slot values . After validating the extracted slots, return `SlotSet` events for them. If you want @@ -291,7 +291,7 @@ class ValidateSlots(Action): def run( self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: Dict ) -> List[EventType]: - extracted_slots: Dict[Text, Any] = tracker.get_extracted_slots() + extracted_slots: Dict[Text, Any] = tracker.form_slots_to_validate() validation_events = [] diff --git a/docs/docs/migration-guide.mdx b/docs/docs/migration-guide.mdx index 5d2fa66e3d4c..83f5ba983197 100644 --- a/docs/docs/migration-guide.mdx +++ b/docs/docs/migration-guide.mdx @@ -303,7 +303,7 @@ class RestaurantFormValidator(Action): def run( self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: Dict ) -> List[EventType]: - extracted_slots: Dict[Text, Any] = tracker.get_extracted_slots() + extracted_slots: Dict[Text, Any] = tracker.form_slots_to_validate() cuisine_slot_value = extracted_slots.get("cuisine") validated_slot_event = self.validate_cuisine(