Skip to content

Latest commit

 

History

History
139 lines (118 loc) · 6.05 KB

README.md

File metadata and controls

139 lines (118 loc) · 6.05 KB

DPO-ST: DPO-augmented Self-Training

This repository contains the official code and data for our ACL 2024 paper Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning.

Introduction

Teaching small language models (e.g., T5-large) chain-of-thought reasoning by distilling from larger models like GPT-4 is shown to be effective. However, relying on such propietry large models can be both economically and computationally costly. Our paper demonstrates that small language models are capable of learning from their own generations in a self-training manner, starting with a limited amount of high-quality, human-annotated training data. Additionally, we present an efficient method for integrating external calculators during inference to boost performance.


Our approach demonstrates superior performance while minimizing the required compute cost.

About DPO-ST

DPO-augmented Self-Training is built upon the conventional self-training framework. Unlike traditional self-training framework where the pseudo-labels are generated by the SFT models, we add an additional DPO step in each self-training iteration and make pseudo-labels from the DPO model. We empirically found that the DPO models can generate more diverse pseudo-labels with higher quality.


Model inference with external calculators

Integrating external calculators during model inference can enhance math reasoning performance. However, many previous efforts support only a batch size of 1, significantly slowing down inference speed. In this work, we present an efficient method for integrating external calculators that supports larger inference batch sizes. Specifically, we design a LogitsProcessor that modifies model's output during inference. More details about our implementation can be found at generate.py.


Inference speed-up comparison with Flan-T5-Large on a single A40 GPU.

Setup

Please follow the following steps before running our code.

  1. Use Conda to create a Python virtual environment:
conda create -n dpo-st python=3.10
conda activate dpo-st
  1. Install the Python dependencies with pip.
pip install requirements.txt
  1. Loggin to huggingface for downloading pre-trained model weights
huggingface-cli login --token "${your_hf_token}"
  1. Set the environment variable DATA_DIR and download pre-trained model weights from huggingface into DATA_DIR/hf_models. For example,
DATA_DIR='.'
huggingface-cli download meta-llama/Llama-2-7b-hf --local-dir DATA_DIR/hf_models/llama-2

We recommend using python-dotenv to define the DATA_DIR in to your .env file as this environment variable will be used in the subsequent steps.

Step 1: Warm-up

The first step of DPO-ST is to warm-up the pre-trained language model by fine-tuning it on the labeled dataset.

For Flan-T5-Large, run the following command:

ACC_CONFIG='acc_config/ddp8.yaml'
accelerate launch --config_file $ACC_CONFIG sft.py --config-name=sft-0

For Llama-2-7b, run the following command:

ACC_CONFIG='acc_config/fsdp.yaml'
accelerate launch --config_file $ACC_CONFIG sft.py --config-path=exp_config/llama --config-name=sft-0

Step 2.1: Prepare DPO training data

First, to sample pseudo-labels from the SFT model:

ARGS='+data.split="train" eval.mode="sampling" eval.sampling.max_seed=5'
torchrun --nproc_per_node 8 generate.py --config-name=sft-0 $ARGS
python3 eval_sampling.py --config-name=sft-0 $ARGS

Then, make DPO training data from the SFT model generations:

python3 utils/make_dpo_data.py --config-name=sft-0

Note that the above code is for T5 models. For Llama, add --config-path=exp_config/llama for each command.

Step 2.2: Train SFT model with DPO objective

For T5:

ACC_CONFIG='acc_config/ddp8.yaml'
accelerate launch --config_file $ACC_CONFIG dpo.py --config-name=dpo-1

For Llama:

ACC_CONFIG='acc_config/fsdp.yaml'
accelerate launch --config_file $ACC_CONFIG dpo.py --config-path=exp_config/llama --config-name=dpo-1

Step 2.3: Sampling pseudo-labels from DPO model

ARGS='+data.split="train" eval.mode="sampling" eval.sampling.max_seed=3'
torchrun --nproc_per_node 8 greedy_decode.py --config-name=dpo-1 $ARGS
python3 eval_sampling.py --config-name=dpo-1 $ARGS
python3 utils/make_rft_data.py --config-name=dpo-1

You can control the number of sampled generations per question by adjusting eval.sampling.max_seed.

Step 2.4: SFT with labeled and pseudo-labeled data

For T5:

ACC_CONFIG='acc_config/ddp8.yaml'
accelerate launch --config_file $ACC_CONFIG sft.py --config-name=sft-1

For Llama:

ACC_CONFIG='acc_config/fsdp.yaml'
accelerate launch --config_file $ACC_CONFIG sft.py --config-path=exp_config/llama --config-name=sft-1

Evaluation

CONFIG_PATH='exp_config/t5'
SPLIT='test'
torchrun --nproc_per_node 8 generate.py --config-path=$CONFIG_PATH --config-name=dpo-1 +data.split=$SPLIT
python3 eval_greedy.py --config-path=$CONFIG_PATH --config-name=dpo-1 +data.split=$SPLIT
  • CONFIG_PATH: set it to exp_config/t5 for t5 models and exp_config/llama for llama models
  • SPLIT: set it to dev for dev set results and test for test set results

Citation

If you find this paper useful, please consider citing it

@inproceedings{wang2024dpost,
      title={Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning}, 
      author={Tianduo Wang and Shichen Li and Wei Lu},
      year={2024},
      booktitle = {Proceedings of ACL},
}

Acknowledgement

This repo is largely inspired by GSM8K-ScRel and TRL. We are grateful to the authors for their brilliant work.