diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index a917828e7..d6a0b17fd 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -64,6 +64,7 @@ "SeqBnInvConfig", "StaticAdapterFusionConfig", "UniPELTConfig", + "VeraConfig", ], "context": [ "AdapterSetup", @@ -181,6 +182,7 @@ SeqBnInvConfig, StaticAdapterFusionConfig, UniPELTConfig, + VeraConfig, ) from .context import AdapterSetup, ForwardContext from .heads import ( diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index b5249cb9f..a6cc14917 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -478,11 +478,20 @@ class LoRAConfig(AdapterConfig): (addition of decomposed matrix, as in LoRA) or "scale" (element-wise multiplication of vector, as in (IA)^3). "scale" can only be used together with r=1. Defaults to "add". init_weights (:obj:`str`, optional): Initialization method for the weights of the LoRA modules. - Currently, this can be either "lora" (default) or "bert". + Currently, this can be either "lora" (default) or "bert", or "vera". use_gating (:obj:`bool`, optional): Place a trainable gating module besides the added parameter module to control module activation. This is e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using `merge_adapter()`. + vera_d (:obj:`float`, optional): + The value of d used in the VeraConfig. Defaults to None. Places a trainable + scaling parameter `d` before the decomposition matrix A to allow scaling of the + internal weights. + + vera_b (:obj:`float`, optional): + The value of b used in the VeraConfig. Defaults to None. Places a trainable + scaling parameter `b` before the decomposition matrix B to allow scaling of the + internal weights. dtype (str, optional): torch dtype for reparametrization tensors. Defaults to None. """ @@ -500,6 +509,8 @@ class LoRAConfig(AdapterConfig): composition_mode: str = "add" init_weights: str = "lora" use_gating: bool = False + vera_d: float = None + vera_b: float = None dtype: Optional[str] = None @@ -526,6 +537,29 @@ class IA3Config(LoRAConfig): dtype: Optional[str] = None +@dataclass(eq=False) +class VeraConfig(LoRAConfig): + """ + Lora Config that applies vector-based random matrix adaptation. It adds + trainable matrices 'd' and 'b' while keeping the original LoRA matrices + frozen, random, and shared across layers. See more through their paper: + https://arxiv.org/pdf/2310.11454. Note that `r` will still be supplied + since we are still initializing decomposition matrices A and B. + The `composition_mode` parameter should also be set to `add`. + """ + + selfattn_lora: bool = True + intermediate_lora: bool = False + output_lora: bool = False + + r: int = 8 + vera_d: float = 0.1 + vera_b: float = 0.0 + init_weights: str = "vera" + composition_mode: str = "add" + dtype: Optional[str] = None + + @dataclass(eq=False) class ReftConfig(AdapterConfig): """ diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index d56a11a91..fe6730058 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -16,6 +16,7 @@ from ..composition import Average, BatchSplit, Parallel, Stack from ..configuration import LoRAConfig, ModelAdaptersConfig +from ..context import ForwardContext from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase from .utils import dequantize_bnb_weight @@ -37,6 +38,7 @@ def __init__( lora_B_shape, config: LoRAConfig, gating_heads: int = 1, + name: str = None, ): super().__init__() assert config.composition_mode == "add", "LoRA module only supports composition_mode='add'." @@ -45,6 +47,7 @@ def __init__( self.composition_mode = config.composition_mode self.attn_matrices = config.attn_matrices self.use_gating = config.use_gating + self.name = name # Optional dropout if config.dropout > 0.0: self.lora_dropout = nn.Dropout(p=config.dropout) @@ -69,6 +72,9 @@ def __init__( elif config.init_weights == "ia3": nn.init.ones_(self.lora_A) nn.init.ones_(self.lora_B) + elif config.init_weights == "vera": + nn.init.kaiming_uniform_(self.lora_A) + nn.init.kaiming_uniform_(self.lora_B) else: raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) @@ -111,6 +117,7 @@ def __init__( lora_B_shape, config: LoRAConfig, gating_heads: int = 1, + name: str = None, ): super().__init__() assert config.composition_mode == "scale", "IA3 module only supports composition_mode='scale'." @@ -121,6 +128,7 @@ def __init__( self.composition_mode = config.composition_mode self.attn_matrices = config.attn_matrices self.use_gating = config.use_gating + self.name = name # Optional dropout if config.dropout > 0.0: raise ValueError("IA3 module does not support dropout.") @@ -132,7 +140,7 @@ def __init__( # For compatibility with LoRA, allow all init_weights types here. # Usually should be "ia3". if config.init_weights == "lora": - logger.warning("(IA)^3 module initialized with LoRA zeo init. Ignore if this is intended.") + logger.warning("(IA)^3 module initialized with LoRA zero init. Ignore if this is intended.") nn.init.zeros_(self.lora_B) elif config.init_weights == "bert": nn.init.normal_(self.lora_B, std=0.02) @@ -175,6 +183,116 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens return hidden_states, gate +class Vera(nn.Module): + def __init__( + self, + lora_A_shape, + lora_B_shape, + config: LoRAConfig, + gating_heads: int = 1, + name: str = None, + ): + super().__init__() + self.d = config.vera_d + self.b = config.vera_b + self.r = config.r + self.alpha = config.alpha + self.use_gating = config.use_gating + self.name = name + + # check to make sure that the `composition_mode` is set to `add` + self.composition_mode = config.composition_mode + if self.composition_mode != "add": + raise ValueError("Vera module only supports composition_mode='add'.") + + # Optional dropout + if config.dropout > 0.0: + self.lora_dropout = nn.Dropout(p=config.dropout) + + self.lora_A_shape = lora_A_shape + self.lora_B_shape = lora_B_shape + self.d_shape = self.lora_A_shape[0] + self.b_shape = self.lora_B_shape[0] + + # Actual trainable parameters + self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape) * self.d)) + self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape) * self.b)) + self.scaling = self.alpha / self.r + + if self.use_gating: + self.gate = nn.Linear(lora_A_shape[-1], gating_heads) + nn.init.normal_(self.gate.weight, std=0.02) + + @property + def delta_w(self) -> torch.Tensor: + parameters = ForwardContext.get_context().shared_parameters[self.name] + lora_A = parameters["lora_A"] + lora_B = parameters["lora_B"] + return self.vera_B @ lora_B @ self.vera_D @ lora_A + + def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor: + """Performs the composition operation between existing and injected weights.""" + if scaling is None: + scaling = self.scaling + return weights + added * scaling + + def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: + """Inverts the composition operation between existing and injected weights.""" + return weights - added * self.scaling + + def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): + parameters = ForwardContext.get_context().shared_parameters[self.name] + lora_A = parameters["lora_A"] + lora_B = parameters["lora_B"] + + if hidden_states is None: + hidden_states = layer_input + + if getattr(self, "lora_dropout"): + hidden_states = self.lora_dropout(hidden_states) + + hidden_states = hidden_states @ torch.t(self.vera_B @ lora_B @ self.vera_D @ lora_A) + + if self.use_gating: + gate = torch.sigmoid(self.gate(layer_input)) + gate = torch.mean(gate, dim=1).unsqueeze(-1) + hidden_states = hidden_states * gate + else: + gate = None + hidden_states = hidden_states * self.scaling + + return hidden_states, gate + + +def init_shared_vera_parameters(model_config, adapter_config, device): + hidden_size = model_config.hidden_size + r = adapter_config["r"] + + parameters = nn.ParameterDict() + + # initialize frozen, random tensors A, B + parameters["lora_A"] = torch.zeros(r, hidden_size).to(device) + parameters["lora_B"] = torch.zeros(hidden_size, r).to(device) + + if adapter_config["init_weights"] == "lora": + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(parameters["lora_A"], a=math.sqrt(5)) + nn.init.zeros_(parameters["lora_B"]) + elif adapter_config["init_weights"] == "bert": + nn.init.normal_(parameters["lora_A"], std=0.02) + nn.init.normal_(parameters["lora_B"], std=0.02) + elif adapter_config["init_weights"] == "ia3": + nn.init.ones_(parameters["lora_A"]) + nn.init.ones_(parameters["lora_B"]) + elif adapter_config["init_weights"] == "vera": + nn.init.kaiming_uniform_(parameters["lora_A"]) + nn.init.kaiming_uniform_(parameters["lora_B"]) + else: + raise ValueError("Unknown init_weights type: {}".format(adapter_config["init_weights"])) + + return parameters + + class LoRALayer(AdapterLayerBase): adapter_modules_name = "loras" @@ -200,6 +318,7 @@ def _get_lora_shapes(self, config: LoRAConfig): def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: self.layer_idx = layer_idx + lora_config = self.adapters_config.match( adapter_name, config_type=LoRAConfig, @@ -208,7 +327,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: ) if lora_config is not None and self._check_lora_location(lora_config): if lora_config.composition_mode == "add": - lora_cls = LoRA + if isinstance(lora_config.vera_d, float) or isinstance(lora_config.vera_b, float): + lora_cls = Vera + else: + lora_cls = LoRA elif lora_config.composition_mode == "scale": lora_cls = IA3 else: @@ -217,7 +339,9 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: *self._get_lora_shapes(lora_config), lora_config, gating_heads=self.get_n_heads(lora_config), + name=adapter_name, ) + lora.train(self.training) lora = lora.to(self.weight.device) self.loras[adapter_name] = lora diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 659a6cfcf..5e142ad4c 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -10,7 +10,7 @@ import torch from torch import nn -from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig +from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig, VeraConfig from transformers import GenerationConfig from transformers.modeling_outputs import ModelOutput from transformers.utils import is_accelerate_available @@ -22,7 +22,7 @@ from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader from .methods.adapter_layer_base import AdapterLayerBase from .methods.bottleneck import BottleneckLayer -from .methods.lora import LoRALayer +from .methods.lora import LoRALayer, init_shared_vera_parameters from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool from .methods.prompt_tuning import PromptTuningLayer @@ -610,13 +610,21 @@ def _add_adapter_weights(self, adapter_name: str): ) else: raise ValueError( - "The model has different hidden sizes {}. Sharing comapcter weights is only possible if" + "The model has different hidden sizes {}. Sharing compacter weights is only possible if" " the hidden_sizes match.".format(hidden_sizes) ) else: self.base_model.shared_parameters[adapter_name] = init_shared_parameters( adapter_config, self.config.hidden_size, self.device ) + + # Vera Initialization + if self.adapters_config.match(adapter_name, VeraConfig): + adapter_config = self.adapters_config.match(adapter_name, VeraConfig) + self.base_model.shared_parameters[adapter_name] = init_shared_vera_parameters( + self.config, adapter_config, self.device + ) + # Prefix Tuning for module in self.modules(): if isinstance(module, PrefixTuningPool):