Skip to content

Commit

Permalink
PrefixTuningShim -> PrefixTuningLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Oct 9, 2023
1 parent 7770f43 commit 55fdc0c
Show file tree
Hide file tree
Showing 16 changed files with 34 additions and 32 deletions.
2 changes: 1 addition & 1 deletion docs/contributing/adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Now that we have discussed the purpose of every file in `src/adapters/models/<mo
- To figure out which classes to change, think about where to insert LoRA, Prefix Tuning, and bottleneck adapters.
- You can use similar model implementations for guidance.
- Often, existing mixins of another class can be reused. E.g. `BertLayer`, `RobertaLayer`, `XLMRobertaLayer`, `DebertaLayer`, `DebertaV2Layer` and `BertGenerationLayer` (all models derived from BERT) use the `BertLayerAdaptersMixin`.
- To additionally support Prefix Tuning, it's necessary to apply the forward call to the `PrefixTuningShim` module in the respective attention layer (see step 3 for how to modify the code of an Hugging Face class).
- To additionally support Prefix Tuning, it's necessary to apply the forward call to the `PrefixTuningLayer` module in the respective attention layer (see step 3 for how to modify the code of an Hugging Face class).
- Make sure the calls to `adapter_layer_forward()` are added in the right places.
- The mixin for the whole base model class (e.g., `BertModel`) should derive from `ModelBaseAdaptersMixin` and (if possible) `EmbeddingAdaptersMixin` and/or `InvertibleAdaptersMixin`. This mixin should at least implement the `iter_layers()` method but might require additional modifications depending on the architecture.
- If the model is a combination of different models, such as the EncoderDecoderModel, use `ModelUsingSubmodelsAdaptersMixin` instead of `ModelBaseAdaptersMixin`.
Expand Down
6 changes: 3 additions & 3 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class PrefixTuningPool(nn.Module):
How it works:
1. A `PrefixTuningShim` module that sets this module as pool module is added to each layer.
1. A `PrefixTuningLayer` module that sets this module as pool module is added to each layer.
2. On adding a prefix, each shim module where a prefix should be added increments a counter in `prefix_counts`.
3. Finally, the base model class confirms adding a new prefix by calling `confirm_prefix()`.
4. This module adds a prefix layer that produces outputs corresponding to the indicated number of layers.
Expand All @@ -135,7 +135,7 @@ class PrefixTuningPool(nn.Module):
- The forward call to this layer is executed in the ForwardContext of each model pass.
- All other methods of this class (except for `confirm_prefix()`) should be called exclusively by
`PrefixTuningShim`.
`PrefixTuningLayer`.
Args:
config (:class:`~transformers.PretrainedConfig`): The model config.
Expand Down Expand Up @@ -265,7 +265,7 @@ class PrefixTuningState(NamedTuple):
idx_slice: Optional[slice] = None


class PrefixTuningShim(ComposableAdapterLayerBase, nn.Module):
class PrefixTuningLayer(ComposableAdapterLayerBase, nn.Module):
"""
Representation of a Prefix Tuning layer within one Transformer layer. This class implements `AdapterLayerBase` for
compatibility with adapters. It uses `PrefixTuningPool` in the background and `set_pool()` must be called after
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .methods.bottleneck import BottleneckLayer
from .methods.lora import LoRALayer
from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters
from .methods.prefix_tuning import PrefixTuningPool, PrefixTuningShim
from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc
from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config

Expand Down Expand Up @@ -368,7 +368,7 @@ def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)

def _link_prefix_to_pool(self, layer):
if isinstance(layer, PrefixTuningShim):
if isinstance(layer, PrefixTuningLayer):
layer.set_pool(self.base_model.prefix_tuning)

@property
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/albert/mixin_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ...composition import adjust_tensors_for_parallel_
from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin


Expand All @@ -20,7 +20,7 @@ def init_adapters(self, model_config, adapters_config):

self.attention_adapters = BottleneckLayer("mh_adapter")

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)

Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/bart/mixin_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...composition import adjust_tensors_for_parallel
from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
Expand All @@ -25,7 +25,7 @@ def init_adapters(self, model_config, adapters_config):
self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v")
self.q_proj = LoRALinear.wrap(self.q_proj, "selfattn", model_config, adapters_config, attn_key="q")

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)

Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/beit/mixin_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import ModelBaseAdaptersMixin


Expand All @@ -17,7 +17,7 @@ def init_adapters(self, model_config, adapters_config):
self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k")
self.value = LoRALinear.wrap(self.value, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)

Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/bert/mixin_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...composition import adjust_tensors_for_parallel_
from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin


Expand All @@ -22,7 +22,7 @@ def init_adapters(self, model_config, adapters_config):
self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k")
self.value = LoRALinear.wrap(self.value, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)

Expand Down
6 changes: 4 additions & 2 deletions src/adapters/models/clip/mixin_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ...composition import adjust_tensors_for_parallel_
from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
Expand All @@ -24,7 +24,9 @@ def init_adapters(self, model_config, adapters_config):
self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k")
self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningShim("self_prefix", model_config, adapters_config, add_model_type_to_key=True)
self.prefix_tuning = PrefixTuningLayer(
"self_prefix", model_config, adapters_config, add_model_type_to_key=True
)


class CLIPEncoderLayerAdaptersMixin:
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/deberta/mixin_deberta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ...methods.lora import MergedLinear as LoRAMergedLinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer


class DebertaSelfAttentionAdaptersMixin:
Expand All @@ -9,6 +9,6 @@ def init_adapters(self, model_config, adapters_config):
# Wrap layers for LoRA
self.in_proj = LoRAMergedLinear.wrap(self.in_proj, "selfattn", model_config, adapters_config)

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
4 changes: 2 additions & 2 deletions src/adapters/models/deberta_v2/mixin_deberta_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer


class DebertaV2SelfAttentionAdaptersMixin:
Expand All @@ -11,6 +11,6 @@ def init_adapters(self, model_config, adapters_config):
self.key_proj = LoRALinear.wrap(self.key_proj, "selfattn", model_config, adapters_config, attn_key="k")
self.value_proj = LoRALinear.wrap(self.value_proj, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
4 changes: 2 additions & 2 deletions src/adapters/models/distilbert/mixin_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin


Expand All @@ -17,7 +17,7 @@ def init_adapters(self, model_config, adapters_config):
self.k_lin = LoRALinear.wrap(self.k_lin, "selfattn", model_config, adapters_config, attn_key="k")
self.v_lin = LoRALinear.wrap(self.v_lin, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningShim("self", model_config, adapters_config)
self.prefix_tuning = PrefixTuningLayer("self", model_config, adapters_config)


class DistilBertTransfomerBlockAdaptersMixin:
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/gpt2/mixin_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.lora import MergedLinear as LoRAMergedLinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin


Expand All @@ -25,7 +25,7 @@ def init_adapters(self, model_config, adapters_config):
)

location_key = "cross_prefix" if self.is_cross_attention else "self_prefix"
self.prefix_tuning = PrefixTuningShim(location_key, model_config, adapters_config)
self.prefix_tuning = PrefixTuningLayer(location_key, model_config, adapters_config)


class GPT2DecoderBlockAdaptersMixin:
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/gptj/mixin_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin


Expand All @@ -17,7 +17,7 @@ def init_adapters(self, model_config, adapters_config):
self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k")
self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)

Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/llama/mixin_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin


Expand All @@ -14,7 +14,7 @@ def init_adapters(self, model_config, adapters_config):
self.k_proj = LoRALinear.wrap(self.k_proj, "selfattn", model_config, adapters_config, attn_key="k")
self.v_proj = LoRALinear.wrap(self.v_proj, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningShim("self_prefix", model_config, adapters_config)
self.prefix_tuning = PrefixTuningLayer("self_prefix", model_config, adapters_config)


class LlamaDecoderLayerMixin:
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/t5/mixin_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import (
EmbeddingAdaptersMixin,
InvertibleAdaptersMixin,
Expand All @@ -23,7 +23,7 @@ def init_adapters(self, model_config, adapters_config):
self.k = LoRALinear.wrap(self.k, "selfattn", model_config, adapters_config, attn_key="k", bias=False)
self.v = LoRALinear.wrap(self.v, "selfattn", model_config, adapters_config, attn_key="v", bias=False)

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)

Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/vit/mixin_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import Linear as LoRALinear
from ...methods.prefix_tuning import PrefixTuningShim
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import ModelBaseAdaptersMixin


Expand All @@ -17,7 +17,7 @@ def init_adapters(self, model_config, adapters_config):
self.key = LoRALinear.wrap(self.key, "selfattn", model_config, adapters_config, attn_key="k")
self.value = LoRALinear.wrap(self.value, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningShim(
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)

Expand Down

0 comments on commit 55fdc0c

Please sign in to comment.