Skip to content

Commit

Permalink
Extended interface for more bottleneck support
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Dec 25, 2024
1 parent 3e90a3a commit ea85e43
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 36 deletions.
14 changes: 13 additions & 1 deletion src/adapters/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from typing import List, Optional


class AdapterType:
Expand Down Expand Up @@ -32,6 +32,11 @@ class AdapterModelInterface:
attn_o_proj (str): Name of the output projection layer in an attention layer.
layer_intermediate_proj (str): Name of the intermediate projection layer in a transformer layer.
layer_output_proj (str): Name of the output projection layer in a transformer layer.
layer_pre_self_attn (Optional[str]): Hook point directly before the self attention layer. Used for extended bottleneck adapter support.
layer_pre_cross_attn (Optional[str]): Hook point directly before the cross attention layer. Used for extended bottleneck adapter support.
layer_pre_ffn (Optional[str]): Hook point directly before the feed forward layer. Used for extended bottleneck adapter support.
layer_ln_1 (Optional[str]): Layer norm *after* the self-attention layer. Used for extended bottleneck adapter support.
layer_ln_2 (Optional[str]): Layer norm *after* the feed forward layer. Used for extended bottleneck adapter support.
"""

adapter_types: List[str]
Expand All @@ -48,3 +53,10 @@ class AdapterModelInterface:

layer_intermediate_proj: str
layer_output_proj: str

# Optional attributes for extended bottleneck adapter support
layer_pre_self_attn: Optional[str] = None
layer_pre_cross_attn: Optional[str] = None
layer_pre_ffn: Optional[str] = None
layer_ln_1: Optional[str] = None
layer_ln_2: Optional[str] = None
71 changes: 56 additions & 15 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from .modeling import Adapter, BertFusion, ParallelAdapter


LAYER_HOOK_UNSUPPORTED = [
("original_ln_after", False),
]


class BottleneckState(NamedTuple):
"""
Models the input and output states of a bottleneck adapter layer.
Expand Down Expand Up @@ -83,6 +88,15 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
'{"1": 16, "default": 16}'
)

# check unsupported configurations for layer hooking mode
if self.is_layer_hooked:
for key, value in LAYER_HOOK_UNSUPPORTED:
if adapter_config.get(key, None) == value:
raise ValueError(
f"Unsupported configuration for bottleneck layer hooking mode: {key}={value}. "
"Please set this configuration to a supported value."
)

if adapter_config.is_parallel:
adapter_class = ParallelAdapter
else:
Expand All @@ -93,6 +107,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
down_sample=int(self.model_config.hidden_size // reduction_factor),
config=adapter_config,
)
# for adapters hooked via interface:
# residual & LN are applied by model, so don't apply in adapters
if self.is_layer_hooked:
adapter.original_ln_after = False
adapter.train(self.training) # make sure training mode is consistent
self.adapters[adapter_name] = adapter
return True
Expand Down Expand Up @@ -338,15 +356,12 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm):
state = self.compose(adapter_setup, state)
hidden_states, residual_input, _, _, _, last = state

if not self.is_layer_hooked:
last_adapter = self.adapters[last]
hidden_states = last_adapter.post_forward(
hidden_states, input_hidden_states, residual_input, layer_norm
)
last_adapter = self.adapters[last]
hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm)

elif layer_norm:
elif layer_norm is not None and not self.is_layer_hooked:
hidden_states = layer_norm(hidden_states + residual_input)
elif residual_input is not None:
elif residual_input is not None and not self.is_layer_hooked:
hidden_states = hidden_states + residual_input

return hidden_states
Expand All @@ -365,26 +380,52 @@ def forward(self, hidden_states, residual_input, layer_norm):
return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm)


def hook_fn(adapter_layer, module, args, output):
# TODO: we currently cannot reliably pass residual input and layer norm here. This means "is_parallel" and "original_ln_before" are not supported.
def hook_fn(adapter_layer, ln_get_fn, module, args, output):
# Retrieve residual from previous hook, if existing
context = ForwardContext.get_context()
residual_input = getattr(context, f"{adapter_layer.location_key}_residual_input", None)
# Retrieve layer norm from getter fn
if ln_get_fn is not None:
layer_norm = ln_get_fn()
else:
layer_norm = None
# Call adapter layer
if isinstance(output, torch.Tensor):
return adapter_layer(output, None, None)
return adapter_layer(output, residual_input, layer_norm)
else:
return (adapter_layer(output[0], None, None),) + output[1:]
return (adapter_layer(output[0], residual_input, layer_norm),) + output[1:]


def _residual_hook_fn(location_key, module, args):
context = ForwardContext.get_context()
if context is not None:
setattr(context, f"{location_key}_residual_input", args[0])


