This is the official repo for the Stable Alignment project. We aim to provide a RLHF alternative which is superior in alignment performance, highly-efficient in data learning, and easy to deploy in scaled-up settings. Instead of training an extra reward model that can be gamed during optimization, we directly train on the recorded interaction data in simulated social games. We find high-quality data + reliable algorithm is the secret recipe for stable alignment learning.
The repo contains:
- The code for running social simulation in Sandbox.
- The 169K interaction data used for alignment training.
- The code for training with stable alignment.
- The download for So(cially)-Good Language Model.
Life is a game. Play by your rules!
# install development environment
pip install -r requirements.txt
# install dependencies for package re-distribution
pip install -e .
- Initial data is already stored at
assets/hh-rlhf/labeled_prior.jsonl
(with Git LFS). - After a round of simulation, the simulated interaction data and metrics will be saved at
data/cache/world_<world_id>/
. - Place your OpenAI API key in
.env
inside the project root folder.
Navigate to the project root folder and run simulation with customized settings:
python stable_alignment/simulation.py \
-model_type 'text-davinci-002' \
-obs_model_type 'gpt-3.5-turbo' \
-world_id 1 \
-init_setting 'all_bad' \
-n_round '2' \
-size '4' \
-dataset_name 'hh-rlhf'
We present an example simulation result in assets/sample_world
. It is simulated with 100 text-davinci-003 based social agents and ChatGPT based observer agents. The simulation is run for 50 rounds of interactions.
The alignment data used for training has been already included in the path assets/sandbox_v1.json
and assets/sandbox_v2.json
. Note that they are sampled from the full set of interaction data by a ratio of 5:1:1 for Alignment Imitation, Self-Critic, and Realignment data respectively. The full set of interaction data is available upon request.
The Statistics of Alignment Data (Full Set)
sandbox_v1.json
Data / Social Agent Type | text-davinci-002 | text-davinci-003 | ChatGPT | Total |
---|---|---|---|---|
Alignment Imitation | 9.8k | 10k | 10k | 29.8k |
Self-Critic | 17k | 20k | 20k | 57k |
Realignment | 3.3k | 3k | 0.7k | 7k |
Total | 30.1k | 33k | 30.7k | 93.8k |
sandbox_v2.json
Data / Social Agent Type | text-davinci-002 | text-davinci-003 | GPT4 | Total |
---|---|---|---|---|
Alignment Imitation | 18.2k | 10.4k | 20.2k | 48.8k |
Self-Critic | 36.3k | 18.3k | 40k | 94.6k |
Realignment | 18.2k | 3.4k | 4.0k | 25.6k |
Total | 72.7k | 32.1k | 64.2k | 169k |
torchrun --nproc_per_node=4 --master_port=36646 train_alignment.py \
--model_name_or_path "/workspace/hhh_sft" \ # path to your SFT model
--data_path "./assets/sandbox_v1.json" \ # path to the alignment data
--bf16 True \
--output_dir "/workspace/<your_output_lm_name>" \
--num_train_epochs 7 \
--per_device_train_batch_size 1 \ # batch size has to be 1 for alignment training
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "shard_grad_op auto_wrap" \ # change to "full_shard auto_wrap" if OOM
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length 360 \ # change to shorter length if OOM
--rating_scale 7 \ # the scale of the ratings. 7 for 1-7, 10 for 1-10, etc.
--margin 10 \ # constant, see the paper
--max_flow False \ # mean or max for the penalty
--ratio 0.2 \ # control the ratio of the penalty
--num_comp 3
We have released our models on huggingface! 🤗
Released models include:
-
better-base
, base model trained on LLaMA with AlpacaDataCleaned which is the fixed Alpaca instruction tuning dataset, and codealpaca which is the code pretraining dataset. -
hh-rlhf-sft
, supervised fine-tuned model onbetter-base
with the socially aligned demonstrations in Anthropic HH-RLHF dataset (theaccepted
samples in the dataset). -
socially-good-lm
, socially aligned language model trained onhh-rlhf-sft
with the stable alignment method.
After you download the model, you can run inference with the following command:
python stable_alignment/run_inference.py \
--model_path './models/socially-good-lm' \
--device 'cuda:0'
Please cite our paper if you use the data or code in this repo:
@misc{liu2023sociallyaligned,
title={Training Socially Aligned Language Models in Simulated Human Society},
author={Ruibo Liu and Ruixin Yang and Chenyan Jia and Ge Zhang and Denny Zhou and Andrew M. Dai and Diyi Yang and Soroush Vosoughi},
year={2023},
eprint={2305.16960},
archivePrefix={arXiv},
primaryClass={cs.CL}
}