From ea85e43796e38b456714af1ca0b44ebc316f2f35 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 25 Dec 2024 21:23:35 +0000 Subject: [PATCH] Extended interface for more bottleneck support --- src/adapters/interface.py | 14 +++++- src/adapters/methods/bottleneck.py | 71 +++++++++++++++++++++------ tests/test_custom_interface.py | 38 +++++++++++--- tests/test_custom_interface_compat.py | 68 ++++++++++++++++++++----- 4 files changed, 155 insertions(+), 36 deletions(-) diff --git a/src/adapters/interface.py b/src/adapters/interface.py index 003480cfb..e3e38d6a3 100644 --- a/src/adapters/interface.py +++ b/src/adapters/interface.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Optional class AdapterType: @@ -32,6 +32,11 @@ class AdapterModelInterface: attn_o_proj (str): Name of the output projection layer in an attention layer. layer_intermediate_proj (str): Name of the intermediate projection layer in a transformer layer. layer_output_proj (str): Name of the output projection layer in a transformer layer. + layer_pre_self_attn (Optional[str]): Hook point directly before the self attention layer. Used for extended bottleneck adapter support. + layer_pre_cross_attn (Optional[str]): Hook point directly before the cross attention layer. Used for extended bottleneck adapter support. + layer_pre_ffn (Optional[str]): Hook point directly before the feed forward layer. Used for extended bottleneck adapter support. + layer_ln_1 (Optional[str]): Layer norm *after* the self-attention layer. Used for extended bottleneck adapter support. + layer_ln_2 (Optional[str]): Layer norm *after* the feed forward layer. Used for extended bottleneck adapter support. """ adapter_types: List[str] @@ -48,3 +53,10 @@ class AdapterModelInterface: layer_intermediate_proj: str layer_output_proj: str + + # Optional attributes for extended bottleneck adapter support + layer_pre_self_attn: Optional[str] = None + layer_pre_cross_attn: Optional[str] = None + layer_pre_ffn: Optional[str] = None + layer_ln_1: Optional[str] = None + layer_ln_2: Optional[str] = None diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index 02d5c36aa..159059a83 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -21,6 +21,11 @@ from .modeling import Adapter, BertFusion, ParallelAdapter +LAYER_HOOK_UNSUPPORTED = [ + ("original_ln_after", False), +] + + class BottleneckState(NamedTuple): """ Models the input and output states of a bottleneck adapter layer. @@ -83,6 +88,15 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: '{"1": 16, "default": 16}' ) + # check unsupported configurations for layer hooking mode + if self.is_layer_hooked: + for key, value in LAYER_HOOK_UNSUPPORTED: + if adapter_config.get(key, None) == value: + raise ValueError( + f"Unsupported configuration for bottleneck layer hooking mode: {key}={value}. " + "Please set this configuration to a supported value." + ) + if adapter_config.is_parallel: adapter_class = ParallelAdapter else: @@ -93,6 +107,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: down_sample=int(self.model_config.hidden_size // reduction_factor), config=adapter_config, ) + # for adapters hooked via interface: + # residual & LN are applied by model, so don't apply in adapters + if self.is_layer_hooked: + adapter.original_ln_after = False adapter.train(self.training) # make sure training mode is consistent self.adapters[adapter_name] = adapter return True @@ -338,15 +356,12 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm): state = self.compose(adapter_setup, state) hidden_states, residual_input, _, _, _, last = state - if not self.is_layer_hooked: - last_adapter = self.adapters[last] - hidden_states = last_adapter.post_forward( - hidden_states, input_hidden_states, residual_input, layer_norm - ) + last_adapter = self.adapters[last] + hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm) - elif layer_norm: + elif layer_norm is not None and not self.is_layer_hooked: hidden_states = layer_norm(hidden_states + residual_input) - elif residual_input is not None: + elif residual_input is not None and not self.is_layer_hooked: hidden_states = hidden_states + residual_input return hidden_states @@ -365,26 +380,52 @@ def forward(self, hidden_states, residual_input, layer_norm): return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm) -def hook_fn(adapter_layer, module, args, output): - # TODO: we currently cannot reliably pass residual input and layer norm here. This means "is_parallel" and "original_ln_before" are not supported. +def hook_fn(adapter_layer, ln_get_fn, module, args, output): + # Retrieve residual from previous hook, if existing + context = ForwardContext.get_context() + residual_input = getattr(context, f"{adapter_layer.location_key}_residual_input", None) + # Retrieve layer norm from getter fn + if ln_get_fn is not None: + layer_norm = ln_get_fn() + else: + layer_norm = None + # Call adapter layer if isinstance(output, torch.Tensor): - return adapter_layer(output, None, None) + return adapter_layer(output, residual_input, layer_norm) else: - return (adapter_layer(output[0], None, None),) + output[1:] + return (adapter_layer(output[0], residual_input, layer_norm),) + output[1:] + + +def _residual_hook_fn(location_key, module, args): + context = ForwardContext.get_context() + if context is not None: + setattr(context, f"{location_key}_residual_input", args[0]) def init_bottleneck(model): - for i, layer in model.iter_layers(): + for _, layer in model.iter_layers(): if self_attn := multigetattr(layer, model.adapter_interface.layer_self_attn, None): if o_proj := multigetattr(self_attn, model.adapter_interface.attn_o_proj, None): if not hasattr(layer, "attention_adapters"): layer.attention_adapters = BottleneckLayer("mh_adapter", is_layer_hooked=True) - o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters)) + ln_1_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_1, None) + o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters, ln_1_get_fn)) if layer_output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj, None): if not hasattr(layer, "output_adapters"): layer.output_adapters = BottleneckLayer("output_adapter", is_layer_hooked=True) - layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters)) + ln_2_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_2, None) + layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters, ln_2_get_fn)) if cross_attn := multigetattr(layer, model.adapter_interface.layer_cross_attn, None): if not hasattr(cross_attn, "cross_attention_adapters"): layer.attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True) - cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters)) + cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters, None)) + + if model.adapter_interface.layer_pre_self_attn is not None: + if pre_self_attn := multigetattr(layer, model.adapter_interface.layer_pre_self_attn, None): + pre_self_attn.register_forward_pre_hook(partial(_residual_hook_fn, "mh_adapter")) + if model.adapter_interface.layer_pre_cross_attn is not None: + if pre_cross_attn := multigetattr(layer, model.adapter_interface.layer_pre_cross_attn, None): + pre_cross_attn.register_forward_pre_hook(partial(_residual_hook_fn, "cross_adapter")) + if model.adapter_interface.layer_pre_ffn is not None: + if pre_ffn := multigetattr(layer, model.adapter_interface.layer_pre_ffn, None): + pre_ffn.register_forward_pre_hook(partial(_residual_hook_fn, "output_adapter")) diff --git a/tests/test_custom_interface.py b/tests/test_custom_interface.py index 93cb45fa8..a75f571a6 100644 --- a/tests/test_custom_interface.py +++ b/tests/test_custom_interface.py @@ -4,14 +4,22 @@ import torch import adapters -from adapters import AdapterModelInterface, AdapterSetup, LoRAConfig, load_model +from adapters import AdapterModelInterface, AdapterSetup, DoubleSeqBnConfig, LoRAConfig, ParBnConfig, load_model from transformers import Gemma2ForCausalLM, Gemma2ForSequenceClassification from transformers.models.gemma2.configuration_gemma2 import Gemma2Config from transformers.testing_utils import require_torch, torch_device -from .methods import IA3TestMixin, LoRATestMixin, ReftTestMixin, create_twin_models +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + ReftTestMixin, + create_twin_models, +) from .test_adapter import AdapterTestBase, make_config from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin class CustomInterfaceModelTestBase(AdapterTestBase): @@ -28,7 +36,7 @@ class CustomInterfaceModelTestBase(AdapterTestBase): ) tokenizer_name = "yujiepan/gemma-2-tiny-random" adapter_interface = AdapterModelInterface( - adapter_types=["lora", "reft"], + adapter_types=["bottleneck", "lora", "reft"], model_embeddings="embed_tokens", model_layers="layers", layer_self_attn="self_attn", @@ -39,6 +47,11 @@ class CustomInterfaceModelTestBase(AdapterTestBase): attn_o_proj="o_proj", layer_intermediate_proj="mlp.up_proj", layer_output_proj="mlp.down_proj", + layer_pre_self_attn="input_layernorm", + layer_pre_cross_attn=None, + layer_pre_ffn="pre_feedforward_layernorm", + layer_ln_1="post_attention_layernorm", + layer_ln_2="post_feedforward_layernorm", ) def get_model(self): @@ -50,8 +63,8 @@ def get_model(self): @require_torch class CustomInterfaceModelTest( - # BottleneckAdapterTestMixin, - # CompacterTestMixin, + BottleneckAdapterTestMixin, + CompacterTestMixin, IA3TestMixin, LoRATestMixin, # PrefixTuningTestMixin, @@ -59,13 +72,18 @@ class CustomInterfaceModelTest( ReftTestMixin, # UniPELTTestMixin, EmbeddingTestMixin, - # AdapterFusionModelTestMixin, - # CompabilityTestMixin, + AdapterFusionModelTestMixin, # ParallelAdapterInferenceTestMixin, # ParallelTrainingMixin, CustomInterfaceModelTestBase, unittest.TestCase, ): + # Modify the list here since we don't support MAMConfig (due to prefix tuning) + adapter_configs_to_test = [ + (DoubleSeqBnConfig(), ["adapters.{name}."]), + (ParBnConfig(init_weights="bert"), ["adapters.{name}."]), + ] + def create_twin_models(self): return create_twin_models(self.model_class, self.config, self.adapter_interface) @@ -109,5 +127,11 @@ def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, a return model + def test_load_mam_adapter(self): + self.skipTest("Does not support prefix tuning.") + + def test_train_mam_adapter(self): + self.skipTest("Does not support prefix tuning.") + def test_merging_with_other_adapters(self): self.skipTest("Does not support all required methods yet.") diff --git a/tests/test_custom_interface_compat.py b/tests/test_custom_interface_compat.py index 760dbc753..922ececdd 100644 --- a/tests/test_custom_interface_compat.py +++ b/tests/test_custom_interface_compat.py @@ -34,7 +34,7 @@ class CustomInterfaceCompatTest(unittest.TestCase): hidden_act="gelu", pad_token_id=0, ) - llama_adapter_interface = AdapterModelInterface( + llama_interface = AdapterModelInterface( adapter_types=["bottleneck", "lora", "reft"], model_embeddings="embed_tokens", model_layers="layers", @@ -46,8 +46,13 @@ class CustomInterfaceCompatTest(unittest.TestCase): attn_o_proj="o_proj", layer_intermediate_proj="mlp.up_proj", layer_output_proj="mlp.down_proj", + layer_pre_self_attn="input_layernorm", + layer_pre_cross_attn=None, + layer_pre_ffn="post_attention_layernorm", + layer_ln_1=None, + layer_ln_2=None, ) - bert_adapter_interface = AdapterModelInterface( + bert_interface = AdapterModelInterface( adapter_types=["bottleneck", "lora", "reft", "prompt_tuning"], model_embeddings="embeddings", model_layers="encoder.layer", @@ -59,6 +64,11 @@ class CustomInterfaceCompatTest(unittest.TestCase): attn_o_proj="output.dense", layer_intermediate_proj="intermediate.dense", layer_output_proj="output.dense", + layer_pre_self_attn="attention.self", + layer_pre_cross_attn=None, + layer_pre_ffn="intermediate", + layer_ln_1="attention.output.LayerNorm", + layer_ln_2="output.LayerNorm", ) bert_bn_rewrites = [(".attention_adapters.", ".attention.output."), (".output_adapters.", ".output.")] @@ -73,41 +83,73 @@ def create_twin_models(self, config, adapter_interface, hf_auto_model_class): @parameterized.expand( [ - ("LoRA_Llama", adapters.LoRAConfig(), llama_config, llama_adapter_interface, AutoModelForCausalLM), - ("LoRA_BERT", adapters.LoRAConfig(), bert_config, bert_adapter_interface, AutoModel), - ("LoReft_Llama", adapters.LoReftConfig(), llama_config, llama_adapter_interface, AutoModelForCausalLM), - ("LoReft_BERT", adapters.LoReftConfig(), bert_config, bert_adapter_interface, AutoModel), + ("LoRA_Llama", adapters.LoRAConfig(), llama_config, llama_interface, AutoModelForCausalLM), + ("LoRA_BERT", adapters.LoRAConfig(), bert_config, bert_interface, AutoModel), + ("LoReft_Llama", adapters.LoReftConfig(), llama_config, llama_interface, AutoModelForCausalLM), + ("LoReft_BERT", adapters.LoReftConfig(), bert_config, bert_interface, AutoModel), ( "BnSeq_Llama", adapters.SeqBnConfig(original_ln_before=False), llama_config, - llama_adapter_interface, + llama_interface, AutoModelForCausalLM, ), + ( + "BnSeqPreLN_Llama", + adapters.SeqBnConfig(original_ln_before=True), + llama_config, + llama_interface, + AutoModelForCausalLM, + ), + ("BnPar_Llama", adapters.ParBnConfig(), llama_config, llama_interface, AutoModelForCausalLM), ( "Bn2Seq_Llama", - adapters.DoubleSeqBnConfig(), + adapters.DoubleSeqBnConfig(original_ln_before=True), llama_config, - llama_adapter_interface, + llama_interface, + AutoModelForCausalLM, + ), + ( + "Bn2Par_Llama", + adapters.ParBnConfig(mh_adapter=True, output_adapter=True), + llama_config, + llama_interface, AutoModelForCausalLM, ), ( "BnSeq_BERT", adapters.SeqBnConfig(original_ln_before=False), bert_config, - bert_adapter_interface, + bert_interface, + AutoModel, + bert_bn_rewrites, + ), + ( + "BnSeqPreLN_BERT", + adapters.SeqBnConfig(original_ln_before=True), + bert_config, + bert_interface, AutoModel, bert_bn_rewrites, ), + ("BnPar_BERT", adapters.ParBnConfig(), bert_config, bert_interface, AutoModel, bert_bn_rewrites), ( "Bn2Seq_BERT", - adapters.DoubleSeqBnConfig(), + adapters.DoubleSeqBnConfig(original_ln_before=True), + bert_config, + bert_interface, + AutoModel, + bert_bn_rewrites, + ), + ( + "Bn2Par_BERT", + adapters.ParBnConfig(mh_adapter=True, output_adapter=True), bert_config, - bert_adapter_interface, + bert_interface, AutoModel, bert_bn_rewrites, ), - ("Prompt_BERT", adapters.PromptTuningConfig(), bert_config, bert_adapter_interface, AutoModel), + ("Prompt_BERT", adapters.PromptTuningConfig(), bert_config, bert_interface, AutoModel), ] ) def test_load_adapter(self, name, adapter_config, config, adapter_interface, hf_auto_model_class, rewrites=None):