diff --git a/changelog/8258.bugfix.md b/changelog/8258.bugfix.md new file mode 100644 index 000000000000..7f5b0efd76d2 --- /dev/null +++ b/changelog/8258.bugfix.md @@ -0,0 +1 @@ +Fixed the bug that events from previous conversation sessions would be re-saved in the [`SQLTrackerStore`](tracker-stores.mdx#sqltrackerstore) or [`MongoTrackerStore`](tracker-stores.mdx#mongotrackerstore) when `retrieve_events_from_previous_conversation_sessions` was true. diff --git a/rasa/core/tracker_store.py b/rasa/core/tracker_store.py index 8b861f49b236..c367499cedc2 100644 --- a/rasa/core/tracker_store.py +++ b/rasa/core/tracker_store.py @@ -625,9 +625,12 @@ def _additional_events(self, tracker: DialogueStateTracker) -> Iterator: stored = self.conversations.find_one({"sender_id": tracker.sender_id}) or {} all_events = self._events_from_serialized_tracker(stored) - number_events_since_last_session = len( - self._events_since_last_session_start(all_events) - ) + if self.retrieve_events_from_previous_conversation_sessions: + number_events_since_last_session = len(all_events) + else: + number_events_since_last_session = len( + self._events_since_last_session_start(all_events) + ) return itertools.islice( tracker.events, number_events_since_last_session, len(tracker.events) @@ -1126,10 +1129,12 @@ def _additional_events( self, session: "Session", tracker: DialogueStateTracker ) -> Iterator: """Return events from the tracker which aren't currently stored.""" - number_of_events_since_last_session = self._event_query( - session, tracker.sender_id, fetch_events_from_all_sessions=False + session, + tracker.sender_id, + fetch_events_from_all_sessions=self.retrieve_events_from_previous_conversation_sessions, ).count() + return itertools.islice( tracker.events, number_of_events_since_last_session, len(tracker.events) ) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 83f918d36abd..85af2840cdaa 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -84,11 +84,18 @@ def __init__(self, *args, **kwargs): class MockedMongoTrackerStore(MongoTrackerStore): """In-memory mocked version of `MongoTrackerStore`.""" - def __init__(self, _domain: Domain): + def __init__( + self, + _domain: Domain, + retrieve_events_from_previous_conversation_sessions: bool = False, + ) -> None: from mongomock import MongoClient self.db = MongoClient().rasa self.collection = "conversations" + self.retrieve_events_from_previous_conversation_sessions = ( + retrieve_events_from_previous_conversation_sessions + ) # skipcq: PYL-E1003 # Skip `MongoTrackerStore` constructor to avoid that actual Mongo connection diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index a59d2aa0931c..4f8639a24421 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -539,8 +539,16 @@ def _saved_tracker_with_multiple_session_starts( return tracker_store.retrieve(sender_id) -def test_mongo_additional_events(default_domain: Domain): - tracker_store = MockedMongoTrackerStore(default_domain) +@pytest.mark.parametrize( + "retrieve_events_from_previous_conversation_sessions", [True, False], +) +def test_mongo_additional_events( + default_domain: Domain, retrieve_events_from_previous_conversation_sessions +): + tracker_store = MockedMongoTrackerStore( + default_domain, + retrieve_events_from_previous_conversation_sessions=retrieve_events_from_previous_conversation_sessions, + ) events, tracker = create_tracker_with_partially_saved_events(tracker_store) # make sure only new events are returned @@ -548,9 +556,17 @@ def test_mongo_additional_events(default_domain: Domain): assert list(tracker_store._additional_events(tracker)) == events -def test_mongo_additional_events_with_session_start(default_domain: Domain): +@pytest.mark.parametrize( + "retrieve_events_from_previous_conversation_sessions", [True, False], +) +def test_mongo_additional_events_with_session_start( + default_domain: Domain, retrieve_events_from_previous_conversation_sessions +): sender = "test_mongo_additional_events_with_session_start" - tracker_store = MockedMongoTrackerStore(default_domain) + tracker_store = MockedMongoTrackerStore( + default_domain, + retrieve_events_from_previous_conversation_sessions=retrieve_events_from_previous_conversation_sessions, + ) tracker = _saved_tracker_with_multiple_session_starts(tracker_store, sender) tracker.update(UserUttered("hi2")) @@ -562,10 +578,18 @@ def test_mongo_additional_events_with_session_start(default_domain: Domain): assert isinstance(additional_events[0], UserUttered) +@pytest.mark.parametrize( + "retrieve_events_from_previous_conversation_sessions", [True, False], +) # we cannot parametrise over this and the previous test due to the different ways of # calling _additional_events() -def test_sql_additional_events(default_domain: Domain): - tracker_store = SQLTrackerStore(default_domain) +def test_sql_additional_events( + default_domain: Domain, retrieve_events_from_previous_conversation_sessions +): + tracker_store = SQLTrackerStore( + default_domain, + retrieve_events_from_previous_conversation_sessions=retrieve_events_from_previous_conversation_sessions, + ) additional_events, tracker = create_tracker_with_partially_saved_events( tracker_store ) @@ -579,9 +603,17 @@ def test_sql_additional_events(default_domain: Domain): ) -def test_sql_additional_events_with_session_start(default_domain: Domain): +@pytest.mark.parametrize( + "retrieve_events_from_previous_conversation_sessions", [True, False], +) +def test_sql_additional_events_with_session_start( + default_domain: Domain, retrieve_events_from_previous_conversation_sessions +): sender = "test_sql_additional_events_with_session_start" - tracker_store = SQLTrackerStore(default_domain) + tracker_store = SQLTrackerStore( + default_domain, + retrieve_events_from_previous_conversation_sessions=retrieve_events_from_previous_conversation_sessions, + ) tracker = _saved_tracker_with_multiple_session_starts(tracker_store, sender) tracker.update(UserUttered("hi2"), default_domain)