Skip to content

HLT-MAIA/lightning-convai

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

State-of-the-art Conversational AI

This code is based on the transfer-learning-conv-ai repo from Hugging Face. Please check the accompanying blog post How to build a State-of-the-Art Conversational AI with Transfer Learning.

The major difference is that we use PyTorch Lightning instead of Ignite and a more "up to date" version of Transformers. We also made an effort to make everything well documented and "easy" to understand.

Presentation Slides

Model Architecture

architecture

Our model is built on top of a pretrained GPT2 model and its is trained in a multi-task setting where we minimize the following losses:

  • Language modeling: we project the hidden-state on the word embedding matrix to get logits and apply a cross-entropy loss on the portion of the target corresponding to the gold reply (green labels on the above figure).
  • Next-sentence prediction: we pass the hidden-state of the last token (the end-of-sequence token) through a linear layer to get a score and apply a cross-entropy loss to classify correctly a gold answer among distractors.

Install:

virtualenv -p python3.6 convai-env
source convai-env/bin/activate

git clone https://github.com/HLT-MAIA/lightning-convai
cd lightning-convai
pip install -r requirements.txt

Command Line Interface:

Train:

To set up your training you have to define your model configs. Take a look at the example.yaml in the configs folder, where all hyperparameters are briefly described.

After defining your hyperparameter run the following command:

python cli.py train -f configs/example.yaml

Monitor training with Tensorboard:

Launch tensorboard with:

tensorboard --logdir="experiments/"

Test:

To test your model ability to rank candidate answers and reply to user questions just run the following command:

python cli.py test --experiment experiments/{experiment_id}/ --test_set data/personachat_val.json

where experiment_id is the name of the experiment folder containing the model you want to test.

Options:
  --experiment PATH    Path to the experiment folder containing the checkpoint
                       we want to interact with.  [required]

  --test_set PATH      Path to the json file containing the testset.
                       [required]

  --cuda / --cpu       Flag that either runs inference on cuda or in cpu.
                       [default: True]

  --seed INTEGER       Seed value used during inference. This influences
                       results only when using sampling.

  --sample / --search  Flag that either runs Nucleus-Sampling or Beam search.
                       [default: True]

  --top_p FLOAT        Nucleus filtering (top-p) before sampling (<=0.0: no
                       filtering)

  --temperature FLOAT  Use temperature to decrease the sensitivity to low
                       probability candidates when sampling.

  --num_beams INTEGER  Number of beams during search.
  --to_json TEXT       Creates and exports model predictions to a JSON file.
                       [default: False]

  --help               Show this message and exit.

Interact:

Fun command where we can interact with with a trained model that impersonates a Vegan that likes cooking and radical activities such as sky-diving.

python cli.py interact --experiment experiments/{experiment_id}/

Benchmarks:

Training with the example.yaml config should result in the following:

Metric GPT2 DialoGPT-small
Hits@1↑ 0.8023 0.8231
Hits@5↑ 0.9721 0.9771
Hits@10↑ 0.9948 0.9960
BLEU↑ 2.7799 2.9633
TER↓ 1.0497 1.0528
BERTScore↑ 0.8548 0.8548

Download DialoGPT2-small trained with PersonaChat:

cd experiments
wget https://unbabel-experimental-models.s3.amazonaws.com/maia/persona/dialogpt2-small.zip
unzip dialogpt2-small.zip

Test the model:

python cli.py test --experiment experiments/dialogpt2-small/ --test_set data/personachat_val.json --to_json

Code Style:

All the code follows the same style we use Black.

About

State-of-the-art Conversational AI using GPT2

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages