Skip to content

Code for "LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding", ACL 2024

License

Notifications You must be signed in to change notification settings

facebookresearch/LayerSkip

Repository files navigation

LayerSkip

License: CC BY-NC YouTube arXiv alphaXiv

This code base is the implementation of LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding.

News

Getting Started

  • Clone repo:
$ git clone git@github.com:facebookresearch/LayerSkip.git
$ cd LayerSkip
  • Setup environment:
$ conda create --name layer_skip python=3.10
$ conda activate layer_skip

$ pip install -r requirements.txt

In order to access each model:

  1. Visit the model's corresponding link above, make sure you are logged on the HuggingFace website with your account.
  2. Fill the request form and submit it. Approval may take a while and you should receive an email notification to notify you that permission to the model is granted.
  3. Follow the steps here to obtain a user access token.
  4. In the command-line run huggingface-cli login, and you will be prompted to provide the token you have obtained in Step 3.

Once you run those steps, the commands below to run the LayerSkip checkpoints should work.

Generate

To run one of our models in interactive mode using regular autoregressive decoding:

$ torchrun generate.py --model facebook/layerskip-llama2-7B \
    --sample True \
    --max_steps 512

In order to observe speedup, you need to use self-speculative decoding to generate tokens, and specify --exit_layer, the layer the draft stage to exit at, and --num_speculations, the number of draft tokens:

$ torchrun generate.py --model facebook/layerskip-llama2-7B \
    --sample True \
    --max_steps 512 \
    --generation_strategy self_speculative \
    --exit_layer 8 \
    --num_speculations 6

Tips:

  • You may change --model to any HuggingFace model but in order to observe speedup with self-speculative decoding, use a model trained using the LayerSkip recipe, such as those we have open sourced on HuggingFace.
  • By default we enable sampling. You may change the sampling behaviour using the --sample, --temperature, --top_p, and --top_k arguments.
  • You may run python generate.py --help for details on different command-line arguments.

Benchmark

To benchmark on a dataset:

$ torchrun benchmark.py --model facebook/layerskip-llama2-7B \
    --dataset cnn_dm_summarization \
    --num_samples 100 \
    --generation_strategy self_speculative \
    --exit_layer 8 \
    --num_speculations 6 \
    --output_dir ./logs

Tips:

  • You can specify different tasks by modifying the --dataset argument:
    • cnn_dm_summarization: CNN/DM Summarization
    • xsum_summarization: XSUM Summarization
    • cnn_dm_lm: CNN/DM Language Modeling (given the first few words of an article, generate the remaining article)
    • human_eval: HumanEval Coding
  • By default, the tasks run as 0-shot. You can change to any specified n-shot by specifying the --n_shot argument.
  • By default we enable sampling, while the results reported in the paper were greedy decoding without sampling. You may change the sampling behaviour using the --sample, --temperature, --top_p, and --top_k arguments.
  • You may run python benchmark.py --help for details on different command-line arguments.

Evaluate

We have integrated our generation scripts with Eleuther Language Model Evaluation Harness to enable a large number of tasks and properly post-process generated text.

$ torchrun eval.py --model facebook/layerskip-llama2-7B \
    --tasks gsm8k \
    --limit 10 \
    --generation_strategy self_speculative \
    --exit_layer 8 \
    --num_speculations 6 \
    --output_dir ./logs

Tips:

  • Note that with speculative decoding we can only obtain speedups from generation tasks (e.g., gsm8k or cnn_dailymail), while classificaton tasks, i.e., multiple choice question tasks (e.g., piqa, social_iqa) or True/False question tasks (e.g., boolq) will not lead to speedup.
  • You can specify arbitrary number of tasks supported by Eleuther Evaluation Harness using the --tasks argument. To get a list of all of possible tasks, check this link.
  • Similar to the generate.py and benchmark.py scripts, you may specify different models, datasets, and sampling parameters
  • You may run python benchmark.py --help for details on different command-line arguments.

Sweep

Our inference hyperparameters, exit_layer and num_speculations determine the speedup during inference:

  • exit_layer:
    • smaller means a faster but less accurate draft stage
    • larger means a more accurate but slower draft stage
  • num_speculations:
    • smaller means higher acceptance rate but verification stage will amortize less the draft stage
    • learger means verification stage will better amortize the draft stage but acceptance rate decreases

The optimal combination of exit_layer and num_speculations may change with the model, dataset and sampling parameters. Hence, we provided a script to sweep over a grid of different exit_layer and num_speculations:

$ torchrun sweep.py --model facebook/layerskip-llama2-7B \
    --dataset human_eval \
    --generation_strategy self_speculative \
    --num_samples 150 \
    --max_steps 256 \
    --output_dir ./logs/ \
    --sample False

This will create a CSV file in the directory specified in the --outpu_dir argument.

Tips:

  • Similar to the generate.py and benchmark.py scripts, you may specify different models, datasets, and sampling parameters
  • You may run python sweep.py --help for details on different command-line arguments.

Correctness

In order to verify that the generated tokens of our self-speculative decoding algorithm are correct, we have created a script to compare the outputs of autoregressive decoding with self-speculative decoding. Note that the outputs we can only guarantee equivalence when there is no sampling (i.e., --sample False):

$ torchrun correctness.py --model facebook/layerskip-llama2-7B \
    --dataset human_eval \
    --generation_strategy self_speculative \
    --num_speculations 6 \
    --exit_layer 4 \
    --num_samples 10 \
    --sample False \
    --output_dir ./logs

Using Docker

Kindy check DOCKER.md to setup the project using docker

Other Implementations

We also have other implementations of LayerSkip inference:

  • gpt-fast: gpt-fast is a simple and efficient pytorch-native transformer text generation. We have implemented LayerSkip in the gpt-fast codebase to enable compouding it with other optimizations such as torch.compile(), quantization, and tensor parallelism.
  • Native HuggingFace: in the model card of each of our HuggingFace models, we have provided simple code snippets that leverages HuggingFace speculative decoding capabilities using a simple trick to clone the earlier layers of the main model without cloning its weights. Although this implementation is simple and does not require implementing other functions or importing other libraries, it does not share the KV cache or execution between the draft and verification stages.

Training

Our training implementation is work-in-progress. You can check this pull request for details and discussions.

License

LayerSkip is licensed under CC-by-NC license. Refer to the LICENSE file in the top level directory.

Contributing

We welcome contributions to LayerSkip. If you are interested in contributing please see this document.

Citation

If you use LayerSkip in your research, please use the following BibTex entry:

@misc{layerskip,
    title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
    author={Mostafa Elhoushi and Akshat Shrivastava and Diana Liskovich and Basil Hosmer and Bram Wasti and Liangzhen Lai and Anas Mahmoud and Bilge Acun and Saurabh Agarwal and Ahmed Roman and Ahmed A Aly and Beidi Chen and Carole-Jean Wu},
    booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    month = aug,
    year = "2024",
    address = "Bangkok, Thailand",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2024.acl-long.681",
    doi = "10.18653/v1/2024.acl-long.681",
    pages = "12622--12642",
}