This repo contains our codes for the paper "No Parameters Left Behind: Sensitivity Guided Adaptive Learning Rate for Training Large Transformer Models" (ICLR 2022).
- Pull and run docker
pytorch/pytorch:1.5.1-cuda10.1-cudnn7-devel
- Install requirements
pip install -r requirements.txt
- Download data and pre-trained models
./download.sh
Please refer to this link for details on the GLUE benchmark. - Preprocess data
./experiments/glue/prepro.sh
For the most updated data processing details, please refer to the mt-dnn repo.
We provide an example script for fine-tuning a pre-trained BERT-base model on MNLI using Adamax-SAGE:
./scripts/train_mnli_usadamax.sh GPUID
A few notices:
-
learning_rate
andbeta3
are two of the most important hyper-parameters.learning_rate
that works well for Adamax/AdamW-SAGE is usually 2 to 5 times larger than that works well for Adamax/AdamW, depending on the tasks.beta3
that works well for Adamax/AdamW-SAGE is usually in the range of 0.6 and 0.9, depending on the tasks. -
To use AdamW-SAGE, set argument
--optim=usadamw
. The current codebase only contains the implementation of Adamax-SAGE and AdamW-SAGE. Please refer tomodule/bert_optim.py
for details. Please refer to our paper for integrating SAGE on other optimizers. -
To fine-tune a pre-trained RoBERTa-base model, set arguments
--init_checkpoint
to the model path and set--encoder_type
to 2. Other supported models are listed inpretrained_models.py
. -
To fine-tune on other tasks, set arguments
--train_datasets
and--test_datasets
to the corresponding task names.
@inproceedings{
liang2022no,
title={No Parameters Left Behind: Sensitivity Guided Adaptive Learning Rate for Training Large Transformer Models},
author={Chen Liang and Haoming Jiang and Simiao Zuo and Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen and Tuo Zhao},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=cuvga_CiVND}
}
For help or issues related to this package, please submit a GitHub issue. For personal questions related to this paper, please contact Chen Liang (cliang73@gatech.edu).