Skip to content

Commit

Permalink
DOC: How to configure new transformers models (#1195)
Browse files Browse the repository at this point in the history
I believe that new transformers architectures could be the most common
case of users wanting to apply PEFT on a model that is not supported out
of the box. Thus I added a section specifically to help users configure
their configs for new transformers models.

As I wanted to point users to a single file that contains all the
existing transformers models, I added a new file
`src/peft/utils/constants.py`, which contains all the mappings that
previously lived in `src/peft/utils/other.py`. LMK if that makes sense.

Notes

To be absolutely backwards compatible, I re-imported the moved constants
into `other.py`. This way, if there is code that imports them directly
from there, it should continue to work.

To avoid getting a linter error for unused imports, I added those
constants to the `__all__` list in `other.py`.

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
  • Loading branch information
BenjaminBossan and younesbelkada authored Dec 5, 2023
1 parent 1a7433b commit c22a8e5
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 134 deletions.
25 changes: 24 additions & 1 deletion docs/source/developer_guides/custom_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Finally, we can use any training framework we like, or write our own fit loop, t

For a complete example, check out [this notebook](https://github.com/huggingface/peft/blob/main/examples/multilayer_perceptron/multilayer_perceptron_lora.ipynb).

## timm model
## timm models

The [timm](https://huggingface.co/docs/timm/index) library contains a large number of pretrained computer vision models.
Those can also be fine-tuned with PEFT. Let's check out how this works in practice.
Expand Down Expand Up @@ -199,3 +199,26 @@ peft_model.print_trainable_parameters()
This shows us that we only need to train less than 2% of all parameters, which is a huge efficiency gain.

For a complete example, check out [this notebook](https://github.com/huggingface/peft/blob/main/examples/image_classification/image_classification_timm_peft_lora.ipynb).

## New transformers architectures

When new popular transformers architectures are released, we do our best to quickly add them to PEFT. If you come across a transformers model that is not supported out of the box, don't worry, it will most likely still work if the config is set correctly. Specifically, you have to identify the layers that should be adapted and set them correctly when initializing the corresponding config class, e.g. `LoraConfig`. Here are some tips to help with this.

As a first step, it is a good idea is to check the existing models for inspiration. You can find them inside of [constants.py](https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py) in the PEFT repository. Often, you'll find a similar architecture that uses the same names. For example, if the new model architecture is a variation of the "mistral" model and you want to apply LoRA, you can see that the entry for "mistral" in `TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING` contains `["q_proj", "v_proj"]`. This tells you that for "mistral" models, the `target_modules` for LoRA should be `["q_proj", "v_proj"]`:

```python
from peft import LoraConfig, get_peft_model

my_mistral_model = ...
config = LoraConfig(
target_modules=["q_proj", "v_proj"],
..., # other LoRA arguments
)
peft_model = get_peft_model(my_mistral_model, config)
```

If that doesn't help, check the existing modules in your model architecture with the `named_modules` method and try to identify the attention layers, especially the key, query, and value layers. Those will often have names such as `c_attn`, `query`, `q_proj`, etc. The key layer is not always adapted, and ideally, you should check whether including it results in better performance.

Additionally, linear layers are common targets to be adapted (e.g. in [QLoRA paper](https://arxiv.org/abs/2305.14314), authors suggest to adapt them as well). Their names will often contain the strings `fc` or `dense`.

If you want to add a new model to PEFT, please create an entry in [constants.py](https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py) and open a pull request on the [repository](https://github.com/huggingface/peft/pulls). Don't forget to update the [README](https://github.com/huggingface/peft#models-support-matrix) as well.
148 changes: 148 additions & 0 deletions src/peft/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


# needed for prefix-tuning of bloom model
def bloom_model_postprocess_past_key_value(past_key_values):
past_key_values = torch.cat(past_key_values)
total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
keys = past_key_values[: total_layers // 2]
keys = keys.transpose(2, 3).reshape(
total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
)
values = past_key_values[total_layers // 2 :]
values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)

return tuple(zip(keys, values))


# needed for prefix-tuning of StarCoder models
def starcoder_model_postprocess_past_key_value(past_key_values):
result = []
for k in past_key_values:
k = k[:, :, 0]
k = k.permute([1, 2, 0, 3])
k = k.reshape(*k.shape[:-2], -1)
result.append(k)
return tuple(result)


TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
"bloom": bloom_model_postprocess_past_key_value,
"gpt_bigcode": starcoder_model_postprocess_past_key_value,
}


TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"],
"mt5": ["q", "v"],
"bart": ["q_proj", "v_proj"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"blip-2": ["q", "v", "q_proj", "v_proj"],
"opt": ["q_proj", "v_proj"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "value"],
"xlm-roberta": ["query", "value"],
"electra": ["query", "value"],
"deberta-v2": ["query_proj", "value_proj"],
"deberta": ["in_proj"],
"layoutlm": ["query", "value"],
"llama": ["q_proj", "v_proj"],
"chatglm": ["query_key_value"],
"gpt_bigcode": ["c_attn"],
"mpt": ["Wqkv"],
"RefinedWebModel": ["query_key_value"],
"RefinedWeb": ["query_key_value"],
"falcon": ["query_key_value"],
"btlm": ["c_proj", "c_attn"],
"codegen": ["qkv_proj"],
"mistral": ["q_proj", "v_proj"],
"stablelm": ["q_proj", "v_proj"],
"phi": ["Wqkv", "out_proj", "fc1", "fc2"],
}

TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING = {
"t5": ["k", "v", "wo"],
"mt5": ["k", "v", "wi_1"],
"gpt2": ["c_attn", "mlp.c_proj"],
"bloom": ["query_key_value", "mlp.dense_4h_to_h"],
"roberta": ["key", "value", "output.dense"],
"opt": ["q_proj", "k_proj", "fc2"],
"gptj": ["q_proj", "v_proj", "fc_out"],
"gpt_neox": ["query_key_value", "dense_4h_to_h"],
"gpt_neo": ["q_proj", "v_proj", "c_proj"],
"bart": ["q_proj", "v_proj", "fc2"],
"gpt_bigcode": ["c_attn", "mlp.c_proj"],
"llama": ["k_proj", "v_proj", "down_proj"],
"bert": ["key", "value", "output.dense"],
"deberta-v2": ["key_proj", "value_proj", "output.dense"],
"deberta": ["in_proj", "output.dense"],
"RefinedWebModel": ["query_key_value", "dense_4h_to_h"],
"RefinedWeb": ["query_key_value", "dense_4h_to_h"],
"falcon": ["query_key_value", "dense_4h_to_h"],
}

TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING = {
"t5": ["wo"],
"mt5": [],
"gpt2": ["mlp.c_proj"],
"bloom": ["mlp.dense_4h_to_h"],
"roberta": ["output.dense"],
"opt": ["fc2"],
"gptj": ["fc_out"],
"gpt_neox": ["dense_4h_to_h"],
"gpt_neo": ["c_proj"],
"bart": ["fc2"],
"gpt_bigcode": ["mlp.c_proj"],
"llama": ["down_proj"],
"bert": ["output.dense"],
"deberta-v2": ["output.dense"],
"deberta": ["output.dense"],
"RefinedWeb": ["dense_4h_to_h"],
"RefinedWebModel": ["dense_4h_to_h"],
"falcon": ["dense_4h_to_h"],
}

TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "k", "v", "o", "wi", "wo"],
"mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
"bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"llama": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "key", "value", "dense"],
# "xlm-roberta": ["query", "value"],
# "electra": ["query", "value"],
"deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"],
"gpt_bigcode": ["c_attn"],
"deberta": ["in_proj"],
# "layoutlm": ["query", "value"],
}

COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]
WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
CONFIG_NAME = "adapter_config.json"
EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"]
163 changes: 30 additions & 133 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,36 @@
from safetensors.torch import storage_ptr, storage_size

from ..import_utils import is_auto_gptq_available, is_torch_tpu_available
from .constants import (
COMMON_LAYERS_PATTERN,
CONFIG_NAME,
EMBEDDING_LAYER_NAMES,
SAFETENSORS_WEIGHTS_NAME,
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
WEIGHTS_NAME,
bloom_model_postprocess_past_key_value,
starcoder_model_postprocess_past_key_value,
)


__all__ = [
"COMMON_LAYERS_PATTERN",
"CONFIG_NAME",
"EMBEDDING_LAYER_NAMES",
"SAFETENSORS_WEIGHTS_NAME",
"TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
"WEIGHTS_NAME",
"bloom_model_postprocess_past_key_value",
"starcoder_model_postprocess_past_key_value",
]


# Get current device name based on available devices
Expand All @@ -39,31 +69,6 @@ def infer_device():
return torch_device


# needed for prefix-tuning of bloom model
def bloom_model_postprocess_past_key_value(past_key_values):
past_key_values = torch.cat(past_key_values)
total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
keys = past_key_values[: total_layers // 2]
keys = keys.transpose(2, 3).reshape(
total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
)
values = past_key_values[total_layers // 2 :]
values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)

return tuple(zip(keys, values))


# needed for prefix-tuning of StarCoder models
def starcoder_model_postprocess_past_key_value(past_key_values):
result = []
for k in past_key_values:
k = k[:, :, 0]
k = k.permute([1, 2, 0, 3])
k = k.reshape(*k.shape[:-2], -1)
result.append(k)
return tuple(result)


def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
r"""
Note this method only works for `transformers` models.
Expand Down Expand Up @@ -476,111 +481,3 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
unique_id = storage_ptr(tensor)

return tensor.device, unique_id, storage_size(tensor)


TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"],
"mt5": ["q", "v"],
"bart": ["q_proj", "v_proj"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"blip-2": ["q", "v", "q_proj", "v_proj"],
"opt": ["q_proj", "v_proj"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "value"],
"xlm-roberta": ["query", "value"],
"electra": ["query", "value"],
"deberta-v2": ["query_proj", "value_proj"],
"deberta": ["in_proj"],
"layoutlm": ["query", "value"],
"llama": ["q_proj", "v_proj"],
"chatglm": ["query_key_value"],
"gpt_bigcode": ["c_attn"],
"mpt": ["Wqkv"],
"RefinedWebModel": ["query_key_value"],
"RefinedWeb": ["query_key_value"],
"falcon": ["query_key_value"],
"btlm": ["c_proj", "c_attn"],
"codegen": ["qkv_proj"],
"mistral": ["q_proj", "v_proj"],
"stablelm": ["q_proj", "v_proj"],
"phi": ["Wqkv", "out_proj", "fc1", "fc2"],
}

TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING = {
"t5": ["k", "v", "wo"],
"mt5": ["k", "v", "wi_1"],
"gpt2": ["c_attn", "mlp.c_proj"],
"bloom": ["query_key_value", "mlp.dense_4h_to_h"],
"roberta": ["key", "value", "output.dense"],
"opt": ["q_proj", "k_proj", "fc2"],
"gptj": ["q_proj", "v_proj", "fc_out"],
"gpt_neox": ["query_key_value", "dense_4h_to_h"],
"gpt_neo": ["q_proj", "v_proj", "c_proj"],
"bart": ["q_proj", "v_proj", "fc2"],
"gpt_bigcode": ["c_attn", "mlp.c_proj"],
"llama": ["k_proj", "v_proj", "down_proj"],
"bert": ["key", "value", "output.dense"],
"deberta-v2": ["key_proj", "value_proj", "output.dense"],
"deberta": ["in_proj", "output.dense"],
"RefinedWebModel": ["query_key_value", "dense_4h_to_h"],
"RefinedWeb": ["query_key_value", "dense_4h_to_h"],
"falcon": ["query_key_value", "dense_4h_to_h"],
}

TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING = {
"t5": ["wo"],
"mt5": [],
"gpt2": ["mlp.c_proj"],
"bloom": ["mlp.dense_4h_to_h"],
"roberta": ["output.dense"],
"opt": ["fc2"],
"gptj": ["fc_out"],
"gpt_neox": ["dense_4h_to_h"],
"gpt_neo": ["c_proj"],
"bart": ["fc2"],
"gpt_bigcode": ["mlp.c_proj"],
"llama": ["down_proj"],
"bert": ["output.dense"],
"deberta-v2": ["output.dense"],
"deberta": ["output.dense"],
"RefinedWeb": ["dense_4h_to_h"],
"RefinedWebModel": ["dense_4h_to_h"],
"falcon": ["dense_4h_to_h"],
}

COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]

TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "k", "v", "o", "wi", "wo"],
"mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
"bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"llama": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "key", "value", "dense"],
# "xlm-roberta": ["query", "value"],
# "electra": ["query", "value"],
"deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"],
"gpt_bigcode": ["c_attn"],
"deberta": ["in_proj"],
# "layoutlm": ["query", "value"],
}

TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
"bloom": bloom_model_postprocess_past_key_value,
"gpt_bigcode": starcoder_model_postprocess_past_key_value,
}

WEIGHTS_NAME = "adapter_model.bin"
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
CONFIG_NAME = "adapter_config.json"
EMBEDDING_LAYER_NAMES = ["embed_tokens", "lm_head"]

0 comments on commit c22a8e5

Please sign in to comment.