def init_bottleneck(model):
for i, layer in model.iter_layers():
for _, layer in model.iter_layers():
if self_attn := multigetattr(layer, model.adapter_interface.layer_self_attn, None):
if o_proj := multigetattr(self_attn, model.adapter_interface.attn_o_proj, None):
if not hasattr(layer, "attention_adapters"):
layer.attention_adapters = BottleneckLayer("mh_adapter", is_layer_hooked=True)
o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters))
ln_1_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_1, None)
o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters, ln_1_get_fn))
if layer_output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj, None):
if not hasattr(layer, "output_adapters"):
layer.output_adapters = BottleneckLayer("output_adapter", is_layer_hooked=True)
layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters))
ln_2_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_2, None)
layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters, ln_2_get_fn))
if cross_attn := multigetattr(layer, model.adapter_interface.layer_cross_attn, None):
if not hasattr(cross_attn, "cross_attention_adapters"):
layer.attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True)
cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters))
cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters, None))

if model.adapter_interface.layer_pre_self_attn is not None:
if pre_self_attn := multigetattr(layer, model.adapter_interface.layer_pre_self_attn, None):
pre_self_attn.register_forward_pre_hook(partial(_residual_hook_fn, "mh_adapter"))
if model.adapter_interface.layer_pre_cross_attn is not None:
if pre_cross_attn := multigetattr(layer, model.adapter_interface.layer_pre_cross_attn, None):
pre_cross_attn.register_forward_pre_hook(partial(_residual_hook_fn, "cross_adapter"))
if model.adapter_interface.layer_pre_ffn is not None:
if pre_ffn := multigetattr(layer, model.adapter_interface.layer_pre_ffn, None):
pre_ffn.register_forward_pre_hook(partial(_residual_hook_fn, "output_adapter"))
38 changes: 31 additions & 7 deletions tests/test_custom_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
import torch

import adapters
from adapters import AdapterModelInterface, AdapterSetup, LoRAConfig, load_model
from adapters import AdapterModelInterface, AdapterSetup, DoubleSeqBnConfig, LoRAConfig, ParBnConfig, load_model
from transformers import Gemma2ForCausalLM, Gemma2ForSequenceClassification
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
from transformers.testing_utils import require_torch, torch_device

from .methods import IA3TestMixin, LoRATestMixin, ReftTestMixin, create_twin_models
from .methods import (
BottleneckAdapterTestMixin,
CompacterTestMixin,
IA3TestMixin,
LoRATestMixin,
ReftTestMixin,
create_twin_models,
)
from .test_adapter import AdapterTestBase, make_config
from .test_adapter_embeddings import EmbeddingTestMixin
from .test_adapter_fusion_common import AdapterFusionModelTestMixin


class CustomInterfaceModelTestBase(AdapterTestBase):
Expand All @@ -28,7 +36,7 @@ class CustomInterfaceModelTestBase(AdapterTestBase):
)
tokenizer_name = "yujiepan/gemma-2-tiny-random"
adapter_interface = AdapterModelInterface(
adapter_types=["lora", "reft"],
adapter_types=["bottleneck", "lora", "reft"],
model_embeddings="embed_tokens",
model_layers="layers",
layer_self_attn="self_attn",
Expand All @@ -39,6 +47,11 @@ class CustomInterfaceModelTestBase(AdapterTestBase):
attn_o_proj="o_proj",
layer_intermediate_proj="mlp.up_proj",
layer_output_proj="mlp.down_proj",
layer_pre_self_attn="input_layernorm",
layer_pre_cross_attn=None,
layer_pre_ffn="pre_feedforward_layernorm",
layer_ln_1="post_attention_layernorm",
layer_ln_2="post_feedforward_layernorm",
)

def get_model(self):
Expand All @@ -50,22 +63,27 @@ def get_model(self):

@require_torch
class CustomInterfaceModelTest(
# BottleneckAdapterTestMixin,
# CompacterTestMixin,
BottleneckAdapterTestMixin,
CompacterTestMixin,
IA3TestMixin,
LoRATestMixin,
# PrefixTuningTestMixin,
# PromptTuningTestMixin,
ReftTestMixin,
# UniPELTTestMixin,
EmbeddingTestMixin,
# AdapterFusionModelTestMixin,
# CompabilityTestMixin,
AdapterFusionModelTestMixin,
# ParallelAdapterInferenceTestMixin,
# ParallelTrainingMixin,
CustomInterfaceModelTestBase,
unittest.TestCase,
):
# Modify the list here since we don't support MAMConfig (due to prefix tuning)
adapter_configs_to_test = [
(DoubleSeqBnConfig(), ["adapters.{name}."]),
(ParBnConfig(init_weights="bert"), ["adapters.{name}."]),
]

def create_twin_models(self):
return create_twin_models(self.model_class, self.config, self.adapter_interface)

Expand Down Expand Up @@ -109,5 +127,11 @@ def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, a

return model

def test_load_mam_adapter(self):
self.skipTest("Does not support prefix tuning.")

def test_train_mam_adapter(self):
self.skipTest("Does not support prefix tuning.")

