Skip to content

Commit

Permalink
ENH: LoRA support for dynamically dispatching to custom layers (#1875)
Browse files Browse the repository at this point in the history
Description

This is an experimental feature with a private API for now. If this
feature finds adoption, I will work on adding an official API.

With this PR, we allow users to register their own LoRA layer types.
This way, they can add their own support for hitherto unsupported layer
types, say nn.Conv3d or nn.LSTM. Without this PR, they can only do that
by creating a PR on PEFT with support for this new type and getting it
merged.

The custom dispatch mechanism also allows users to override existing
layer type mapping. This way, they can, for instance, provide their own
lora.Linear layer type, instead of using the one from PEFT, to adapt
nn.Linear layers.

Implementation

The implementation required only very few changes because we already
have a mechanism for dynamic dispatching for LoRA. It is currently used,
for instance, to dynamically add quantized target layers in case the
right quantization library is installed.

This existing mechanism is now extended to include user provided LoRA
layers if those were passed. These are checked first before checking the
default PEFT supported layers.

What's missing for this to become an official API?

Right now, the main reason why this cannot be an official API is the
question of how to persist the config. In the current implementation, we
add an attribute that is a mapping from target layer type to LoRA layer
type:

config._custom_modules == {CustomBaseLayer: CustomLoraLayer}

The entries of this dict are Python classes. Therefore, they cannot be
json-serialized. We could think of possible solutions how to serialize
and deserialize custom Python objects, but this is not trivial and
potentially a security risk. Thus I would only really start working on
this if the demand is sufficiently high. At that point, I would also add
a public API instead of requiring the use of a private API.

As is, users can still save and load PEFT models with custom LoRA
layers, they only need to add two lines of code to their scripts, as
documented.

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
  • Loading branch information
BenjaminBossan and stevhliu committed Jun 25, 2024
1 parent d716adf commit ef23712
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 1 deletion.
70 changes: 70 additions & 0 deletions docs/source/developer_guides/custom_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,73 @@ peft_model.print_trainable_parameters()
```python
print(peft_model.targeted_module_names)
```

## Unsupported module types

Methods like LoRA only work if the target modules are supported by PEFT. For example, it's possible to apply LoRA to `nn.Linear` and `nn.Conv2d` layers, but not, for instance, to `nn.LSTM`. If you find a layer class you want to apply PEFT to is not supported, you can:

- define a custom mapping to dynamically dispatch custom modules in LoRA
- open an [issue](https://github.com/huggingface/peft/issues) and request the feature where maintainers will implement it or guide you on how to implement it yourself if demand for this module type is sufficiently high

### Experimental support for dynamic dispatch of custom modules in LoRA

> [!WARNING]
> This feature is experimental and subject to change, depending on its reception by the community. We will introduce a public and stable API if there is significant demand for it.
PEFT supports an experimental API for custom module types for LoRA. Let's assume you have a LoRA implementation for LSTMs. Normally, you would not be able to tell PEFT to use it, even if it would theoretically work with PEFT. However, this is possible with dynamic dispatch of custom layers.

The experimental API currently looks like this:

```python
class MyLoraLSTMLayer:
...

base_model = ... # load the base model that uses LSTMs

# add the LSTM layer names to target_modules
config = LoraConfig(..., target_modules=["lstm"])
# define a mapping from base layer type to LoRA layer type
custom_module_mapping = {nn.LSTM: MyLoraLSTMLayer}
# register the new mapping
config._register_custom_module(custom_module_mapping)
# after registration, create the PEFT model
peft_model = get_peft_model(base_model, config)
# do training
```

<Tip>

When you call [`get_peft_model`], you will see a warning because PEFT does not recognize the targeted module type. In this case, you can ignore this warning.

</Tip>

By supplying a custom mapping, PEFT first checks the base model's layers against the custom mapping and dispatches to the custom LoRA layer type if there is a match. If there is no match, PEFT checks the built-in LoRA layer types for a match.

Therefore, this feature can also be used to override existing dispatch logic, e.g. if you want to use your own LoRA layer for `nn.Linear` instead of using the one provided by PEFT.

When creating your custom LoRA module, please follow the same rules as the [existing LoRA modules](https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py). Some important constraints to consider:

- The custom module should inherit from `nn.Module` and `peft.tuners.lora.layer.LoraLayer`.
- The `__init__` method of the custom module should have the positional arguments `base_layer` and `adapter_name`. After this, there are additional `**kwargs` that you are free to use or ignore.
- The learnable parameters should be stored in an `nn.ModuleDict` or `nn.ParameterDict`, where the key corresponds to the name of the specific adapter (remember that a model can have more than one adapter at a time).
- The name of these learnable parameter attributes should start with `"lora_"`, e.g. `self.lora_new_param = ...`.
- Some methods are optional, e.g. you only need to implement `merge` and `unmerge` if you want to support weight merging.

Currently, the information about the custom module does not persist when you save the model. When loading the model, you have to register the custom modules again.

```python
# saving works as always and includes the parameters of the custom modules
peft_model.save_pretrained(<model-path>)

# loading the model later:
base_model = ...
# load the LoRA config that you saved earlier
config = LoraConfig.from_pretrained(<model-path>)
# register the custom module again, the same way as the first time
custom_module_mapping = {nn.LSTM: MyLoraLSTMLayer}
config._register_custom_module(custom_module_mapping)
# pass the config instance to from_pretrained:
peft_model = PeftModel.from_pretrained(model, tmp_path / "lora-custom-module", config=config)
```

If you use this feature and find it useful, or if you encounter problems, let us know by creating an issue or a discussion on GitHub. This allows us to estimate the demand for this feature and add a public API if it is sufficiently high.
23 changes: 23 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from dataclasses import dataclass, field
from typing import Literal, Optional, Union

from torch import nn

from peft.config import PeftConfig
from peft.utils import PeftType

Expand Down Expand Up @@ -309,3 +311,24 @@ def __post_init__(self):
# convert loftq_config to dict
if self.loftq_config and not isinstance(self.loftq_config, dict):
self.loftq_config = vars(self.loftq_config)

self._custom_modules: Optional[dict[type[nn.Mmodule], type[nn.Module]]] = None

def _register_custom_module(self, mapping: dict[type[nn.Mmodule], type[nn.Module]]) -> None:
"""
Experimental API to support providing custom LoRA layers.
This API is subject to change, you should carefully read the docs before deciding to use it:
https://huggingface.co/docs/peft/developer_guides/custom_models
To register custom LoRA module types, call this method with a `mapping` argument that is a dict that maps from
the target layer type to the custom LoRA layer type. The dict can contain multiple items if you wish to target
multiple layer types. The target layer type can be any nn.Module that we currently don't support in PEFT,
whether that is an official PyTorch layer type or a custom layer type. The custom LoRA module class has to be
implemented by the user and follow the PEFT conventions for LoRA layers.
"""
if self._custom_modules is None:
self._custom_modules = {}
self._custom_modules.update(mapping)
9 changes: 8 additions & 1 deletion src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,14 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
# HQQ layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")
# possibly support user provided custom layer types using dynamic dispatch
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
in_features, out_features = None, None
warnings.warn(
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning
)

self.in_features = in_features
self.out_features = out_features
Expand Down
22 changes: 22 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def _replace_module(self, parent, child_name, new_module, child):
else child.W_q
if hasattr(child, "W_q")
else child.weight
if hasattr(child, "weight")
else next(child.parameters())
)
module.to(weight.device)

Expand Down Expand Up @@ -289,6 +291,26 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
# because the first match is always used. Therefore, the default layers should be checked last.
dispatchers = []

if lora_config._custom_modules:
# Experimental custom LoRA module support. Allows users to pass a custom mapping for unsupported layer
# types by impelementing their own LoRA layers.
def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
new_module = None

if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target

for key, custom_cls in lora_config._custom_modules.items():
if isinstance(target_base_layer, key):
new_module = custom_cls(target, adapter_name, **kwargs)
break

return new_module

dispatchers.append(dynamic_dispatch_func)

# avoid eager bnb import
if is_bnb_available():
from .bnb import dispatch_bnb_8bit
Expand Down
188 changes: 188 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3105,3 +3105,191 @@ def timed():
times_separate = logs[-3:]
time_separate = sum(times_separate) / 3
assert time_separate > time_mixed


class TestDynamicDispatch:
# These are tests for the dynamic dispatch feature for LoRA. We create a custom module and a custom LoRA layer
# that targets it.

@pytest.fixture(scope="class")
def custom_module_cls(self):
class MyModule(nn.Module):
# A custom layer that just behaves like an nn.Linear layer but is not an instance of nn.Linear. Therefore,
# it would normally fail to be targeted.
def __init__(self):
super().__init__()
self.in_features = 10
self.out_features = 20
self.weight = nn.Parameter(torch.randn(20, 10))

def forward(self, x):
return nn.functional.linear(x, self.weight)

return MyModule

@pytest.fixture(scope="class")
def custom_lora_cls(self):
from peft.tuners import lora

class MyLora(lora.Linear):
# just re-use the lora.Linear code here
pass

return MyLora

@pytest.fixture(scope="class")
def model_cls(self, custom_module_cls):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.my_module = custom_module_cls()
self.lin1 = nn.Linear(20, 2)

def forward(self, x):
x = self.relu(self.lin0(x))
x = self.relu(self.my_module(x))
x = self.lin1(x)
return x

return MyModel

def test_custom_lora_layer_used(self, custom_module_cls, custom_lora_cls, model_cls):
# check that when we register custom lora layers, they are indeed being used for the intended module
model = model_cls()
config = LoraConfig(target_modules=["lin0", "my_module", "lin1"])
config._register_custom_module({custom_module_cls: custom_lora_cls})

peft_model = get_peft_model(model, config)
assert isinstance(peft_model.base_model.model.my_module, custom_lora_cls)
assert isinstance(peft_model.base_model.model.my_module.base_layer, custom_module_cls)
# sanity check that the other lora layer types are still the default ones
assert not isinstance(peft_model.base_model.model.lin0.base_layer, custom_module_cls)
assert not isinstance(peft_model.base_model.model.lin1.base_layer, custom_module_cls)

def test_training_works(self, model_cls, custom_module_cls, custom_lora_cls):
# check that when we train with custom lora layers, they are indeed updated
model = model_cls()
config = LoraConfig(target_modules=["lin0", "my_module", "lin1"])
config._register_custom_module({custom_module_cls: custom_lora_cls})

peft_model = get_peft_model(model, config)
sd_before = peft_model.state_dict()
inputs = torch.randn(16, 10)
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-1)

