Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/158-TUS-training #169

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5a49833
Create class to represent an information need
NoB0 May 13, 2024
634b2c7
Add tests for information need
NoB0 May 13, 2024
ea422e5
Fix version rasa
NoB0 May 13, 2024
025da95
Update requirements to reduce backtracking
NoB0 May 13, 2024
3e7e45f
Update test
NoB0 May 13, 2024
e4315ec
Format with black
NoB0 May 13, 2024
a1a82b9
Format docstring
NoB0 May 13, 2024
3ac5192
Add dialogue state and dialogue state tracker
NoB0 May 29, 2024
b7208c5
Add tests for DST
NoB0 May 29, 2024
1d2dacd
Black
NoB0 May 29, 2024
41e6fba
Implement transformer model
NoB0 May 29, 2024
853cfbb
Fix pre-commit
NoB0 May 29, 2024
11247e7
Merge branch 'feature/TUS-feature-handler' into feature/158-TUS-featu…
NoB0 May 29, 2024
32506a1
Merge branch 'feature/core-neural-us' into feature/158-TUS-feature-ha…
NoB0 May 29, 2024
7b8ba16
Add TUS feature handler
NoB0 May 29, 2024
9ec13e4
Merge branch 'feature/163-Add-dialogue-management-components' into fe…
NoB0 May 29, 2024
b511be1
Update tests
NoB0 May 29, 2024
7daa047
Black
NoB0 May 29, 2024
763008b
Update feature handler
NoB0 May 30, 2024
bb0f715
Merge branch 'main' into feature/158-TUS-feature-handler
NoB0 Jun 4, 2024
6901be4
Reorganize modules
NoB0 Jun 4, 2024
381fd32
Update feature handler
NoB0 Jun 25, 2024
8dcb725
Add missing requirement
NoB0 Jun 25, 2024
554b826
Add TUS
NoB0 Jun 25, 2024
03d5300
Add training script for TUS
NoB0 Jun 25, 2024
2cae500
Merge branch 'main' into feature/158-TUS-training
NoB0 Sep 26, 2024
4ee5abe
Fix tests
NoB0 Sep 26, 2024
46e299a
Quick fixes
NoB0 Sep 27, 2024
0676c21
Merge branch 'main' into feature/158-TUS-training
NoB0 Oct 1, 2024
f6c9bd4
Merge branch 'main' into feature/158-TUS-training
NoB0 Oct 2, 2024
1bdedf2
Update TUS feature handler
NoB0 Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Your should be formatted as a JSON object with two fields: constraints and reque
Constraints are represented as a dictionary where the keys are the slots, and the values are the values that the movie must have for that slot. The possible slots are:
- GENRE: Movie genre
- ACTOR: Actor starring in the movie
- KEYWORD: Keyword associated with the movie
- KEYWORDS: Keywords associated with the movie
- DIRECTOR: Director of the movie
Requests are represented as a list of slots. The possible slots are:
- PLOT: Movie plot
Expand Down
240 changes: 240 additions & 0 deletions tests/data/tus_annotated_dialogues.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
[
{
"conversation_id": "1",
"conversation": [
{
"participant": "USER",
"utterance": "Hello, recommend me an action movie.",
"dialogue_acts": [
{
"intent": "DISCLOSE",
"slot_values": [
[
"GENRE",
"action",
null,
null
]
]
}
]
},
{
"participant": "AGENT",
"utterance": "Sure, how about Mission Impossible?",
"dialogue_acts": [
{
"intent": "RECOMMEND",
"slot_values": [
[
"TITLE",
"Mission Impossible",
null,
null
]
]
}
]
},
{
"participant": "USER",
"utterance": "Sounds good. What's it about?",
"dialogue_acts": [
{
"intent": "INQUIRE",
"slot_values": [
[
"PLOT",
null,
null,
null
]
]
}
]
},
{
"participant": "AGENT",
"utterance": "A secret agent is sent on a mission to save the world.",
"dialogue_acts": [
{
"intent": "INFORM",
"slot_values": [
[
"PLOT",
"A secret agent is sent on a mission to save the world.",
null,
null
]
]
}
]
}
],
"agent": {
"id": "Agent",
"type": "AGENT"
},
"user": {
"id": "User",
"type": "USER"
},
"metadata": {
"information_need": {
"constraints": {
"GENRE": "action",
"ACTOR": "Tom Cruise"
},
"requests": [
"PLOT"
],
"target_items": [
{
"item_id": "1",
"properties": {
"GENRE": "action",
"ACTOR": "Tom Cruise",
"PLOT": "A secret agent is sent on a mission to save the world."
}
}
]
}
}
},
{
"conversation_id": "2",
"conversation": [
{
"participant": "USER",
"utterance": "I want to watch a comedy movie.",
"dialogue_acts": [
{
"intent": "DISCLOSE",
"slot_values": [
[
"GENRE",
"comedy",
null,
null
]
]
}
]
},
{
"participant": "AGENT",
"utterance": "Do you have a favorite director?",
"dialogue_acts": [
{
"intent": "ELICIT",
"slot_values": [
[
"DIRECTOR",
null,
null,
null
]
]
}
]
},
{
"participant": "USER",
"utterance": "I love Steven Spielberg",
"dialogue_acts": [
{
"intent": "DISCLOSE",
"slot_values": [
[
"DIRECTOR",
"Steven Spielberg",
null,
null
]
]
}
]
},
{
"participant": "AGENT",
"utterance": "How about The Terminal?",
"dialogue_acts": [
{
"intent": "RECOMMEND",
"slot_values": [
[
"TITLE",
"The Terminal",
null,
null
]
]
}
]
},
{
"participant": "USER",
"utterance": "What's the rating?",
"dialogue_acts": [
{
"intent": "INQUIRE",
"slot_values": [
[
"RATING",
null,
null,
null
]
]
}
]
},
{
"participant": "AGENT",
"utterance": "The Terminal has a rating of 4.5.",
"dialogue_acts": [
{
"intent": "INFORM",
"slot_values": [
[
"RATING",
"4.5",
null,
null
]
]
}
]
}
],
"agent": {
"id": "Agent",
"type": "AGENT"
},
"user": {
"id": "User",
"type": "USER"
},
"metadata": {
"information_need": {
"constraints": {
"GENRE": "Comedy",
"DIRECTOR": "Steven Spielberg"
},
"requests": [
"RATING"
],
"target_items": [
{
"item_id": "1",
"properties": {
"GENRE": "Comedy",
"DIRECTOR": "Steven Spielberg",
"RATING": 4.5
}
}
]
}
}
}
]
77 changes: 77 additions & 0 deletions tests/simulator/tus/test_tus_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Tests for preparing TUS dataset for training."""

from typing import List

import pytest

from dialoguekit.core.dialogue import Dialogue
from dialoguekit.utils.dialogue_reader import json_to_dialogues
from usersimcrs.simulator.neural.tus.tus_dataset import TUSDataset
from usersimcrs.simulator.neural.tus.tus_feature_handler import (
TUSFeatureHandler,
)

TEST_DATA_PATH = "tests/data/tus_annotated_dialogues.json"


@pytest.fixture
def dialogues() -> List[Dialogue]:
"""Loads dialogues from the test data file."""
return json_to_dialogues(TEST_DATA_PATH)


@pytest.fixture
def tus_dataset(feature_handler: TUSFeatureHandler) -> TUSDataset:
"""Returns a TUSDataset instance."""
return TUSDataset(data_path=TEST_DATA_PATH, feature_handler=feature_handler)


def test_init_failure_file_path(feature_handler: TUSFeatureHandler) -> None:
"""Tests dataset initialization failure."""
with pytest.raises(FileNotFoundError):
TUSDataset(
data_path="non_existent_file.json", feature_handler=feature_handler
)


def test_init_failure_dialogue_format(
feature_handler: TUSFeatureHandler,
) -> None:
"""Tests dataset initialization failure."""
with pytest.raises(ValueError):
TUSDataset(
data_path="tests/data/annotated_dialogues.json",
feature_handler=feature_handler,
)


def test_len(tus_dataset: TUSDataset, dialogues: List[Dialogue]) -> None:
"""Tests the length of the dataset."""
assert len(tus_dataset) == len(dialogues)


def test_process_dialogue(
tus_dataset: TUSDataset, dialogues: List[Dialogue]
) -> None:
"""Tests dialogue processing."""
dialogue = dialogues[0]

input_features = tus_dataset.process_dialogue(dialogue)

assert input_features.get("dialogue_id") == dialogue.conversation_id
assert (
len(input_features.get("input"))
== len(input_features.get("mask"))
== len(input_features.get("label"))
== len(dialogue.utterances) // 2
)
assert all(
len(input_features.get("input")[i])
== len(input_features.get("mask")[i])
== 80
for i in range(len(input_features.get("input")))
)
assert all(
len(input_features.get("label")[i]) == 40
for i in range(len(input_features.get("label")))
)
26 changes: 17 additions & 9 deletions usersimcrs/simulator/neural/core/transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
"""Encoder-only transformer model for neural user simulator."""
"""Encoder-only transformer model for neural user simulator.

