From 42fff1e0fe6ec0e3bf4f60c457a35d4edaadb402 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 18 Nov 2023 14:57:18 +0100 Subject: [PATCH 1/2] =?UTF-8?q?Add=20Composition=20Support=20to=20LoRA=20a?= =?UTF-8?q?nd=20(IA)=C2=B3=20(#598)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to #591. This PR provides initial support for adapter composition in LoRA & (IA)³ modules. Currently LoRA & (IA)³ don't support composition. With this PR, the following blocks will be supported: **Stack, BatchSplit, Average, Parallel** Additionally, the LoRA implementation is refactored a bit in an effort to make it cleaner. ### Limitations - Split & Fuse compositions are **not** supported - LoRA/ (IA)³ composition is **not** supported for models using the `LoRAMergedLinear` implementation. These currently are: **GPT-2, DeBERTa (v1)** --- docs/adapter_composition.md | 10 +- docs/index.rst | 1 - src/adapters/composition.py | 17 +- src/adapters/methods/adapter_layer_base.py | 17 +- src/adapters/methods/lora.py | 309 ++++++++++++------ src/adapters/models/albert/mixin_albert.py | 2 +- src/adapters/models/albert/modeling_albert.py | 4 +- src/adapters/models/bart/mixin_bart.py | 2 +- src/adapters/models/bart/modeling_bart.py | 7 +- src/adapters/models/beit/mixin_beit.py | 2 +- src/adapters/models/bert/mixin_bert.py | 2 +- src/adapters/models/bert/modeling_bert.py | 4 +- .../modeling_bert_generation.py | 4 +- src/adapters/models/clip/mixin_clip.py | 2 +- src/adapters/models/deberta/mixin_deberta.py | 2 +- .../models/deberta/modeling_deberta.py | 5 +- .../models/deberta_v2/mixin_deberta_v2.py | 2 +- .../models/deberta_v2/modeling_deberta_v2.py | 5 +- .../models/distilbert/mixin_distilbert.py | 2 +- .../models/distilbert/modeling_distilbert.py | 5 +- .../models/electra/modeling_electra.py | 4 +- src/adapters/models/gpt2/mixin_gpt2.py | 3 +- src/adapters/models/gptj/mixin_gptj.py | 2 +- src/adapters/models/gptj/modeling_gptj.py | 5 +- src/adapters/models/llama/mixin_llama.py | 2 +- src/adapters/models/llama/modeling_llama.py | 11 +- src/adapters/models/mbart/modeling_mbart.py | 7 +- .../models/roberta/modeling_roberta.py | 4 +- src/adapters/models/t5/mixin_t5.py | 2 +- src/adapters/models/t5/modeling_t5.py | 7 +- src/adapters/models/vit/mixin_vit.py | 2 +- src/adapters/models/vit/modeling_vit.py | 4 +- .../xlm_roberta/modeling_xlm_roberta.py | 4 +- src/adapters/models/xmod/modeling_xmod.py | 4 +- .../composition/test_adapter_composition.py | 22 +- 35 files changed, 344 insertions(+), 143 deletions(-) diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md index 05f85f3fdb..e8b6bf3c10 100644 --- a/docs/adapter_composition.md +++ b/docs/adapter_composition.md @@ -42,14 +42,16 @@ The following table gives an overview on the supported composition blocks and th | Block | Bottleneck
Adapters | Prefix
Tuning | Compacter | LoRA | (IA)³ | | --- | --- | --- | --- | --- | --- | -| [`Stack`](#stack) | ✅ | ✅ | ✅ | | | +| [`Stack`](#stack) | ✅ | ✅ | ✅ | ✅(*) | ✅(*) | | [`Fuse`](#fuse) | ✅ | | ✅ | | | | [`Split`](#split) | ✅ | | ✅ | | | -| [`BatchSplit`](#batchsplit) | ✅ | ✅ | ✅ | | | -| [`Parallel`](#parallel) | ✅ | ✅ | ✅ | | | -| [Output averaging](#output-averaging) | ✅ | | ✅ | | | +| [`BatchSplit`](#batchsplit) | ✅ | ✅ | ✅ | ✅(*) | ✅(*) | +| [`Parallel`](#parallel) | ✅ | ✅ | ✅ | ✅(*) | ✅(*) | +| [Output averaging](#output-averaging) | ✅ | | ✅ | ✅(*) | ✅(*) | | [Parameter averaging](#parameter-averaging) | ✅ | ✅ | ✅ | ✅ | ✅ | +(*) except for Deberta-v1, GPT-2. + Next, we present all composition blocks in more detail. ## `Stack` diff --git a/docs/index.rst b/docs/index.rst index fdddf228ec..b3685d8d28 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -94,7 +94,6 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/adapter_config classes/model_adapters_config - classes/adapter_modules classes/adapter_layer classes/model_mixins classes/adapter_training diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 5899b113d6..937fae2685 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -1,6 +1,8 @@ import itertools from collections.abc import Sequence -from typing import List, Optional, Set, Union +from typing import List, Optional, Set, Tuple, Union + +import torch class AdapterCompositionBlock(Sequence): @@ -242,3 +244,16 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors): repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats) tensor.set_(new_tensor) + + +def match_attn_matrices_for_parallel(query, key, value) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Matches the shapes of query, key and value matrices for parallel composition. + """ + max_bsz = max(query.shape[0], key.shape[0], value.shape[0]) + + query = query.repeat(max_bsz // query.shape[0], *([1] * len(query.shape[1:]))) + key = key.repeat(max_bsz // key.shape[0], *([1] * len(key.shape[1:]))) + value = value.repeat(max_bsz // value.shape[0], *([1] * len(value.shape[1:]))) + + return query, key, value diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index b89b75cb14..79d18500ec 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -150,10 +150,13 @@ class ComposableAdapterLayerBase(AdapterLayerBase): Base class for all adapter methods that support composition. Make sure the 'adapter_modules_name' and 'supported_compositions' attributes as well as all abstract methods are - overriden in derived classes. + overriden in derived classes. 'allow_multi_parallelize' can be set to True to allow inputs to be parallelized + independently multiple times. This is useful when there are multiple parallel input flows through an adapter layer + (e.g. in LoRA). """ supported_compositions = [] + allow_multi_parallelize = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -382,15 +385,23 @@ def compose_parallel(self, adapter_setup: Parallel, state: NamedTuple, lvl: int orig_batch_size = self._bsz(state) state = self.repeat(state, adapter_setup.parallel_channels) context.adapters_parallelized = True + context.original_batch_size = orig_batch_size else: + bsz = self._bsz(state) + # If the input was already parallelized, we can parallelize it again. + # This is useful e.g. for LoRA, where attention matrices are parallelized independently. + if self.allow_multi_parallelize and bsz == getattr(context, "original_batch_size", -1): + state = self.repeat(state, adapter_setup.parallel_channels) + orig_batch_size = bsz # The base model should handle replication of input. # Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels. - if self._bsz(state) % adapter_setup.parallel_channels != 0: + elif bsz % adapter_setup.parallel_channels != 0: raise ValueError( "The total input batch size in a Parallel adapter block must be divisible by the number of" " parallel channels." ) - orig_batch_size = self._bsz(state) // adapter_setup.parallel_channels + else: + orig_batch_size = bsz // adapter_setup.parallel_channels state = self.pre_block(adapter_setup, state) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index a4c66c830c..db987a7853 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -3,8 +3,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import logging import math -from typing import Dict, List, Union +from typing import Dict, List, NamedTuple, Optional, Union import torch import torch.nn as nn @@ -13,9 +14,12 @@ from transformers.configuration_utils import PretrainedConfig from transformers.pytorch_utils import Conv1D -from ..composition import AdapterCompositionBlock +from ..composition import AdapterCompositionBlock, Average, BatchSplit, Parallel, Stack from ..configuration import LoRAConfig, ModelAdaptersConfig -from .adapter_layer_base import AdapterLayerBase +from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase + + +logger = logging.getLogger(__name__) class LoRA(nn.Module): @@ -27,6 +31,7 @@ def __init__( gating_heads: int = 1, ): super().__init__() + assert config.composition_mode == "add", "LoRA module only supports composition_mode='add'." self.r = config.r self.lora_alpha = config.alpha self.composition_mode = config.composition_mode @@ -39,58 +44,126 @@ def __init__( self.lora_dropout = lambda x: x # Actual trainable parameters - if self.r > 1 and self.composition_mode == "scale": - raise ValueError("Can only use composition_mode='scale' when r == 1.") - if self.r > 0: - if self.composition_mode == "add": - self.lora_A = nn.Parameter(torch.zeros(lora_A_shape)) - self.lora_B = nn.Parameter(torch.zeros(lora_B_shape)) - self.scaling = self.lora_alpha / self.r - - if self.use_gating: - self.gate = nn.Linear(lora_A_shape[-1], gating_heads) - - if config.init_weights == "lora": - # initialize A the same way as the default for nn.Linear and B to zero - if self.composition_mode == "add": - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - if self.use_gating: - nn.init.normal_(self.gate.weight, std=0.02) - elif config.init_weights == "bert": - if self.composition_mode == "add": - nn.init.normal_(self.lora_A, std=0.02) - nn.init.normal_(self.lora_B, std=0.02) - if self.use_gating: - nn.init.normal_(self.gate.weight, std=0.02) - elif config.init_weights == "ia3": - if self.composition_mode == "add": - nn.init.ones_(self.lora_A) - nn.init.ones_(self.lora_B) - if self.use_gating: - nn.init.normal_(self.gate.weight, std=0.02) - else: - raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) + self.lora_A = nn.Parameter(torch.zeros(lora_A_shape)) + self.lora_B = nn.Parameter(torch.zeros(lora_B_shape)) + self.scaling = self.lora_alpha / self.r + + # For compatibility with (IA)^3, allow all init_weights types here. + # Usually should be "lora". + if config.init_weights == "lora": + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + elif config.init_weights == "bert": + nn.init.normal_(self.lora_A, std=0.02) + nn.init.normal_(self.lora_B, std=0.02) + elif config.init_weights == "ia3": + nn.init.ones_(self.lora_A) + nn.init.ones_(self.lora_B) + else: + raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) + + if self.use_gating: + self.gate = nn.Linear(lora_A_shape[-1], gating_heads) + nn.init.normal_(self.gate.weight, std=0.02) + + @property + def delta_w(self) -> torch.Tensor: + return self.lora_B @ self.lora_A def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor: """Performs the composition operation between existing and injected weights.""" if scaling is None: scaling = self.scaling - if self.composition_mode == "add": - return weights + added * scaling - elif self.composition_mode == "scale": - return weights * (added * scaling) + return weights + added * scaling + + def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: + """Inverts the composition operation between existing and injected weights.""" + return weights - added * self.scaling + + def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): + if hidden_states is None: + hidden_states = layer_input + hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B) + if self.use_gating: + gate = torch.sigmoid(self.gate(layer_input)) + gate = torch.mean(gate, dim=1).unsqueeze(-1) + hidden_states = hidden_states * gate + else: + gate = None + + return hidden_states, gate + + +class IA3(nn.Module): + def __init__( + self, + lora_A_shape, + lora_B_shape, + config: LoRAConfig, + gating_heads: int = 1, + ): + super().__init__() + assert config.composition_mode == "scale", "IA3 module only supports composition_mode='scale'." + if config.r > 1: + raise ValueError("Can only use composition_mode='scale' when r == 1.") + self.r = config.r + self.lora_alpha = config.alpha + self.composition_mode = config.composition_mode + self.attn_matrices = config.attn_matrices + self.use_gating = config.use_gating + # Optional dropout + if config.dropout > 0.0: + raise ValueError("IA3 module does not support dropout.") + + # Actual trainable parameters + self.lora_B = nn.Parameter(torch.zeros(lora_B_shape)) + self.scaling = self.lora_alpha + + # For compatibility with LoRA, allow all init_weights types here. + # Usually should be "ia3". + if config.init_weights == "lora": + logger.warning("(IA)^3 module initialized with LoRA zeo init. Ignore if this is intended.") + nn.init.zeros_(self.lora_B) + elif config.init_weights == "bert": + nn.init.normal_(self.lora_B, std=0.02) + elif config.init_weights == "ia3": + nn.init.ones_(self.lora_B) else: - raise ValueError("Invalid composition mode.") + raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) + + if self.use_gating: + self.gate = nn.Linear(lora_A_shape[-1], gating_heads) + nn.init.normal_(self.gate.weight, std=0.02) + + @property + def delta_w(self) -> torch.Tensor: + return self.lora_B + + def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor: + """Performs the composition operation between existing and injected weights.""" + if scaling is None: + scaling = self.scaling + return weights * (added * scaling) def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: """Inverts the composition operation between existing and injected weights.""" - if self.composition_mode == "add": - return weights - added * self.scaling - elif self.composition_mode == "scale": - return weights / (added * self.scaling) + return weights / (added * self.scaling) + + def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): + scaling_vector = self.lora_B.view(1, 1, -1).repeat(layer_input.shape[0], 1, 1) + if hidden_states is None: + hidden_states = scaling_vector + else: + hidden_states = hidden_states * scaling_vector + if self.use_gating: + gate = torch.sigmoid(self.gate(layer_input)) + gate = torch.mean(gate, dim=1).unsqueeze(-1) + hidden_states = hidden_states * gate else: - raise ValueError("Invalid composition mode.") + gate = None + + return hidden_states, gate class LoRALayer(AdapterLayerBase): @@ -107,7 +180,7 @@ def __init__( self.merged = False - def get_n_heads(self, lora: Union[LoRA, LoRAConfig]): + def get_n_heads(self, lora: Union[LoRA, IA3, LoRAConfig]): return 1 def _check_lora_location(self, config: LoRAConfig): @@ -125,7 +198,13 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: location_key=self.location_key, ) if lora_config is not None and self._check_lora_location(lora_config): - lora = LoRA( + if lora_config.composition_mode == "add": + lora_cls = LoRA + elif lora_config.composition_mode == "scale": + lora_cls = IA3 + else: + raise ValueError(f"Unknown composition_mode: {lora_config.composition_mode}") + lora = lora_cls( *self._get_lora_shapes(lora_config), lora_config, gating_heads=self.get_n_heads(lora_config), @@ -188,9 +267,25 @@ def get_adapter(self, adapter_name: str) -> nn.Module: return None -class Linear(LoRALayer, nn.Linear): +class LoRAState(NamedTuple): + """Models the input and output states of a LoRA layer. + + Args: + layer_input (torch.Tensor): The input states to the adapted layer. + hidden_states (Optional[torch.Tensor]): + The hidden states of the adaptation module. These can be None before passing through the first LoRA/ IA3 + module. + layer_output (torch.Tensor): The output states of the original layer without adaptation. """ - LoRA implementation for Linear layer. + + layer_input: torch.Tensor + hidden_states: Optional[torch.Tensor] + layer_output: torch.Tensor + + +class LoRALinear(LoRALayer, ComposableAdapterLayerBase, nn.Linear): + """ + LoRA implementation for Linear layer. This layer supports composition. Args: fan_in_fan_out (bool, optional): @@ -199,6 +294,9 @@ class Linear(LoRALayer, nn.Linear): """ + supported_compositions = [Stack, BatchSplit, Average, Parallel] + allow_multi_parallelize = True + def __init__( self, in_features: int, @@ -267,36 +365,17 @@ def _check_lora_location(self, config: LoRAConfig): def _get_lora_shapes(self, config: LoRAConfig): return (config.r, self.in_features), (self.out_features, config.r) - def reset_adapter(self): - def T(w): - return torch.t(w) if self.fan_in_fan_out else w + def maybe_t(self, w): + return torch.t(w) if self.fan_in_fan_out else w + def reset_adapter(self): if self.merged: lora = self.loras[self.merged] # Make sure that the weights are not merged - if lora.r > 0: - if lora.composition_mode == "scale": - delta_w = T(lora.lora_B) - else: - delta_w = T(lora.lora_B @ lora.lora_A) - self.weight.data = lora.com_inv(self.weight.data, delta_w) + delta_w = self.maybe_t(lora.delta_w) + self.weight.data = lora.com_inv(self.weight.data, delta_w) self.merged = None - def _compute_adapted_weight(self, lora, scaling=None): - def T(w): - return torch.t(w) if self.fan_in_fan_out else w - - weight = self.weight - # Merge the weights and mark it - if lora.r > 0: - if lora.composition_mode == "scale": - delta_w = T(lora.lora_B) - else: - delta_w = T(lora.lora_B @ lora.lora_A) - weight = lora.com(weight, delta_w, scaling=scaling) - - return weight - def merge_adapter(self, name: str): if name in self.loras: if self.merged == name: @@ -305,44 +384,74 @@ def merge_adapter(self, name: str): lora = self.loras[name] if lora.use_gating: raise ValueError("Cannot merge LoRA layer with gating.") - self.weight.data = self._compute_adapted_weight(lora) + delta_w = self.maybe_t(lora.delta_w) + self.weight.data = lora.com(self.weight.data, delta_w) self.merged = name elif self.merged != name: raise ValueError("LoRALayer already has a merged LoRA module. Please reset it first.") - def forward(self, x: torch.Tensor): - def T(w): - return torch.transpose(w, -2, -1) if self.fan_in_fan_out else w + def vslice(self, state: LoRAState, slice_obj: slice) -> LoRAState: + return LoRAState( + state.layer_input[slice_obj], + state.hidden_states[slice_obj] if state.hidden_states is not None else None, + state.layer_output[slice_obj], + ) + + def pad_and_concat(self, states: List[LoRAState]) -> LoRAState: + return LoRAState( + torch.cat([s.layer_input for s in states], dim=0), + torch.cat([s.hidden_states for s in states], dim=0) if states[0].hidden_states is not None else None, + torch.cat([s.layer_output for s in states], dim=0), + ) + + def repeat(self, state: LoRAState, channels: int) -> LoRAState: + return LoRAState( + state.layer_input.repeat(channels, 1, 1), + state.hidden_states.repeat(channels, 1, 1) if state.hidden_states is not None else None, + state.layer_output.repeat(channels, 1, 1), + ) + + def mean(self, states: List[LoRAState], weights: torch.Tensor) -> LoRAState: + return LoRAState( + states[0].layer_input, + torch.mean(torch.stack([s.hidden_states for s in states], dim=0) * weights, dim=0) + if states[0].hidden_states is not None + else None, + states[0].layer_output, + ) + + def compose_single(self, adapter_setup: str, state: LoRAState, lvl: int = 0) -> LoRAState: + lora = self.loras[adapter_setup] + hidden_states, gate = lora(state.hidden_states, state.layer_input) + if gate is not None: + self._store_gating_score(adapter_setup, gate) + + return state._replace(hidden_states=hidden_states) + + def forward(self, input_states: torch.Tensor): + weight = torch.transpose(self.weight, -2, -1) if self.fan_in_fan_out else self.weight + # result shape: x x + layer_output = F.linear(input_states, weight, bias=self.bias) if not self.merged: adapter_setup = self.get_active_setup() if adapter_setup is not None: - if len(adapter_setup) == 1: - lora = self.loras[adapter_setup[0]] - # result shape: x x - result = F.linear(x, T(self.weight), bias=self.bias) - if lora.r > 0: - if lora.composition_mode == "scale": - delta_w = lora.lora_B.view(1, 1, -1) - else: - delta_w = lora.lora_dropout(x) @ torch.t(lora.lora_A) @ torch.t(lora.lora_B) - if lora.use_gating: - gate = torch.sigmoid(lora.gate(x)) - gate = torch.mean(gate, dim=1).unsqueeze(-1) - self._store_gating_score(adapter_setup[0], gate) - else: - gate = None - result = lora.com(result, delta_w, scaling=gate) - return result - else: - raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.") + state = LoRAState(input_states, None, layer_output) + state = self.compose(adapter_setup, state) + _, hidden_states, layer_output = state - return F.linear(x, T(self.weight), bias=self.bias) + last_lora = self.loras[adapter_setup.last()] + layer_output = last_lora.com( + layer_output, hidden_states, scaling=1.0 + ) # scaling already applied in compose + + return layer_output -class MergedLinear(LoRALayer, nn.Linear): +class LoRAMergedLinear(LoRALayer, nn.Linear): """ - LoRA implementation for merged attention layer layer. + LoRA implementation for merged attention layer, as used by some model implementations (e.g. GPT-2). This layer + currently does not support composition. Args: fan_in_fan_out (bool, optional): @@ -395,7 +504,7 @@ def wrap( return new_module - def get_n_heads(self, lora: Union[LoRA, LoRAConfig]): + def get_n_heads(self, lora: Union[LoRA, IA3, LoRAConfig]): return len(set(lora.attn_matrices)) def _get_lora_shapes(self, config: LoRAConfig): diff --git a/src/adapters/models/albert/mixin_albert.py b/src/adapters/models/albert/mixin_albert.py index 21534980af..ff9ef19fe3 100644 --- a/src/adapters/models/albert/mixin_albert.py +++ b/src/adapters/models/albert/mixin_albert.py @@ -4,7 +4,7 @@ from ...composition import adjust_tensors_for_parallel_ from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin diff --git a/src/adapters/models/albert/modeling_albert.py b/src/adapters/models/albert/modeling_albert.py index df3e7523f0..7f5294cad7 100644 --- a/src/adapters/models/albert/modeling_albert.py +++ b/src/adapters/models/albert/modeling_albert.py @@ -23,7 +23,7 @@ from transformers.models.albert.modeling_albert import AlbertAttention, AlbertLayer from transformers.pytorch_utils import apply_chunking_to_forward -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from .mixin_albert import AlbertAttentionAdaptersMixin, AlbertEncoderLayerAdaptersMixin @@ -42,6 +42,8 @@ def forward( query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) key_layer, value_layer, attention_mask = self.prefix_tuning( key_layer, value_layer, hidden_states, attention_mask diff --git a/src/adapters/models/bart/mixin_bart.py b/src/adapters/models/bart/mixin_bart.py index 5ef20aaa86..28e7b3ac77 100644 --- a/src/adapters/models/bart/mixin_bart.py +++ b/src/adapters/models/bart/mixin_bart.py @@ -5,7 +5,7 @@ from ...composition import adjust_tensors_for_parallel from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ( EmbeddingAdaptersMixin, diff --git a/src/adapters/models/bart/modeling_bart.py b/src/adapters/models/bart/modeling_bart.py index cb15b385bd..28bf37bd7c 100644 --- a/src/adapters/models/bart/modeling_bart.py +++ b/src/adapters/models/bart/modeling_bart.py @@ -21,7 +21,7 @@ from transformers.models.bart.modeling_bart import BartAttention, BartDecoderLayer, BartEncoderLayer -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel from .mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -74,6 +74,11 @@ def forward( key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention diff --git a/src/adapters/models/beit/mixin_beit.py b/src/adapters/models/beit/mixin_beit.py index 2c129f085c..536048e669 100644 --- a/src/adapters/models/beit/mixin_beit.py +++ b/src/adapters/models/beit/mixin_beit.py @@ -3,7 +3,7 @@ import torch.nn as nn from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ModelBaseAdaptersMixin diff --git a/src/adapters/models/bert/mixin_bert.py b/src/adapters/models/bert/mixin_bert.py index e97c9dd988..3cf5a6e1ff 100644 --- a/src/adapters/models/bert/mixin_bert.py +++ b/src/adapters/models/bert/mixin_bert.py @@ -5,7 +5,7 @@ from ...composition import adjust_tensors_for_parallel_ from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin diff --git a/src/adapters/models/bert/modeling_bert.py b/src/adapters/models/bert/modeling_bert.py index 539dc74ebf..692605610a 100644 --- a/src/adapters/models/bert/modeling_bert.py +++ b/src/adapters/models/bert/modeling_bert.py @@ -25,7 +25,7 @@ from transformers.models.bert.modeling_bert import BertOutput, BertSelfAttention, BertSelfOutput -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from .mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -66,6 +66,8 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) use_cache = past_key_value is not None if self.is_decoder: diff --git a/src/adapters/models/bert_generation/modeling_bert_generation.py b/src/adapters/models/bert_generation/modeling_bert_generation.py index 8f083fe295..8381ccf2bb 100644 --- a/src/adapters/models/bert_generation/modeling_bert_generation.py +++ b/src/adapters/models/bert_generation/modeling_bert_generation.py @@ -27,7 +27,7 @@ BertGenerationSelfOutput, ) -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -78,6 +78,8 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) use_cache = past_key_value is not None if self.is_decoder: diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py index 36eae84b0f..02469974f5 100644 --- a/src/adapters/models/clip/mixin_clip.py +++ b/src/adapters/models/clip/mixin_clip.py @@ -4,7 +4,7 @@ from ...composition import adjust_tensors_for_parallel_ from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ( EmbeddingAdaptersMixin, diff --git a/src/adapters/models/deberta/mixin_deberta.py b/src/adapters/models/deberta/mixin_deberta.py index cee8530f02..d9907de36d 100644 --- a/src/adapters/models/deberta/mixin_deberta.py +++ b/src/adapters/models/deberta/mixin_deberta.py @@ -1,4 +1,4 @@ -from ...methods.lora import MergedLinear as LoRAMergedLinear +from ...methods.lora import LoRAMergedLinear from ...methods.prefix_tuning import PrefixTuningLayer diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 8197c19fb6..71b7f9dc2a 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -24,7 +24,7 @@ XSoftmax, ) -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin from .mixin_deberta import DebertaSelfAttentionAdaptersMixin @@ -113,6 +113,9 @@ def linear(w, b, x): k, v = [linear(qkvw[i], qkvb[i], hidden_states.to(dtype=qkvw[i].dtype)) for i in range(1, 3)] query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) + query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) diff --git a/src/adapters/models/deberta_v2/mixin_deberta_v2.py b/src/adapters/models/deberta_v2/mixin_deberta_v2.py index f60e8788fb..3a33fdf84c 100644 --- a/src/adapters/models/deberta_v2/mixin_deberta_v2.py +++ b/src/adapters/models/deberta_v2/mixin_deberta_v2.py @@ -1,4 +1,4 @@ -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 082e77a721..aa8945000f 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -24,7 +24,7 @@ XSoftmax, ) -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin from .mixin_deberta_v2 import DebertaV2SelfAttentionAdaptersMixin @@ -97,6 +97,9 @@ def forward( key_layer = self.transpose_for_scores_extended(self.key_proj(hidden_states), self.num_attention_heads) value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) + orig_key_layer = key_layer.contiguous() # save this for relative attention key_layer, value_layer, attention_mask = self.prefix_tuning( key_layer, value_layer, hidden_states, attention_mask, False diff --git a/src/adapters/models/distilbert/mixin_distilbert.py b/src/adapters/models/distilbert/mixin_distilbert.py index 44bcbb0b16..111733c2f0 100644 --- a/src/adapters/models/distilbert/mixin_distilbert.py +++ b/src/adapters/models/distilbert/mixin_distilbert.py @@ -3,7 +3,7 @@ import torch.nn as nn from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin diff --git a/src/adapters/models/distilbert/modeling_distilbert.py b/src/adapters/models/distilbert/modeling_distilbert.py index 6dfb62eb1c..e0aee4e1b9 100644 --- a/src/adapters/models/distilbert/modeling_distilbert.py +++ b/src/adapters/models/distilbert/modeling_distilbert.py @@ -27,7 +27,7 @@ from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel from .mixin_distilbert import DistilBertMultiHeadSelfAttentionMixin, DistilBertTransfomerBlockAdaptersMixin @@ -70,6 +70,9 @@ def unshape(x: torch.Tensor) -> torch.Tensor: k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + q, k, v = match_attn_matrices_for_parallel(q, k, v) + (mask,) = adjust_tensors_for_parallel(q, mask) + k, v, mask = self.prefix_tuning(k, v, value, mask, invert_mask=False) bs = k.size(0) # reset for Parallel block (q,) = adjust_tensors_for_parallel(k, q) diff --git a/src/adapters/models/electra/modeling_electra.py b/src/adapters/models/electra/modeling_electra.py index 35552782ce..cbe4277ec9 100644 --- a/src/adapters/models/electra/modeling_electra.py +++ b/src/adapters/models/electra/modeling_electra.py @@ -6,7 +6,7 @@ from transformers.models.electra.modeling_electra import ElectraOutput, ElectraSelfAttention, ElectraSelfOutput -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -47,6 +47,8 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) use_cache = past_key_value is not None if self.is_decoder: diff --git a/src/adapters/models/gpt2/mixin_gpt2.py b/src/adapters/models/gpt2/mixin_gpt2.py index e86c2967a9..ce88136a92 100644 --- a/src/adapters/models/gpt2/mixin_gpt2.py +++ b/src/adapters/models/gpt2/mixin_gpt2.py @@ -3,8 +3,7 @@ import torch.nn as nn from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear -from ...methods.lora import MergedLinear as LoRAMergedLinear +from ...methods.lora import LoRALinear, LoRAMergedLinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin diff --git a/src/adapters/models/gptj/mixin_gptj.py b/src/adapters/models/gptj/mixin_gptj.py index 333c1b9358..7e4e771cba 100644 --- a/src/adapters/models/gptj/mixin_gptj.py +++ b/src/adapters/models/gptj/mixin_gptj.py @@ -3,7 +3,7 @@ import torch.nn as nn from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin diff --git a/src/adapters/models/gptj/modeling_gptj.py b/src/adapters/models/gptj/modeling_gptj.py index 453f0c9b6d..700e919a17 100644 --- a/src/adapters/models/gptj/modeling_gptj.py +++ b/src/adapters/models/gptj/modeling_gptj.py @@ -22,7 +22,7 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, apply_rotary_pos_emb, get_embed_positions from transformers.utils.import_utils import is_torch_fx_proxy -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel from .mixin_gptj import GPTJAttentionAdaptersMixin, GPTJDecoderBlockAdaptersMixin @@ -44,6 +44,9 @@ def forward( key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) + query, key, value = match_attn_matrices_for_parallel(query, key, value) + (attention_mask,) = adjust_tensors_for_parallel(query, attention_mask) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index 22223edaf4..3caf66e544 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -3,7 +3,7 @@ import torch.nn as nn from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index f16b65e9c6..3b22e5ae13 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -25,7 +25,11 @@ import torch.utils.checkpoint from torch import nn -from adapters.composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from adapters.composition import ( + adjust_tensors_for_parallel, + adjust_tensors_for_parallel_, + match_attn_matrices_for_parallel, +) from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.utils import logging @@ -53,6 +57,11 @@ def forward( key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] diff --git a/src/adapters/models/mbart/modeling_mbart.py b/src/adapters/models/mbart/modeling_mbart.py index 5c43212a28..0f8f0d5335 100644 --- a/src/adapters/models/mbart/modeling_mbart.py +++ b/src/adapters/models/mbart/modeling_mbart.py @@ -21,7 +21,7 @@ from transformers.models.mbart.modeling_mbart import MBartAttention, MBartDecoderLayer, MBartEncoderLayer -from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ +from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel from ..bart.mixin_bart import BartAttentionAdaptersMixin, BartDecoderLayerAdaptersMixin, BartEncoderLayerAdaptersMixin @@ -74,6 +74,11 @@ def forward( key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) + if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention diff --git a/src/adapters/models/roberta/modeling_roberta.py b/src/adapters/models/roberta/modeling_roberta.py index 47a8ed35a9..e33b7e7ca3 100644 --- a/src/adapters/models/roberta/modeling_roberta.py +++ b/src/adapters/models/roberta/modeling_roberta.py @@ -24,7 +24,7 @@ from transformers.models.roberta.modeling_roberta import RobertaOutput, RobertaSelfAttention, RobertaSelfOutput -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -66,6 +66,8 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) use_cache = past_key_value is not None if self.is_decoder: diff --git a/src/adapters/models/t5/mixin_t5.py b/src/adapters/models/t5/mixin_t5.py index 832dfd185d..244f5d4335 100644 --- a/src/adapters/models/t5/mixin_t5.py +++ b/src/adapters/models/t5/mixin_t5.py @@ -4,7 +4,7 @@ import torch.nn as nn from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ( EmbeddingAdaptersMixin, diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 3440a4bb73..19064f58b2 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -28,7 +28,7 @@ ) from transformers.utils import logging -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from .mixin_t5 import ( T5AttentionAdaptersMixin, T5CrossAttentionLayerAdaptersMixin, @@ -128,6 +128,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None ) + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (mask,) = adjust_tensors_for_parallel(query_states, mask) + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None key_states, value_states, mask = self.prefix_tuning(key_states, value_states, hidden_states, mask) diff --git a/src/adapters/models/vit/mixin_vit.py b/src/adapters/models/vit/mixin_vit.py index 07598ad8ae..2f9962a9d8 100644 --- a/src/adapters/models/vit/mixin_vit.py +++ b/src/adapters/models/vit/mixin_vit.py @@ -3,7 +3,7 @@ import torch.nn as nn from ...methods.bottleneck import BottleneckLayer -from ...methods.lora import Linear as LoRALinear +from ...methods.lora import LoRALinear from ...methods.prefix_tuning import PrefixTuningLayer from ...model_mixin import ModelBaseAdaptersMixin diff --git a/src/adapters/models/vit/modeling_vit.py b/src/adapters/models/vit/modeling_vit.py index bb0fadd2ca..f8c02bd931 100644 --- a/src/adapters/models/vit/modeling_vit.py +++ b/src/adapters/models/vit/modeling_vit.py @@ -22,7 +22,7 @@ import torch.utils.checkpoint from torch import nn -from adapters.composition import adjust_tensors_for_parallel +from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from transformers.models.vit.modeling_vit import ViTLayer, ViTOutput, ViTSelfAttention from .mixin_vit import ViTLayerAdaptersMixin, ViTOutputAdaptersMixin, ViTSelfAttentionAdaptersMixin @@ -38,6 +38,8 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states) (query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer) diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py index a8d22284b7..5f18c9f70e 100644 --- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py @@ -28,7 +28,7 @@ XLMRobertaSelfOutput, ) -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -70,6 +70,8 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) use_cache = past_key_value is not None if self.is_decoder: diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py index b772321667..4a2269fbae 100644 --- a/src/adapters/models/xmod/modeling_xmod.py +++ b/src/adapters/models/xmod/modeling_xmod.py @@ -23,7 +23,7 @@ from transformers.models.xmod.modeling_xmod import XmodOutput, XmodSelfAttention, XmodSelfOutput -from ...composition import adjust_tensors_for_parallel +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin @@ -65,6 +65,8 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) use_cache = past_key_value is not None if self.is_decoder: diff --git a/tests_adapters/composition/test_adapter_composition.py b/tests_adapters/composition/test_adapter_composition.py index 2670488cb9..42e1f64c1b 100644 --- a/tests_adapters/composition/test_adapter_composition.py +++ b/tests_adapters/composition/test_adapter_composition.py @@ -3,7 +3,7 @@ import torch import adapters -from adapters import PrefixTuningConfig, SeqBnConfig +from adapters import IA3Config, LoRAConfig, PrefixTuningConfig, SeqBnConfig from adapters.composition import Average, BatchSplit, Fuse, Parallel, Split, Stack, parse_composition from tests.test_modeling_common import ids_tensor from transformers import BertConfig, BertForSequenceClassification @@ -140,9 +140,9 @@ def test_parallel(self): model.set_active_adapters(Parallel("a", "b", "c", "d")) inputs = {} - inputs["input_ids"] = ids_tensor((1, 128), 1000) + inputs["input_ids"] = ids_tensor((2, 10), 1000) logits = model(**inputs).logits - self.assertEqual(logits.shape, (4, 2)) + self.assertEqual(logits.shape, (8, 2)) def test_nested_parallel(self): if Parallel in self.unsupported_blocks or Stack in self.unsupported_blocks: @@ -152,7 +152,7 @@ def test_nested_parallel(self): model.set_active_adapters(Stack("a", Parallel(Stack("b", "c"), "d"))) inputs = {} - inputs["input_ids"] = ids_tensor((1, 128), 1000) + inputs["input_ids"] = ids_tensor((1, 10), 1000) logits = model(**inputs).logits self.assertEqual(logits.shape, (2, 2)) @@ -234,3 +234,17 @@ class PrefixTuningCompositionTest(AdapterCompositionTest): def get_adapter_config(self): return PrefixTuningConfig() + + +class LoRACompositionTest(AdapterCompositionTest): + unsupported_blocks = [Split, Fuse] + + def get_adapter_config(self): + return LoRAConfig(init_weights="bert") + + +class IA3CompositionTest(AdapterCompositionTest): + unsupported_blocks = [Split, Fuse] + + def get_adapter_config(self): + return IA3Config() From d45d9511b70e6b78abe4866d86af3af2a8bdba31 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 19 Nov 2023 14:49:56 +0100 Subject: [PATCH 2/2] Move contributing guide --- CONTRIBUTING.md | 78 ------------------------------------------- README.md | 4 +++ docs/contributing.md | 79 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 82 insertions(+), 79 deletions(-) delete mode 100644 CONTRIBUTING.md mode change 120000 => 100644 docs/contributing.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index c9fd05e828..0000000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,78 +0,0 @@ -# Contributing to AdapterHub - -There are many ways in which you can contribute to AdapterHub and the `adapters` library. -This includes code contributions such as: -- implementing new adapter methods -- adding support for new Transformer -- fixing open issues - -as well as non-code contributions such as: -- training and uploading adapters to the Hub -- writing documentation and blog posts -- helping others with their issues and questions - -Whichever way you'd like to contribute, you're very welcome to do so! - -## Contributing to the `adapters` codebase - -### Setting up your dev environment - -To get started with writing code for `adapters`, you'd want to set up the project on a local development environment. - -`adapters` closely follows the original Hugging Face Transformers repository in many aspects. -This guide assumes that you want to set up your dev environment on a local machine and that you have basic knowledge of `git`. -Additionally, you require **Python 3.8** or above pre-installed to get started. - -In the following, we go through the setup procedure step by step: - -1. Fork [the `adapters` repository](https://github.com/adapter-hub/adapters) to get a local copy of the code under your user account. -2. Clone your fork to your local machine: - ``` - git clone --recursive git@github.com:/adapters.git - cd adapters - ``` - **Note:** The `--recursive` flag is important to initialize git submodules. -3. Create a virtual environment, e.g. via `virtualenv` or `conda`. -4. Install PyTorch, following the installation command for your environment [on their website](https://pytorch.org/get-started/locally/). -5. Install Hugging Face Transformers from the local git submodule: - ``` - pip install ./hf_transformers - ``` -6. Install `adapters` and required dev dependencies: - ``` - pip install -e ".[dev]" - ``` - -### Adding Adapter Methods - -How to integrate new efficient fine-tuning/ adapter methods to `adapters` is described at [https://docs.adapterhub.ml/contributing/adding_adapter_methods.html](https://docs.adapterhub.ml/contributing/adding_adapter_methods.html). - -### Adding Adapters to a Model - -How to add adapter support to a model type already supported by Hugging Face Transformers is described at [https://docs.adapterhub.ml/contributing/adding_adapters_to_a_model.html](https://docs.adapterhub.ml/contributing/adding_adapters_to_a_model.html). - -### Testing your changes to the codebase - -`adapters` provides multiple Makefile targets for easily running tests and repo checks. -Make sure these checks run without errors to pass the CI pipeline tasks when you open a pull request. - -To **run all tests** in the repository: -``` -make test -``` - -To **auto format code and imports** in the whole codebase: -``` -make style -``` -This will run `black` and `isort`. - -To **run all quality checks** ensuring code style and repo consistency: -``` -make quality -``` -This will run checks with `black`, `isort` and `flake8` as well as additional custom checks. - -## Contributing Adapters to the Hub - -How to make your own trained adapters accessible via AdapterHub is described at [https://docs.adapterhub.ml/hub_contributing.html](https://docs.adapterhub.ml/hub_contributing.html). diff --git a/README.md b/README.md index 8ec212943c..799fc8438f 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,10 @@ Currently, adapters integrates all architectures and methods listed below: We currently support the PyTorch versions of all models listed on the **[Model Overview](https://docs.adapterhub.ml/model_overview.html) page** in our documentation. +## Developing & Contributing + +To get started with developing on _Adapters_ yourself and learn more about ways to contribute, please see https://docs.adapterhub.ml/contributing.html. + ## Citation If you use this library for your work, please consider citing our paper [AdapterHub: A Framework for Adapting Transformers](https://arxiv.org/abs/2007.07779): diff --git a/docs/contributing.md b/docs/contributing.md deleted file mode 120000 index 4daa8ec339..0000000000 --- a/docs/contributing.md +++ /dev/null @@ -1 +0,0 @@ -../contributing.md \ No newline at end of file diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 0000000000..c9fd05e828 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,78 @@ +# Contributing to AdapterHub + +There are many ways in which you can contribute to AdapterHub and the `adapters` library. +This includes code contributions such as: +- implementing new adapter methods +- adding support for new Transformer +- fixing open issues + +as well as non-code contributions such as: +- training and uploading adapters to the Hub +- writing documentation and blog posts +- helping others with their issues and questions + +Whichever way you'd like to contribute, you're very welcome to do so! + +## Contributing to the `adapters` codebase + +### Setting up your dev environment + +To get started with writing code for `adapters`, you'd want to set up the project on a local development environment. + +`adapters` closely follows the original Hugging Face Transformers repository in many aspects. +This guide assumes that you want to set up your dev environment on a local machine and that you have basic knowledge of `git`. +Additionally, you require **Python 3.8** or above pre-installed to get started. + +In the following, we go through the setup procedure step by step: + +1. Fork [the `adapters` repository](https://github.com/adapter-hub/adapters) to get a local copy of the code under your user account. +2. Clone your fork to your local machine: + ``` + git clone --recursive git@github.com:/adapters.git + cd adapters + ``` + **Note:** The `--recursive` flag is important to initialize git submodules. +3. Create a virtual environment, e.g. via `virtualenv` or `conda`. +4. Install PyTorch, following the installation command for your environment [on their website](https://pytorch.org/get-started/locally/). +5. Install Hugging Face Transformers from the local git submodule: + ``` + pip install ./hf_transformers + ``` +6. Install `adapters` and required dev dependencies: + ``` + pip install -e ".[dev]" + ``` + +### Adding Adapter Methods + +How to integrate new efficient fine-tuning/ adapter methods to `adapters` is described at [https://docs.adapterhub.ml/contributing/adding_adapter_methods.html](https://docs.adapterhub.ml/contributing/adding_adapter_methods.html). + +### Adding Adapters to a Model + +How to add adapter support to a model type already supported by Hugging Face Transformers is described at [https://docs.adapterhub.ml/contributing/adding_adapters_to_a_model.html](https://docs.adapterhub.ml/contributing/adding_adapters_to_a_model.html). + +### Testing your changes to the codebase + +`adapters` provides multiple Makefile targets for easily running tests and repo checks. +Make sure these checks run without errors to pass the CI pipeline tasks when you open a pull request. + +To **run all tests** in the repository: +``` +make test +``` + +To **auto format code and imports** in the whole codebase: +``` +make style +``` +This will run `black` and `isort`. + +To **run all quality checks** ensuring code style and repo consistency: +``` +make quality +``` +This will run checks with `black`, `isort` and `flake8` as well as additional custom checks. + +## Contributing Adapters to the Hub + +How to make your own trained adapters accessible via AdapterHub is described at [https://docs.adapterhub.ml/hub_contributing.html](https://docs.adapterhub.ml/hub_contributing.html).