for _ in range(5):
optimizer.zero_grad()
output = peft_model(inputs)
loss = output.sum() ** 2
loss.backward()
optimizer.step()

sd_after = peft_model.state_dict()
assert not torch.allclose(
sd_before["base_model.model.my_module.lora_A.default.weight"],
sd_after["base_model.model.my_module.lora_A.default.weight"],
)
assert not torch.allclose(
sd_before["base_model.model.my_module.lora_B.default.weight"],
sd_after["base_model.model.my_module.lora_B.default.weight"],
)

def test_saving_and_loading(self, custom_module_cls, custom_lora_cls, model_cls, tmp_path):
# check that we can successfully save and load the custom lora cls
torch.manual_seed(0)
model = model_cls()
config = LoraConfig(target_modules=["lin0", "my_module", "lin1"])
config._register_custom_module({custom_module_cls: custom_lora_cls})

torch.manual_seed(1)
peft_model = get_peft_model(model, config)

inputs = torch.randn(5, 10)
outputs_before = peft_model(inputs) # does not raise

sd_before = peft_model.state_dict()
peft_model.save_pretrained(tmp_path / "lora-custom-module")
del model, peft_model

torch.manual_seed(0) # same seed for base model
model = model_cls()

# custom lora mapping is not persisted at the moment, so as a workaround this is needed
config = LoraConfig.from_pretrained(tmp_path / "lora-custom-module")
config._register_custom_module({custom_module_cls: custom_lora_cls})

