This repository contains the data and code for the paper "SideControl: Controlled Open-domain Dialogue Generation via Additive Side Networks" (EMNLP2021-Findings).
Under the server environment of python=3.6
and CUDA 11.1
, install the following packages:
pip install -r requirements.txt
Training the DialoGPT-SideNet on DailyDialog
full training set includes two steps.
First, pretrain a DA classifier:
python gpt2_da_sidenet.py -d dailydialog_raw -t train -f base --pretrain_clf
Second, train a SideNet (remember to replace the timestamp
and ckpt
with your own model checkpoint):
python gpt2_da_sidenet.py -d dailydialog_raw -t ft -f sidenet --timestamp 2021-05-11-07-08-10 --ckpt 2500
Training the DialoGPT-SideNet on ConvAI2
full training set includes one step:
python gpt2_kb_sidenet.py -d convai2_raw -t train -f sidenet
Get the DialoGPT-SideNet predictions on DailyDialog
full testing set (remember to replace the timestamp
and ckpt
with your own model checkpoint):
python gpt2_da_sidenet.py -d dailydialog_raw -t eval -f sidenet --timestamp 2021-05-11-07-08-10 --ckpt 2500
Get the DialoGPT-SideNet predictions on ConvAI2
full testing set (remember to replace the timestamp
and ckpt
with your own model checkpoint):
python gpt2_kb_sidenet.py -d convai2_raw -t eval -f sidenet --timestamp 2021-04-26-10-21-06 --ckpt 47839
Compute text quality metrics for DialoGPT-SideNet predictions:
python evaluation.py \
--mode sent \
--reference_file {args.model_name}_{args.dataset}_{args.flag}_{args.timestamp}/refs.json \
--output_file {args.model_name}_{args.dataset}_{args.flag}_{args.timestamp}/outs.json
Compute text controllability metrics (knowledge document control) for DialoGPT-SideNet predictions:
python evaluation.py \
--mode kb \
--reference_file {args.model_name}_{args.dataset}_{args.flag}_{args.timestamp}/refs.json \
--output_file {args.model_name}_{args.dataset}_{args.flag}_{args.timestamp}/outs.json
Compute text controllability metrics (semantic label control) for DialoGPT-SideNet predictions. First, train an independent DA classifier:
python bert_da_eval.py -d dailydialog_dis -t train -f clf
Second, compute the accuracy predicted by the independent DA classifier (remember to replace the timestamp
, ckpt
and output_file
accordingly):
python bert_da_eval.py -d dailydialog_dis -t pred -f clf \
--timestamp 2021-04-11-05-57-18 --ckpt 10000 \
--output_file {args.model_name}_{args.dataset}_{args.flag}_{args.timestamp}/outs.json
Please cite our work if you are interested.
@inproceedings{du-ji-2021-sidecontrol-controlled,
title = "{S}ide{C}ontrol: Controlled Open-domain Dialogue Generation via Additive Side Networks",
author = "Du, Wanyu and
Ji, Yangfeng",
booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021",
month = nov,
year = "2021",
address = "Punta Cana, Dominican Republic",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2021.findings-emnlp.188",
pages = "2175--2194",
}