diff --git a/docs/source/developer_guides/custom_models.md b/docs/source/developer_guides/custom_models.md index e8ce592cca..9a0941011a 100644 --- a/docs/source/developer_guides/custom_models.md +++ b/docs/source/developer_guides/custom_models.md @@ -238,3 +238,73 @@ peft_model.print_trainable_parameters() ```python print(peft_model.targeted_module_names) ``` + +## Unsupported module types + +Methods like LoRA only work if the target modules are supported by PEFT. For example, it's possible to apply LoRA to `nn.Linear` and `nn.Conv2d` layers, but not, for instance, to `nn.LSTM`. If you find a layer class you want to apply PEFT to is not supported, you can: + + - define a custom mapping to dynamically dispatch custom modules in LoRA + - open an [issue](https://github.com/huggingface/peft/issues) and request the feature where maintainers will implement it or guide you on how to implement it yourself if demand for this module type is sufficiently high + +### Experimental support for dynamic dispatch of custom modules in LoRA + +> [!WARNING] +> This feature is experimental and subject to change, depending on its reception by the community. We will introduce a public and stable API if there is significant demand for it. + +PEFT supports an experimental API for custom module types for LoRA. Let's assume you have a LoRA implementation for LSTMs. Normally, you would not be able to tell PEFT to use it, even if it would theoretically work with PEFT. However, this is possible with dynamic dispatch of custom layers. + +The experimental API currently looks like this: + +```python +class MyLoraLSTMLayer: + ... + +base_model = ... # load the base model that uses LSTMs + +# add the LSTM layer names to target_modules +config = LoraConfig(..., target_modules=["lstm"]) +# define a mapping from base layer type to LoRA layer type +custom_module_mapping = {nn.LSTM: MyLoraLSTMLayer} +# register the new mapping +config._register_custom_module(custom_module_mapping) +# after registration, create the PEFT model +peft_model = get_peft_model(base_model, config) +# do training +``` + + + +When you call [`get_peft_model`], you will see a warning because PEFT does not recognize the targeted module type. In this case, you can ignore this warning. + + + +By supplying a custom mapping, PEFT first checks the base model's layers against the custom mapping and dispatches to the custom LoRA layer type if there is a match. If there is no match, PEFT checks the built-in LoRA layer types for a match. + +Therefore, this feature can also be used to override existing dispatch logic, e.g. if you want to use your own LoRA layer for `nn.Linear` instead of using the one provided by PEFT. + +When creating your custom LoRA module, please follow the same rules as the [existing LoRA modules](https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py). Some important constraints to consider: + +- The custom module should inherit from `nn.Module` and `peft.tuners.lora.layer.LoraLayer`. +- The `__init__` method of the custom module should have the positional arguments `base_layer` and `adapter_name`. After this, there are additional `**kwargs` that you are free to use or ignore. +- The learnable parameters should be stored in an `nn.ModuleDict` or `nn.ParameterDict`, where the key corresponds to the name of the specific adapter (remember that a model can have more than one adapter at a time). +- The name of these learnable parameter attributes should start with `"lora_"`, e.g. `self.lora_new_param = ...`. +- Some methods are optional, e.g. you only need to implement `merge` and `unmerge` if you want to support weight merging. + +Currently, the information about the custom module does not persist when you save the model. When loading the model, you have to register the custom modules again. + +```python +# saving works as always and includes the parameters of the custom modules +peft_model.save_pretrained() + +# loading the model later: +base_model = ... +# load the LoRA config that you saved earlier +config = LoraConfig.from_pretrained() +# register the custom module again, the same way as the first time +custom_module_mapping = {nn.LSTM: MyLoraLSTMLayer} +config._register_custom_module(custom_module_mapping) +# pass the config instance to from_pretrained: +peft_model = PeftModel.from_pretrained(model, tmp_path / "lora-custom-module", config=config) +``` + +If you use this feature and find it useful, or if you encounter problems, let us know by creating an issue or a discussion on GitHub. This allows us to estimate the demand for this feature and add a public API if it is sufficiently high. diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 6e2b7b7452..a94165100c 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -17,6 +17,8 @@ from dataclasses import dataclass, field from typing import Literal, Optional, Union +from torch import nn + from peft.config import PeftConfig from peft.utils import PeftType @@ -309,3 +311,24 @@ def __post_init__(self): # convert loftq_config to dict if self.loftq_config and not isinstance(self.loftq_config, dict): self.loftq_config = vars(self.loftq_config) + + self._custom_modules: Optional[dict[type[nn.Mmodule], type[nn.Module]]] = None + + def _register_custom_module(self, mapping: dict[type[nn.Mmodule], type[nn.Module]]) -> None: + """ + Experimental API to support providing custom LoRA layers. + + This API is subject to change, you should carefully read the docs before deciding to use it: + + https://huggingface.co/docs/peft/developer_guides/custom_models + + To register custom LoRA module types, call this method with a `mapping` argument that is a dict that maps from + the target layer type to the custom LoRA layer type. The dict can contain multiple items if you wish to target + multiple layer types. The target layer type can be any nn.Module that we currently don't support in PEFT, + whether that is an official PyTorch layer type or a custom layer type. The custom LoRA module class has to be + implemented by the user and follow the PEFT conventions for LoRA layers. + + """ + if self._custom_modules is None: + self._custom_modules = {} + self._custom_modules.update(mapping) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 54acb5cf16..f46f9e5f0d 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -86,7 +86,14 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: # HQQ layers in_features, out_features = base_layer.in_features, base_layer.out_features else: - raise ValueError(f"Unsupported layer type {type(base_layer)}") + # possibly support user provided custom layer types using dynamic dispatch + if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"): + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + in_features, out_features = None, None + warnings.warn( + f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning + ) self.in_features = in_features self.out_features = out_features diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 6b54bcb1ea..f3be8d95d1 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -259,6 +259,8 @@ def _replace_module(self, parent, child_name, new_module, child): else child.W_q if hasattr(child, "W_q") else child.weight + if hasattr(child, "weight") + else next(child.parameters()) ) module.to(weight.device) @@ -289,6 +291,26 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): # because the first match is always used. Therefore, the default layers should be checked last. dispatchers = [] + if lora_config._custom_modules: + # Experimental custom LoRA module support. Allows users to pass a custom mapping for unsupported layer + # types by impelementing their own LoRA layers. + def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + for key, custom_cls in lora_config._custom_modules.items(): + if isinstance(target_base_layer, key): + new_module = custom_cls(target, adapter_name, **kwargs) + break + + return new_module + + dispatchers.append(dynamic_dispatch_func) + # avoid eager bnb import if is_bnb_available(): from .bnb import dispatch_bnb_8bit diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 06eb097b36..e21405fae5 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -3105,3 +3105,191 @@ def timed(): times_separate = logs[-3:] time_separate = sum(times_separate) / 3 assert time_separate > time_mixed + + +class TestDynamicDispatch: + # These are tests for the dynamic dispatch feature for LoRA. We create a custom module and a custom LoRA layer + # that targets it. + + @pytest.fixture(scope="class") + def custom_module_cls(self): + class MyModule(nn.Module): + # A custom layer that just behaves like an nn.Linear layer but is not an instance of nn.Linear. Therefore, + # it would normally fail to be targeted. + def __init__(self): + super().__init__() + self.in_features = 10 + self.out_features = 20 + self.weight = nn.Parameter(torch.randn(20, 10)) + + def forward(self, x): + return nn.functional.linear(x, self.weight) + + return MyModule + + @pytest.fixture(scope="class") + def custom_lora_cls(self): + from peft.tuners import lora + + class MyLora(lora.Linear): + # just re-use the lora.Linear code here + pass + + return MyLora + + @pytest.fixture(scope="class") + def model_cls(self, custom_module_cls): + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.lin0 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.my_module = custom_module_cls() + self.lin1 = nn.Linear(20, 2) + + def forward(self, x): + x = self.relu(self.lin0(x)) + x = self.relu(self.my_module(x)) + x = self.lin1(x) + return x + + return MyModel + + def test_custom_lora_layer_used(self, custom_module_cls, custom_lora_cls, model_cls): + # check that when we register custom lora layers, they are indeed being used for the intended module + model = model_cls() + config = LoraConfig(target_modules=["lin0", "my_module", "lin1"]) + config._register_custom_module({custom_module_cls: custom_lora_cls}) + + peft_model = get_peft_model(model, config) + assert isinstance(peft_model.base_model.model.my_module, custom_lora_cls) + assert isinstance(peft_model.base_model.model.my_module.base_layer, custom_module_cls) + # sanity check that the other lora layer types are still the default ones + assert not isinstance(peft_model.base_model.model.lin0.base_layer, custom_module_cls) + assert not isinstance(peft_model.base_model.model.lin1.base_layer, custom_module_cls) + + def test_training_works(self, model_cls, custom_module_cls, custom_lora_cls): + # check that when we train with custom lora layers, they are indeed updated + model = model_cls() + config = LoraConfig(target_modules=["lin0", "my_module", "lin1"]) + config._register_custom_module({custom_module_cls: custom_lora_cls}) + + peft_model = get_peft_model(model, config) + sd_before = peft_model.state_dict() + inputs = torch.randn(16, 10) + optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-1) + + for _ in range(5): + optimizer.zero_grad() + output = peft_model(inputs) + loss = output.sum() ** 2 + loss.backward() + optimizer.step() + + sd_after = peft_model.state_dict() + assert not torch.allclose( + sd_before["base_model.model.my_module.lora_A.default.weight"], + sd_after["base_model.model.my_module.lora_A.default.weight"], + ) + assert not torch.allclose( + sd_before["base_model.model.my_module.lora_B.default.weight"], + sd_after["base_model.model.my_module.lora_B.default.weight"], + ) + + def test_saving_and_loading(self, custom_module_cls, custom_lora_cls, model_cls, tmp_path): + # check that we can successfully save and load the custom lora cls + torch.manual_seed(0) + model = model_cls() + config = LoraConfig(target_modules=["lin0", "my_module", "lin1"]) + config._register_custom_module({custom_module_cls: custom_lora_cls}) + + torch.manual_seed(1) + peft_model = get_peft_model(model, config) + + inputs = torch.randn(5, 10) + outputs_before = peft_model(inputs) # does not raise + + sd_before = peft_model.state_dict() + peft_model.save_pretrained(tmp_path / "lora-custom-module") + del model, peft_model + + torch.manual_seed(0) # same seed for base model + model = model_cls() + + # custom lora mapping is not persisted at the moment, so as a workaround this is needed + config = LoraConfig.from_pretrained(tmp_path / "lora-custom-module") + config._register_custom_module({custom_module_cls: custom_lora_cls}) + + # different seed for adapter to ensure it is not identical just because of seed + torch.manual_seed(123) + peft_model = PeftModel.from_pretrained(model, tmp_path / "lora-custom-module", config=config) + assert isinstance(peft_model.base_model.model.my_module, custom_lora_cls) + assert isinstance(peft_model.base_model.model.my_module.base_layer, custom_module_cls) + + outputs_after = peft_model(inputs) # does not raise + assert torch.allclose(outputs_before, outputs_after) + + sd_after = peft_model.state_dict() + assert sd_before.keys() == sd_after.keys() + for key in sd_before.keys(): + assert torch.allclose(sd_before[key], sd_after[key]) + + def test_override_lora_linear(self, custom_lora_cls): + # in this test, we check if users can override default PEFT behavior by supplying a custom lora class that is + # being used instead of lora.Linear + model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + config = LoraConfig(task_type=TaskType.CAUSAL_LM) + config._register_custom_module({nn.Linear: custom_lora_cls}) + peft_model = get_peft_model(model, config) + layers = peft_model.base_model.model.model.decoder.layers + for layer in layers: + assert isinstance(layer.self_attn.v_proj, custom_lora_cls) + assert isinstance(layer.self_attn.q_proj, custom_lora_cls) + + def test_custom_lora_layer_issues_warning(self, custom_module_cls, custom_lora_cls, model_cls, recwarn): + # users will get a warning if they target a layer type that is not officially supported + model = model_cls() + config = LoraConfig(target_modules=["lin0", "my_module", "lin1"]) + config._register_custom_module({custom_module_cls: custom_lora_cls}) + + get_peft_model(model, config) + # check warning message + msg = ( + "Unsupported layer type '.MyModule'>' encountered, proceed at your own risk." + ) + assert str(recwarn.list[-1].message) == msg + + def test_target_layer_without_in_features_out_features(self, recwarn): + # It should be possible for users to target layers even if we cannot determine in_features and out_features. + # Those are only needed to initialize the LoRA layer via update_layer, so as long as users take care of that, + # they should be good and not require those attributes to exist + from peft.tuners import lora + + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.lstm = nn.LSTM(10, 20) + + class MyLora(nn.Module, lora.LoraLayer): + def __init__(self, base_layer, adapter_name, **kwargs): + super().__init__() + lora.LoraLayer.__init__(self, base_layer, **kwargs) + self._active_adapter = adapter_name + + model = MyModel() + # check that in_features and out_features attributes don't exist on LSTM + assert not hasattr(model.lstm, "in_features") + assert not hasattr(model.lstm, "out_features") + + config = LoraConfig(target_modules=["lstm"]) + config._register_custom_module({nn.LSTM: MyLora}) + peft_model = get_peft_model(model, config) + + # check that custom LoRA layer is correctly applied + assert isinstance(peft_model.base_model.lstm, MyLora) + assert isinstance(peft_model.base_model.lstm.base_layer, nn.LSTM) + + # we should still get a warning message + msg = "Unsupported layer type '' encountered, proceed at your own risk." + assert str(recwarn.list[-1].message) == msg