Skip to content

Commit

Permalink
Merge pull request #6514 from RasaHQ/rules-test
Browse files Browse the repository at this point in the history
exclude rules from testing
  • Loading branch information
wochinge committed Aug 31, 2020
2 parents 5eb86b3 + 699886a commit d0e59e3
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 8 deletions.
4 changes: 2 additions & 2 deletions docs/docs/forms.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ When the form is executed it will run your custom action. In your custom action
you can either

- validate already extracted slots. You can retrieve them from the tracker by running
`tracker.get_extracted_slots`.
`tracker.form_slots_to_validate`.
- use [Custom Slot Mappings](./forms.mdx#slot-mappings) to extract slot values .

After validating the extracted slots, return `SlotSet` events for them. If you want
Expand Down Expand Up @@ -291,7 +291,7 @@ class ValidateSlots(Action):
def run(
self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: Dict
) -> List[EventType]:
extracted_slots: Dict[Text, Any] = tracker.get_extracted_slots()
extracted_slots: Dict[Text, Any] = tracker.form_slots_to_validate()

validation_events = []

Expand Down
2 changes: 1 addition & 1 deletion docs/docs/migration-guide.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class RestaurantFormValidator(Action):
def run(
self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: Dict
) -> List[EventType]:
extracted_slots: Dict[Text, Any] = tracker.get_extracted_slots()
extracted_slots: Dict[Text, Any] = tracker.form_slots_to_validate()

cuisine_slot_value = extracted_slots.get("cuisine")
validated_slot_event = self.validate_cuisine(
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def _generate_trackers(
augmentation_factor=0,
tracker_limit=max_stories,
)
return g.generate()
return g.generate_story_trackers()


def _clean_entity_results(
Expand Down
18 changes: 14 additions & 4 deletions rasa/core/training/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,20 @@ def _phase_name(everything_reachable_is_reached, phase):
else:
return f"data generation round {phase}"

def _generate_ml_trackers(self) -> List[TrackerWithCachedStates]:
def generate(self) -> List[TrackerWithCachedStates]:
"""Generate trackers from stories and rules.
Returns:
The generated trackers.
"""
return self.generate_story_trackers() + self._generate_rule_trackers()

def generate_story_trackers(self) -> List[TrackerWithCachedStates]:
"""Generate trackers from stories (exclude rule trackers).
Returns:
The generated story trackers.
"""
steps = [step for step in self.story_graph.ordered_steps() if not step.is_rule]

return self._generate(steps, is_rule_data=False)
Expand All @@ -232,9 +245,6 @@ def _generate_rule_trackers(self) -> List[TrackerWithCachedStates]:

return self._generate(steps, is_rule_data=True)

def generate(self) -> List[TrackerWithCachedStates]:
return self._generate_ml_trackers() + self._generate_rule_trackers()

def _generate(
self, story_steps: List[StoryStep], is_rule_data: bool = False
) -> List[TrackerWithCachedStates]:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from _pytest.monkeypatch import MonkeyPatch

import rasa.utils.io
from rasa.core.agent import Agent
from rasa.core.events import UserUttered
from rasa.core.test import (
EvaluationStore,
Expand Down Expand Up @@ -201,3 +202,29 @@ def test_log_failed_stories(tmp_path: Path):

assert dump.startswith("#")
assert len(dump.split("\n")) == 1


async def test_test_does_not_use_rules(tmp_path: Path, default_agent: Agent):
from rasa.core.test import _generate_trackers

test_file = tmp_path / "test.yml"
test_name = "my test story"
tests = f"""
stories:
- story: {test_name}
steps:
- intent: greet
- action: utter_greet
rules:
- rule: rule which is ignored
steps:
- intent: greet
- action: utter_greet
"""

test_file.write_text(tests)

test_trackers = await _generate_trackers(str(test_file), default_agent)
assert len(test_trackers) == 1
assert test_trackers[0].sender_id == test_name

0 comments on commit d0e59e3

Please sign in to comment.