Skip to content

Commit

Permalink
Merge pull request #8258 from RasaHQ/additional_events
Browse files Browse the repository at this point in the history
add condition in sql/mongo tracker store in additional_events when retrieve_events_from_previous_conversation_sessions is true
  • Loading branch information
indam23 authored Mar 24, 2021
2 parents 35651b9 + 792ebdd commit 0024358
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 14 deletions.
1 change: 1 addition & 0 deletions changelog/8258.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed the bug that events from previous conversation sessions would be re-saved in the [`SQLTrackerStore`](tracker-stores.mdx#sqltrackerstore) or [`MongoTrackerStore`](tracker-stores.mdx#mongotrackerstore) when `retrieve_events_from_previous_conversation_sessions` was true.
15 changes: 10 additions & 5 deletions rasa/core/tracker_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,12 @@ def _additional_events(self, tracker: DialogueStateTracker) -> Iterator:

stored = self.conversations.find_one({"sender_id": tracker.sender_id}) or {}
all_events = self._events_from_serialized_tracker(stored)
number_events_since_last_session = len(
self._events_since_last_session_start(all_events)
)
if self.retrieve_events_from_previous_conversation_sessions:
number_events_since_last_session = len(all_events)
else:
number_events_since_last_session = len(
self._events_since_last_session_start(all_events)
)

return itertools.islice(
tracker.events, number_events_since_last_session, len(tracker.events)
Expand Down Expand Up @@ -1126,10 +1129,12 @@ def _additional_events(
self, session: "Session", tracker: DialogueStateTracker
) -> Iterator:
"""Return events from the tracker which aren't currently stored."""

number_of_events_since_last_session = self._event_query(
session, tracker.sender_id, fetch_events_from_all_sessions=False
session,
tracker.sender_id,
fetch_events_from_all_sessions=self.retrieve_events_from_previous_conversation_sessions,
).count()

return itertools.islice(
tracker.events, number_of_events_since_last_session, len(tracker.events)
)
Expand Down
9 changes: 8 additions & 1 deletion tests/core/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,18 @@ def __init__(self, *args, **kwargs):
class MockedMongoTrackerStore(MongoTrackerStore):
"""In-memory mocked version of `MongoTrackerStore`."""

def __init__(self, _domain: Domain):
def __init__(
self,
_domain: Domain,
retrieve_events_from_previous_conversation_sessions: bool = False,
) -> None:
from mongomock import MongoClient

self.db = MongoClient().rasa
self.collection = "conversations"
self.retrieve_events_from_previous_conversation_sessions = (
retrieve_events_from_previous_conversation_sessions
)

# skipcq: PYL-E1003
# Skip `MongoTrackerStore` constructor to avoid that actual Mongo connection
Expand Down
48 changes: 40 additions & 8 deletions tests/core/test_tracker_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,18 +539,34 @@ def _saved_tracker_with_multiple_session_starts(
return tracker_store.retrieve(sender_id)


def test_mongo_additional_events(default_domain: Domain):
tracker_store = MockedMongoTrackerStore(default_domain)
@pytest.mark.parametrize(
"retrieve_events_from_previous_conversation_sessions", [True, False],
)
def test_mongo_additional_events(
default_domain: Domain, retrieve_events_from_previous_conversation_sessions
):
tracker_store = MockedMongoTrackerStore(
default_domain,
retrieve_events_from_previous_conversation_sessions=retrieve_events_from_previous_conversation_sessions,
)
events, tracker = create_tracker_with_partially_saved_events(tracker_store)

# make sure only new events are returned
# noinspection PyProtectedMember
assert list(tracker_store._additional_events(tracker)) == events


def test_mongo_additional_events_with_session_start(default_domain: Domain):
@pytest.mark.parametrize(
"retrieve_events_from_previous_conversation_sessions", [True, False],
)
def test_mongo_additional_events_with_session_start(
default_domain: Domain, retrieve_events_from_previous_conversation_sessions
):
sender = "test_mongo_additional_events_with_session_start"
tracker_store = MockedMongoTrackerStore(default_domain)
tracker_store = MockedMongoTrackerStore(
default_domain,
retrieve_events_from_previous_conversation_sessions=retrieve_events_from_previous_conversation_sessions,
)
tracker = _saved_tracker_with_multiple_session_starts(tracker_store, sender)

tracker.update(UserUttered("hi2"))
Expand All @@ -562,10 +578,18 @@ def test_mongo_additional_events_with_session_start(default_domain: Domain):
assert isinstance(additional_events[0], UserUttered)


@pytest.mark.parametrize(
"retrieve_events_from_previous_conversation_sessions", [True, False],
)
# we cannot parametrise over this and the previous test due to the different ways of
# calling _additional_events()
def test_sql_additional_events(default_domain: Domain):
tracker_store = SQLTrackerStore(default_domain)
def test_sql_additional_events(
default_domain: Domain, retrieve_events_from_previous_conversation_sessions
):
tracker_store = SQLTrackerStore(
default_domain,
retrieve_events_from_previous_conversation_sessions=retrieve_events_from_previous_conversation_sessions,
)
additional_events, tracker = create_tracker_with_partially_saved_events(
tracker_store
)
Expand All @@ -579,9 +603,17 @@ def test_sql_additional_events(default_domain: Domain):
)


def test_sql_additional_events_with_session_start(default_domain: Domain):
@pytest.mark.parametrize(
"retrieve_events_from_previous_conversation_sessions", [True, False],
)
def test_sql_additional_events_with_session_start(
default_domain: Domain, retrieve_events_from_previous_conversation_sessions
):
sender = "test_sql_additional_events_with_session_start"
tracker_store = SQLTrackerStore(default_domain)
tracker_store = SQLTrackerStore(
default_domain,
retrieve_events_from_previous_conversation_sessions=retrieve_events_from_previous_conversation_sessions,
)
tracker = _saved_tracker_with_multiple_session_starts(tracker_store, sender)

tracker.update(UserUttered("hi2"), default_domain)
Expand Down

0 comments on commit 0024358

Please sign in to comment.