This repository contains example code for training the baselines of the paper lilGym: Natural Language Visual Reasoning with Reinforcement Learning. Trained models on Zenodo: link.
paper | TL;DR tweet | env code & data | website
Note: this code has been tested with PyTorch 1.12.1 and CUDA 11.2.
- Install lilgym and the dependencies by following the installation instructions.
It also includes the installation of PyTorch.
-
Clone the current repo.
-
Install Python dependencies:
cd /path/to/lilgym-baselines
pip install -r requirements.txt
Training a C3+BERT model with PPO+SF on the TowerScratch environment:
python main.py --env-name TowerScratch-v0 --env-opt tower --learn-opt scratch --algo ppo --stop-forcing --seed 1 --model c3bert --text-feat bertfix --num-processes 1 --num-steps 2048 --lr 3e-4 --entropy-coef 0.1 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 4000000 --use-gae --optim-type adam --scheduler linear --warmup-percent 0 --log-interval 1 --eval-interval 10 --log-dir ${path} --save-dir ${path} --save-interval 20 --wandb --wandb-run-name name-of-the-run
Training a ViLT model with PPO on the TowerFlipIt environment:
python main.py --env-name TowerFlipIt-v0 --env-opt tower --learn-opt flipit --algo ppo --stop-forcing --seed 1 --model vilt --num-processes 1 --num-steps 2048 --lr 3e-5 --entropy-coef 0.1 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 4000000 --use-gae --optim-type adamw --scheduler cosine --warmup-percent 0.01 --log-interval 1 --eval-interval 10 --log-dir ${path} --save-dir ${path} --save-interval 20 --wandb --wandb-run-name name-of-the-run
The RL code is based on Kostrikov, 2018. We thank the authors for open-sourcing their code.