This repository contains the dataset and the pytorch implementations of the models from the paper Less is More: Mitigate Spurious Correlations for Open-Domain Dialogue Response Generation Models by Causal Discovery.
We leverage two public dialogue corpora (ESConv and MSC) to construct a corpus annotated with direct causes of responses called CGDIALOG.
The original annotated dataset can be found in datasets/CGDIALOG
Number of Items | ESConv | MSC | Total |
Dialogues | 80 | 80 | 160 |
History-response paris | 694 | 800 | 922 |
Utterances | 2301 | 3807 | 6108 |
Utterances containing direct causes | 1347 | 1525 | 2872 |
Average token length of direct causes | 24.01 (std=16.61) | 22.22 (std=13.79) | 23.05 (std=15.20) |
The proportion of direct causes in original utterances | 0.86 (std=0.22) | 0.72 (std=0.27) | 0.79 (std=0.26) |
The dataset format is like the following.
"WorkerId": "ARH3NPT7GUFQ6",
"history": [ # dialogue history
"seeker: Hi!",
"supporter: Hello, how are you doing today?",
"seeker: Not so good. I have conspiracy theorist as a friend who is now mad at me because I told her to pull up her mask while talking to me.",
"seeker: We have been friends for 13 years",
"seeker: I am hurt and confused that she still thinks this is a game.",
"seeker: She thought Corona was fake until someone we know caught it.",
"seeker: It is like she is mad she was wrong and is taking it out and lashing out at those who have been trying to persuade her the whole time...",
"seeker: what do you think?",
"response": "It sounds like you care a lot about your friend and others. How old is your friend?",
"entities": [ # direct causes of responses that are annotated by workers.
"Not so good. I have conspiracy theorist as a friend who is now mad at me because I told her to pull up her mask while talking to me.",
"I am hurt and confused that she still thinks this is a game.",
"It is like she is mad she was wrong and is taking it out and lashing out at those who have been trying to persuade her the whole time..."
The code is based on PyTorch and HuggingFace transformers
conda create --prefix env/ python=3.6
conda activate env/
pip install -r requirements.txt
python --train_file datasets/CGDIALOG/ESConv_causal_generator_train.csv --model_name_or_path models/blenderbot_400M_distill/ --output_dir models/ESConv_causal_generator_model_new
python --train_file datasets/CGDIALOG/msc_causal_generator_train.csv --model_name_or_path models/blenderbot_400M_distill/ --output_dir models/msc_causal_generator_model_new
python --validation_file datasets/ESConv/test_dataset.json --model_name_or_path models/ESConv_causal_generator_model/ --tokenizer_name models/blenderbot_400M_distill/ --twoCondition_tc_model_name_or_path models/ESConv_classifier/ --tc_tokenizer_name models/roberta_base/ --output_dir outputs/ESConv_causal_generator_test_result
python --validation_file datasets/msc/msc_dialogue/session_4/test.json --model_name_or_path models/msc_causal_generator_model/ --tokenizer_name models/blenderbot_400M_distill/ --twoCondition_tc_model_name_or_path models/msc_classifier/ --tc_tokenizer_name models/roberta_base/ --output_dir outputs/msc_causal_generator_test_result
from select_best_response_ourModel import select_highScore_response
select_highScore_response(ourModel_file="outputs/ESConv_causal_generator_test_result", save_file="outputs/ESConv_test_result_in_testset_highestScore.json")
select_highScore_response(ourModel_file="outputs/msc_causal_generator_test_result", save_file="outputs/msc_test_result_in_testset_highestScore.json")