# different seed for adapter to ensure it is not identical just because of seed
torch.manual_seed(123)
peft_model = PeftModel.from_pretrained(model, tmp_path / "lora-custom-module", config=config)
assert isinstance(peft_model.base_model.model.my_module, custom_lora_cls)
assert isinstance(peft_model.base_model.model.my_module.base_layer, custom_module_cls)

outputs_after = peft_model(inputs) # does not raise
assert torch.allclose(outputs_before, outputs_after)

sd_after = peft_model.state_dict()
assert sd_before.keys() == sd_after.keys()
for key in sd_before.keys():
assert torch.allclose(sd_before[key], sd_after[key])

def test_override_lora_linear(self, custom_lora_cls):
# in this test, we check if users can override default PEFT behavior by supplying a custom lora class that is
# being used instead of lora.Linear
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
config = LoraConfig(task_type=TaskType.CAUSAL_LM)
config._register_custom_module({nn.Linear: custom_lora_cls})
peft_model = get_peft_model(model, config)
layers = peft_model.base_model.model.model.decoder.layers
for layer in layers:
assert isinstance(layer.self_attn.v_proj, custom_lora_cls)
assert isinstance(layer.self_attn.q_proj, custom_lora_cls)

def test_custom_lora_layer_issues_warning(self, custom_module_cls, custom_lora_cls, model_cls, recwarn):
# users will get a warning if they target a layer type that is not officially supported
model = model_cls()
config = LoraConfig(target_modules=["lin0", "my_module", "lin1"])
config._register_custom_module({custom_module_cls: custom_lora_cls})

get_peft_model(model, config)
# check warning message
msg = (
"Unsupported layer type '<class 'tests.test_custom_models.TestDynamicDispatch.custom_module_cls."
"<locals>.MyModule'>' encountered, proceed at your own risk."
)
assert str(recwarn.list[-1].message) == msg

def test_target_layer_without_in_features_out_features(self, recwarn):
# It should be possible for users to target layers even if we cannot determine in_features and out_features.
# Those are only needed to initialize the LoRA layer via update_layer, so as long as users take care of that,
# they should be good and not require those attributes to exist
from peft.tuners import lora

class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(10, 20)

class MyLora(nn.Module, lora.LoraLayer):
def __init__(self, base_layer, adapter_name, **kwargs):
super().__init__()
lora.LoraLayer.__init__(self, base_layer, **kwargs)
self._active_adapter = adapter_name

model = MyModel()
# check that in_features and out_features attributes don't exist on LSTM
assert not hasattr(model.lstm, "in_features")
assert not hasattr(model.lstm, "out_features")

config = LoraConfig(target_modules=["lstm"])
config._register_custom_module({nn.LSTM: MyLora})
peft_model = get_peft_model(model, config)

# check that custom LoRA layer is correctly applied
assert isinstance(peft_model.base_model.lstm, MyLora)
assert isinstance(peft_model.base_model.lstm.base_layer, nn.LSTM)

# we should still get a warning message
msg = "Unsupported layer type '<class 'torch.nn.modules.rnn.LSTM'>' encountered, proceed at your own risk."
assert str(recwarn.list[-1].message) == msg

0 comments on commit ef23712

Please sign in to comment.