Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dialogue state tracker #247

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions dialoguekit/dialogue_manager/dialogue_state_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""A module for tracking the state of a dialogue."""

from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List

from dialoguekit.core.annotated_utterance import AnnotatedUtterance
from dialoguekit.core.annotation import Annotation
from dialoguekit.participant.participant import DialogueParticipant


@dataclass
class DialogueState:
"""A class to represent the state of a dialogue."""

history: List[AnnotatedUtterance] = field(default_factory=list)
last_user_intent: str = None
slots: Dict[str, List[Annotation]] = field(
default_factory=lambda: defaultdict(list)
)
turn_count: int = 0


class DialogueStateTracker:
def __init__(self) -> None:
"""Initializes the dialogue state tracker."""
self._dialogue_state = DialogueState()

def get_state(self) -> DialogueState:
"""Returns the current state of the dialogue.

Returns:
The current state of the dialogue.
"""
return self._dialogue_state

def update(self, annotated_utterance: AnnotatedUtterance) -> None:
"""Updates the dialogue state with the annotated utterance.

Args:
annotated_utterance: The annotated utterance.
"""
self._dialogue_state.history.append(annotated_utterance)
if annotated_utterance.participant is not DialogueParticipant.USER:
return

self._dialogue_state.last_user_intent = annotated_utterance.intent

for annotation in annotated_utterance.annotations:
self._dialogue_state.slots[annotation.slot].append(annotation)

self._dialogue_state.turn_count += 1
114 changes: 114 additions & 0 deletions tests/dialogue_manager/test_dialogue_state_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Tests for the DialogueStateTracker class."""


import pytest

from dialoguekit.core.intent import Intent
from dialoguekit.dialogue_manager.dialogue_state_tracker import (
AnnotatedUtterance,
Annotation,
DialogueParticipant,
DialogueStateTracker,
)


@pytest.fixture
def annotated_utterance() -> AnnotatedUtterance:
"""Return an annotated utterance."""
return AnnotatedUtterance(
"Hello",
DialogueParticipant.USER,
intent=Intent("greeting"),
annotations=[Annotation("name", "John")],
)


def test_initial_state() -> None:
"""Test that the initial state is correct."""
tracker = DialogueStateTracker()
state = tracker.get_state()
assert state.history == []
assert state.last_user_intent is None
assert state.slots == {}
assert state.turn_count == 0


def test_agent_participant(annotated_utterance: AnnotatedUtterance) -> None:
"""Test that the agent participant is updated when the user utterance
contains annotations.

Args:
annotated_utterance: Annotated utterance.
"""
tracker = DialogueStateTracker()
agent_utterance = AnnotatedUtterance(
"Hi, how can I assist you?",
DialogueParticipant.AGENT,
intent=Intent("offer_help"),
annotations=[],
)

tracker.update(annotated_utterance)
assert tracker.get_state().last_user_intent == Intent("greeting")
assert tracker.get_state().turn_count == 1

tracker.update(agent_utterance)
assert tracker.get_state().last_user_intent == Intent("greeting")
assert tracker.get_state().turn_count == 1
assert tracker.get_state().history[-1] == agent_utterance


def test_update_history(annotated_utterance: AnnotatedUtterance) -> None:
"""Test that the history is updated when the user utterance contains
annotations.

Args:
annotated_utterance: Annotated utterance.
"""
tracker = DialogueStateTracker()
tracker.update(annotated_utterance)
assert tracker.get_state().history == [annotated_utterance]


def test_update_intent(annotated_utterance: AnnotatedUtterance) -> None:
"""Test that the last user intent is updated when the user utterance
contains annotations.

Args:
annotated_utterance: Annotated utterance.
"""
tracker = DialogueStateTracker()
tracker.update(annotated_utterance)
assert tracker.get_state().last_user_intent == Intent("greeting")


def test_update_slots(annotated_utterance: AnnotatedUtterance) -> None:
"""Test that the slots are updated when the user utterance contains
annotations.

Args:
annotated_utterance: Annotated utterance.
"""
tracker = DialogueStateTracker()
tracker.update(annotated_utterance)
assert tracker.get_state().slots == {"name": [Annotation("name", "John")]}


def test_turn_count(annotated_utterance: AnnotatedUtterance) -> None:
"""Test that the turn count is incremented when the user and agent have
both acted.

Args:
annotated_utterance: Annotated utterance.
"""
tracker = DialogueStateTracker()

annotated_utterance_2 = AnnotatedUtterance(
"How are you?",
DialogueParticipant.USER,
Intent("ask_health"),
annotations=[],
)
tracker.update(annotated_utterance)
tracker.update(annotated_utterance_2)
assert tracker.get_state().turn_count == 2