def test_merging_with_other_adapters(self):
self.skipTest("Does not support all required methods yet.")
68 changes: 55 additions & 13 deletions tests/test_custom_interface_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CustomInterfaceCompatTest(unittest.TestCase):
hidden_act="gelu",
pad_token_id=0,
)
llama_adapter_interface = AdapterModelInterface(
llama_interface = AdapterModelInterface(
adapter_types=["bottleneck", "lora", "reft"],
model_embeddings="embed_tokens",
model_layers="layers",
Expand All @@ -46,8 +46,13 @@ class CustomInterfaceCompatTest(unittest.TestCase):
attn_o_proj="o_proj",
layer_intermediate_proj="mlp.up_proj",
layer_output_proj="mlp.down_proj",
layer_pre_self_attn="input_layernorm",
layer_pre_cross_attn=None,
layer_pre_ffn="post_attention_layernorm",
layer_ln_1=None,
layer_ln_2=None,
)
bert_adapter_interface = AdapterModelInterface(
bert_interface = AdapterModelInterface(
adapter_types=["bottleneck", "lora", "reft", "prompt_tuning"],
model_embeddings="embeddings",
model_layers="encoder.layer",
Expand All @@ -59,6 +64,11 @@ class CustomInterfaceCompatTest(unittest.TestCase):
attn_o_proj="output.dense",
layer_intermediate_proj="intermediate.dense",
layer_output_proj="output.dense",
layer_pre_self_attn="attention.self",
layer_pre_cross_attn=None,
layer_pre_ffn="intermediate",
layer_ln_1="attention.output.LayerNorm",
layer_ln_2="output.LayerNorm",
)
bert_bn_rewrites = [(".attention_adapters.", ".attention.output."), (".output_adapters.", ".output.")]

Expand All @@ -73,41 +83,73 @@ def create_twin_models(self, config, adapter_interface, hf_auto_model_class):

@parameterized.expand(
[
("LoRA_Llama", adapters.LoRAConfig(), llama_config, llama_adapter_interface, AutoModelForCausalLM),
("LoRA_BERT", adapters.LoRAConfig(), bert_config, bert_adapter_interface, AutoModel),
("LoReft_Llama", adapters.LoReftConfig(), llama_config, llama_adapter_interface, AutoModelForCausalLM),
("LoReft_BERT", adapters.LoReftConfig(), bert_config, bert_adapter_interface, AutoModel),
("LoRA_Llama", adapters.LoRAConfig(), llama_config, llama_interface, AutoModelForCausalLM),
("LoRA_BERT", adapters.LoRAConfig(), bert_config, bert_interface, AutoModel),
("LoReft_Llama", adapters.LoReftConfig(), llama_config, llama_interface, AutoModelForCausalLM),
("LoReft_BERT", adapters.LoReftConfig(), bert_config, bert_interface, AutoModel),
(
"BnSeq_Llama",
adapters.SeqBnConfig(original_ln_before=False),
llama_config,
llama_adapter_interface,
llama_interface,
AutoModelForCausalLM,
),
(
"BnSeqPreLN_Llama",
adapters.SeqBnConfig(original_ln_before=True),
llama_config,
llama_interface,
AutoModelForCausalLM,
),
("BnPar_Llama", adapters.ParBnConfig(), llama_config, llama_interface, AutoModelForCausalLM),
(
"Bn2Seq_Llama",
adapters.DoubleSeqBnConfig(),
adapters.DoubleSeqBnConfig(original_ln_before=True),
llama_config,
llama_adapter_interface,
llama_interface,
AutoModelForCausalLM,
),
(
"Bn2Par_Llama",
adapters.ParBnConfig(mh_adapter=True, output_adapter=True),
llama_config,
llama_interface,
AutoModelForCausalLM,
),
(
"BnSeq_BERT",
adapters.SeqBnConfig(original_ln_before=False),
bert_config,
bert_adapter_interface,
bert_interface,
AutoModel,
bert_bn_rewrites,
),
(
"BnSeqPreLN_BERT",
adapters.SeqBnConfig(original_ln_before=True),
bert_config,
bert_interface,
AutoModel,
bert_bn_rewrites,
),
("BnPar_BERT", adapters.ParBnConfig(), bert_config, bert_interface, AutoModel, bert_bn_rewrites),
(
"Bn2Seq_BERT",
adapters.DoubleSeqBnConfig(),
adapters.DoubleSeqBnConfig(original_ln_before=True),
bert_config,
bert_interface,
AutoModel,
bert_bn_rewrites,
),
(
"Bn2Par_BERT",
adapters.ParBnConfig(mh_adapter=True, output_adapter=True),
bert_config,
bert_adapter_interface,
bert_interface,
AutoModel,
bert_bn_rewrites,
),
("Prompt_BERT", adapters.PromptTuningConfig(), bert_config, bert_adapter_interface, AutoModel),
("Prompt_BERT", adapters.PromptTuningConfig(), bert_config, bert_interface, AutoModel),
]
)
def test_load_adapter(self, name, adapter_config, config, adapter_interface, hf_auto_model_class, rewrites=None):
Expand Down

0 comments on commit ea85e43

Please sign in to comment.