Implementation inspired by PyTorch documentation and TUS's transformer model.

Sources:
https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/dca13261bbb4e9809d1a3aa521d22dd7/transformer_tutorial.ipynb#scrollTo=R8veciavth40
https://gitlab.cs.uni-duesseldorf.de/general/dsml/tus_public/-/blob/master/convlab2/policy/tus/multiwoz/transformer.py?ref_type=heads
"""

import math

Expand Down Expand Up @@ -54,7 +61,6 @@ def __init__(
nhead: int,
hidden_dim: int,
num_encoder_layers: int,
num_token: int,
dropout: float = 0.5,
) -> None:
"""Initializes a encoder-only transformer model.
Expand All @@ -69,15 +75,15 @@ def __init__(
dropout: Dropout rate. Defaults to 0.5.
"""
super(TransformerEncoderModel, self).__init__()
self.d_model = input_dim
self.d_model = hidden_dim

self.pos_encoder = PositionalEncoding(input_dim, dropout)
self.embedding = nn.Embedding(num_token, input_dim)
self.pos_encoder = PositionalEncoding(hidden_dim, dropout)
self.embedding = nn.Linear(input_dim, hidden_dim)

# Encoder layers
norm_layer = nn.LayerNorm(input_dim)
norm_layer = nn.LayerNorm(hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=input_dim,
d_model=self.d_model,
nhead=nhead,
dim_feedforward=hidden_dim,
)
Expand All @@ -87,7 +93,7 @@ def __init__(
norm=norm_layer,
)

self.linear = nn.Linear(input_dim, output_dim)
self.linear = nn.Linear(hidden_dim, output_dim)
self.softmax = nn.Softmax(dim=-1)

self.init_weights()
Expand All @@ -113,6 +119,8 @@ def forward(
"""
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.encoder(src, mask=src_mask)
src = src.permute(1, 0, 2)
output = self.encoder(src, src_key_padding_mask=src_mask)
output = self.linear(output)
output = output.permute(1, 0, 2)
return output
Loading
Loading