diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 85e84478b5..a8cf017955 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -4,7 +4,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import math -from typing import Dict, List, Union +from typing import Dict, List, NamedTuple, Optional, Union import torch import torch.nn as nn @@ -13,9 +13,9 @@ from transformers.configuration_utils import PretrainedConfig from transformers.pytorch_utils import Conv1D -from ..composition import AdapterCompositionBlock +from ..composition import AdapterCompositionBlock, Average, BatchSplit, Stack from ..configuration import LoRAConfig, ModelAdaptersConfig -from .adapter_layer_base import AdapterLayerBase +from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase class LoRA(nn.Module): @@ -75,14 +75,16 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: """Inverts the composition operation between existing and injected weights.""" return weights - added * self.scaling - def forward(self, hidden_states: torch.Tensor, input_states: torch.Tensor): - delta_w = self.lora_dropout(input_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B) + def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): + if hidden_states is None: + hidden_states = layer_input + hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B) if self.use_gating: - gate = torch.sigmoid(self.gate(input_states)) + gate = torch.sigmoid(self.gate(layer_input)) gate = torch.mean(gate, dim=1).unsqueeze(-1) + hidden_states = hidden_states * gate else: gate = None - hidden_states = self.com(hidden_states, delta_w, scaling=gate) return hidden_states, gate @@ -139,14 +141,18 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: """Inverts the composition operation between existing and injected weights.""" return weights / (added * self.scaling) - def forward(self, hidden_states: torch.Tensor, input_states: torch.Tensor): - delta_w = self.lora_B.view(1, 1, -1) + def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): + scaling_vector = self.lora_B.view(1, 1, -1).repeat(layer_input.shape[0], 1, 1) + if hidden_states is None: + hidden_states = scaling_vector + else: + hidden_states = hidden_states * scaling_vector if self.use_gating: - gate = torch.sigmoid(self.gate(input_states)) + gate = torch.sigmoid(self.gate(layer_input)) gate = torch.mean(gate, dim=1).unsqueeze(-1) + hidden_states = hidden_states * gate else: gate = None - hidden_states = self.com(hidden_states, delta_w, scaling=gate) return hidden_states, gate @@ -252,9 +258,25 @@ def get_adapter(self, adapter_name: str) -> nn.Module: return None -class LoRALinear(LoRALayer, nn.Linear): +class LoRAState(NamedTuple): + """Models the input and output states of a LoRA layer. + + Args: + layer_input (torch.Tensor): The input states to the adapted layer. + hidden_states (Optional[torch.Tensor]): + The hidden states of the adaptation module. These can be None before passing through the first LoRA/ IA3 + module. + layer_output (torch.Tensor): The output states of the original layer without adaptation. + """ + + layer_input: torch.Tensor + hidden_states: Optional[torch.Tensor] + layer_output: torch.Tensor + + +class LoRALinear(LoRALayer, ComposableAdapterLayerBase, nn.Linear): """ - LoRA implementation for Linear layer. + LoRA implementation for Linear layer. This layer supports composition. Args: fan_in_fan_out (bool, optional): @@ -263,6 +285,9 @@ class LoRALinear(LoRALayer, nn.Linear): """ + # TODO: enable parallel composition for LoRA + supported_compositions = [Stack, BatchSplit, Average] + def __init__( self, in_features: int, @@ -356,29 +381,68 @@ def merge_adapter(self, name: str): elif self.merged != name: raise ValueError("LoRALayer already has a merged LoRA module. Please reset it first.") - def forward(self, x: torch.Tensor): - input_states = x - # result shape: x x + def vslice(self, state: LoRAState, slice_obj: slice) -> LoRAState: + return LoRAState( + state.layer_input[slice_obj], + state.hidden_states[slice_obj] if state.hidden_states is not None else None, + state.layer_output[slice_obj], + ) + + def pad_and_concat(self, states: List[LoRAState]) -> LoRAState: + return LoRAState( + torch.cat([s.layer_input for s in states], dim=0), + torch.cat([s.hidden_states for s in states], dim=0) if states[0].hidden_states is not None else None, + torch.cat([s.layer_output for s in states], dim=0), + ) + + def repeat(self, state: LoRAState, channels: int) -> LoRAState: + return LoRAState( + state.layer_input.repeat(channels, 1, 1), + state.hidden_states.repeat(channels, 1, 1) if state.hidden_states is not None else None, + state.layer_output.repeat(channels, 1, 1), + ) + + def mean(self, states: List[LoRAState], weights: torch.Tensor) -> LoRAState: + return LoRAState( + states[0].layer_input, + torch.mean(torch.stack([s.hidden_states for s in states], dim=0) * weights, dim=0) + if states[0].hidden_states is not None + else None, + states[0].layer_output, + ) + + def compose_single(self, adapter_setup: str, state: LoRAState, lvl: int = 0) -> LoRAState: + lora = self.loras[adapter_setup] + hidden_states, gate = lora(state.hidden_states, state.layer_input) + if gate is not None: + self._store_gating_score(adapter_setup, gate) + + return state._replace(hidden_states=hidden_states) + + def forward(self, input_states: torch.Tensor): weight = torch.transpose(self.weight, -2, -1) if self.fan_in_fan_out else self.weight - x = F.linear(x, weight, bias=self.bias) + # result shape: x x + layer_output = F.linear(input_states, weight, bias=self.bias) if not self.merged: adapter_setup = self.get_active_setup() if adapter_setup is not None: - if len(adapter_setup) == 1: - lora = self.loras[adapter_setup[0]] - x, gate = lora(x, input_states) - if gate is not None: - self._store_gating_score(adapter_setup[0], gate) - else: - raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.") + state = LoRAState(input_states, None, layer_output) + state = self.compose(adapter_setup, state) + _, hidden_states, layer_output = state + + last_lora = self.loras[adapter_setup.last()] + layer_output = last_lora.com( + layer_output, hidden_states, scaling=1.0 + ) # scaling already applied in compose - return x + return layer_output class LoRAMergedLinear(LoRALayer, nn.Linear): """ - LoRA implementation for merged attention layer layer. + LoRA implementation for merged attention layer, as used by some model implementations (e.g. GPT-2). This layer + currently does not support composition. Args: fan_in_fan_out (bool, optional): diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py index 2670488cb9..5374732f35 100644 --- a/tests_adapters/composition/test_adapter_composition.py +++ b/tests_adapters/composition/test_adapter_composition.py @@ -3,7 +3,7 @@ import torch import adapters -from adapters import PrefixTuningConfig, SeqBnConfig +from adapters import IA3Config, LoRAConfig, PrefixTuningConfig, SeqBnConfig from adapters.composition import Average, BatchSplit, Fuse, Parallel, Split, Stack, parse_composition from tests.test_modeling_common import ids_tensor from transformers import BertConfig, BertForSequenceClassification @@ -234,3 +234,17 @@ class PrefixTuningCompositionTest(AdapterCompositionTest): def get_adapter_config(self): return PrefixTuningConfig() + + +class LoRACompositionTest(AdapterCompositionTest): + unsupported_blocks = [Split, Fuse, Parallel] + + def get_adapter_config(self): + return LoRAConfig(init_weights="bert") + + +class IA3CompositionTest(AdapterCompositionTest): + unsupported_blocks = [Split, Fuse, Parallel] + + def get_adapter_config(self): + return IA3Config()