Skip to content

Commit

Permalink
Merge branch '2.1.x' into 3900_domain_dump_order
Browse files Browse the repository at this point in the history
  • Loading branch information
rasabot authored Dec 4, 2020
2 parents 573b9dc + 2f8d4c2 commit 75b8116
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 27 deletions.
1 change: 1 addition & 0 deletions changelog/7390.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make sure the `responses` are synced between NLU training data and the Domain even if there're no retrieval intents in the NLU training data.
3 changes: 3 additions & 0 deletions data/test_nlg/test_responses.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
responses:
utter_rasa:
- text: this is utter_rasa!
42 changes: 21 additions & 21 deletions rasa/shared/importers/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def load_from_dict(
)
]

return E2EImporter(RetrievalModelsDataImporter(CombinedDataImporter(importers)))
return E2EImporter(ResponsesSyncImporter(CombinedDataImporter(importers)))

@staticmethod
def _importer_from_dict(
Expand Down Expand Up @@ -293,8 +293,8 @@ async def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
)


class RetrievalModelsDataImporter(TrainingDataImporter):
"""A `TrainingDataImporter` that sets up the data for training retrieval models.
class ResponsesSyncImporter(TrainingDataImporter):
"""Importer that syncs `responses` between Domain and NLU training data.
Synchronizes response templates between Domain and NLU
and adds retrieval intent properties from the NLU training data
Expand All @@ -314,19 +314,18 @@ async def get_domain(self) -> Domain:
existing_domain = await self._importer.get_domain()
existing_nlu_data = await self._importer.get_nlu_data()

# Check if NLU data has any retrieval intents, if yes
# add corresponding retrieval actions with `utter_` prefix automatically
# to an empty domain, update the properties of existing retrieval intents
# and merge response templates
if existing_nlu_data.retrieval_intents:

domain_with_retrieval_intents = self._get_domain_with_retrieval_intents(
existing_nlu_data.retrieval_intents,
existing_nlu_data.responses,
existing_domain,
)
# Merge responses from NLU data with responses in the domain.
# If NLU data has any retrieval intents, then add corresponding
# retrieval actions with `utter_` prefix automatically to the
# final domain, update the properties of existing retrieval intents.
domain_with_retrieval_intents = self._get_domain_with_retrieval_intents(
existing_nlu_data.retrieval_intents,
existing_nlu_data.responses,
existing_domain,
)

existing_domain = existing_domain.merge(domain_with_retrieval_intents)
existing_domain = existing_domain.merge(domain_with_retrieval_intents)
existing_domain.check_missing_templates()

return existing_domain

Expand All @@ -351,16 +350,19 @@ def _get_domain_with_retrieval_intents(
response_templates: Dict[Text, List[Dict[Text, Any]]],
existing_domain: Domain,
) -> Domain:
"""Construct a domain consisting of retrieval intents listed in the NLU training data.
"""Construct a domain consisting of retrieval intents.
The result domain will have retrieval intents that are listed
in the NLU training data.
Args:
retrieval_intents: Set of retrieval intents defined in NLU training data.
response_templates: Response templates defined in NLU training data.
existing_domain: Domain which is already loaded from the domain file.
Returns: Domain with retrieval actions added to action names and properties
for retrieval intents updated.
for retrieval intents updated.
"""

# Get all the properties already defined
# for each retrieval intent in other domains
# and add the retrieval intent property to them
Expand All @@ -379,9 +381,7 @@ def _get_domain_with_retrieval_intents(
[],
[],
response_templates,
RetrievalModelsDataImporter._construct_retrieval_action_names(
retrieval_intents
),
ResponsesSyncImporter._construct_retrieval_action_names(retrieval_intents),
{},
)

Expand Down
1 change: 0 additions & 1 deletion rasa/shared/importers/rasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ async def get_domain(self) -> Domain:
return domain
try:
domain = Domain.load(self._domain_path)
domain.check_missing_templates()
except InvalidDomain as e:
rasa.shared.utils.io.raise_warning(
f"Loading domain from '{self._domain_path}' failed. Using "
Expand Down
32 changes: 27 additions & 5 deletions tests/shared/importers/test_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
NluDataImporter,
CoreDataImporter,
E2EImporter,
RetrievalModelsDataImporter,
ResponsesSyncImporter,
)
from rasa.shared.importers.multi_project import MultiProjectImporter
from rasa.shared.importers.rasa import RasaFileImporter
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_load_from_dict(
)

assert isinstance(actual, E2EImporter)
assert isinstance(actual.importer, RetrievalModelsDataImporter)
assert isinstance(actual.importer, ResponsesSyncImporter)

actual_importers = [i.__class__ for i in actual.importer._importer._importers]
assert actual_importers == expected
Expand All @@ -128,7 +128,7 @@ def test_load_from_config(tmpdir: Path):

importer = TrainingDataImporter.load_from_config(config_path)
assert isinstance(importer, E2EImporter)
assert isinstance(importer.importer, RetrievalModelsDataImporter)
assert isinstance(importer.importer, ResponsesSyncImporter)
assert isinstance(importer.importer._importer._importers[0], MultiProjectImporter)


Expand All @@ -140,7 +140,7 @@ async def test_nlu_only(project: Text):
)

assert isinstance(actual, NluDataImporter)
assert isinstance(actual._importer, RetrievalModelsDataImporter)
assert isinstance(actual._importer, ResponsesSyncImporter)

stories = await actual.get_stories()
assert stories.is_empty()
Expand Down Expand Up @@ -350,7 +350,7 @@ async def test_nlu_data_domain_sync_with_retrieval_intents(project: Text):
nlu_importer = NluDataImporter(base_data_importer)
core_importer = CoreDataImporter(base_data_importer)

importer = RetrievalModelsDataImporter(
importer = ResponsesSyncImporter(
CombinedDataImporter([nlu_importer, core_importer])
)
domain = await importer.get_domain()
Expand All @@ -361,3 +361,25 @@ async def test_nlu_data_domain_sync_with_retrieval_intents(project: Text):
assert domain.retrieval_intent_templates == nlu_data.responses
assert domain.templates != nlu_data.responses
assert "utter_chitchat" in domain.action_names


async def test_nlu_data_domain_sync_responses(project: Text):
config_path = os.path.join(project, DEFAULT_CONFIG_PATH)
domain_path = "data/test_domains/default.yml"
data_paths = ["data/test_nlg/test_responses.yml"]

base_data_importer = TrainingDataImporter.load_from_dict(
{}, config_path, domain_path, data_paths
)

nlu_importer = NluDataImporter(base_data_importer)
core_importer = CoreDataImporter(base_data_importer)

importer = ResponsesSyncImporter(
CombinedDataImporter([nlu_importer, core_importer])
)
with pytest.warns(None):
domain = await importer.get_domain()

# Responses were sync between "test_responses.yml" and the "domain.yml"
assert "utter_rasa" in domain.templates.keys()

0 comments on commit 75b8116

Please sign in to comment.