From 57c51310b052fe62aa6051072ed1ac4ca8314dee Mon Sep 17 00:00:00 2001 From: julian fong Date: Sun, 1 Dec 2024 12:22:12 -0500 Subject: [PATCH 01/15] initial commit --- src/adapters/configuration/adapter_config.py | 33 ++++++++++- src/adapters/methods/lora.py | 59 +++++++++++++++++++- 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 0f2eec2162..9ff882109e 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -478,11 +478,18 @@ class LoRAConfig(AdapterConfig): (addition of decomposed matrix, as in LoRA) or "scale" (element-wise multiplication of vector, as in (IA)^3). "scale" can only be used together with r=1. Defaults to "add". init_weights (:obj:`str`, optional): Initialization method for the weights of the LoRA modules. - Currently, this can be either "lora" (default) or "bert". + Currently, this can be either "lora" (default) or "bert", or "vera". use_gating (:obj:`bool`, optional): Place a trainable gating module besides the added parameter module to control module activation. This is e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using `merge_adapter()`. + d (:obj:`bool` or :obj:`float`, optional): + The value of d used in the VeraConfig. Defaults to None + + b (:obj:`bool` or :obj:`float`, optional): + The value of b used in the VeraConfig. Defaults to None + + """ architecture: Optional[str] = "lora" @@ -499,6 +506,8 @@ class LoRAConfig(AdapterConfig): composition_mode: str = "add" init_weights: str = "lora" use_gating: bool = False + d: Union[bool, float] = None + b: Union[bool, float] = None @dataclass(eq=False) @@ -522,6 +531,28 @@ class IA3Config(LoRAConfig): init_weights: str = "ia3" use_gating: bool = False +@dataclass(eq=False) +class VeraConfig(LoRAConfig): + """ + Lora Config that applies vector-based random matrix adaptation. It adds + trainable matrices 'd' and 'b' while keeping the original LoRA matrices + frozen, random, and shared across layers. See more through their paper: + https://arxiv.org/pdf/2106.09685. Note that `r` will still be supplied + since we are still initializing decomposition matrices A and B, + however the `composition_mode` parameter along with the + `use_gating` parameter will be ignored. + """ + + selfattn_lora: bool = False + intermediate_lora: bool = False + output_lora: bool = False + + r: int = 8 + init_weights: str = "vera" + d: Union[bool, float] = 0.1 + b: Union[bool, float] = 0 + + @dataclass(eq=False) class ReftConfig(AdapterConfig): diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index c62a94f265..a5beedf72d 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -45,6 +45,7 @@ def __init__( self.composition_mode = config.composition_mode self.attn_matrices = config.attn_matrices self.use_gating = config.use_gating + # Optional dropout if config.dropout > 0.0: self.lora_dropout = nn.Dropout(p=config.dropout) @@ -54,7 +55,7 @@ def __init__( # Actual trainable parameters self.lora_A = nn.Parameter(torch.zeros(lora_A_shape)) self.lora_B = nn.Parameter(torch.zeros(lora_B_shape)) - self.scaling = self.lora_alpha / self.r + self.scaling = self.lora_alpha / self.r # For compatibility with (IA)^3, allow all init_weights types here. # Usually should be "lora". @@ -131,7 +132,7 @@ def __init__( # For compatibility with LoRA, allow all init_weights types here. # Usually should be "ia3". if config.init_weights == "lora": - logger.warning("(IA)^3 module initialized with LoRA zeo init. Ignore if this is intended.") + logger.warning("(IA)^3 module initialized with LoRA zero init. Ignore if this is intended.") nn.init.zeros_(self.lora_B) elif config.init_weights == "bert": nn.init.normal_(self.lora_B, std=0.02) @@ -173,7 +174,59 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens return hidden_states, gate +class Vera(nn.Module): + def __init__( + self, + lora_A_shape, + lora_B_shape, + config: LoRAConfig, + ): + super().__init__() + self.d = config.d + self.b = config.b + + self.lora_A_shape = lora_A_shape + self.lora_B_shape = lora_B_shape + self.d_shape = self.lora_A_shape[1] + self.b_shape = self.lora_B_shape[0] + + #initialize frozen, random tensors + self.lora_A = torch.tensor(torch.zeros(lora_A_shape)) + self.lora_B = torch.tensor(torch + .zeros(lora_B_shape)) + + # Actual trainable parameters + self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape)*self.d)) + self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape)*self.b)) + + # For compatibility with LoRA, allow all init_weights types here. + # Usually should be "vera" or "lora". + if config.init_weights == "lora": + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + elif config.init_weights == "bert": + nn.init.normal_(self.lora_A, std=0.02) + nn.init.normal_(self.lora_B, std=0.02) + elif config.init_weights == "ia3": + nn.init.ones_(self.lora_A) + nn.init.ones_(self.lora_B) + elif config.init_weights == "vera": + nn.kaiming.uniform_(self.lora_A) + nn.kaiming.uniform_(self.lora_B) + else: + raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) + + @property + def delta_w(self) -> torch.Tensor: + return self.vera_B @ self.lora_B @ self.vera_D @ self.lora_A + + def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): + if hidden_states is None: + hidden_states = layer_input + hidden_states = self.vera_B @ self.lora_B @ self.vera_D @ self.lora_A + return hidden_states class LoRALayer(AdapterLayerBase): adapter_modules_name = "loras" @@ -212,6 +265,8 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: lora_cls = IA3 else: raise ValueError(f"Unknown composition_mode: {lora_config.composition_mode}") + #figure out good criteria to load vera + # lora = lora_cls( *self._get_lora_shapes(lora_config), lora_config, From 259a2689d54d8fd57249e82fc79459a01f37fc23 Mon Sep 17 00:00:00 2001 From: julian fong Date: Sun, 1 Dec 2024 12:27:48 -0500 Subject: [PATCH 02/15] improved docstring and fixed formatting issues --- src/adapters/configuration/adapter_config.py | 12 ++++++------ src/adapters/methods/lora.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 9ff882109e..188c38b272 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -484,12 +484,14 @@ class LoRAConfig(AdapterConfig): e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using `merge_adapter()`. d (:obj:`bool` or :obj:`float`, optional): - The value of d used in the VeraConfig. Defaults to None + The value of d used in the VeraConfig. Defaults to None. Places a trainable + scaling parameter before the decomposition matrix A to allow scaling of the + internal weights. b (:obj:`bool` or :obj:`float`, optional): - The value of b used in the VeraConfig. Defaults to None - - + The value of b used in the VeraConfig. Defaults to None. Places a trainable + scaling parameter before the decomposition matrix B to allow scaling of the + internal weights. """ architecture: Optional[str] = "lora" @@ -552,8 +554,6 @@ class VeraConfig(LoRAConfig): d: Union[bool, float] = 0.1 b: Union[bool, float] = 0 - - @dataclass(eq=False) class ReftConfig(AdapterConfig): """ diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index a5beedf72d..d1b67fcfea 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -55,7 +55,7 @@ def __init__( # Actual trainable parameters self.lora_A = nn.Parameter(torch.zeros(lora_A_shape)) self.lora_B = nn.Parameter(torch.zeros(lora_B_shape)) - self.scaling = self.lora_alpha / self.r + self.scaling = self.lora_alpha / self.r # For compatibility with (IA)^3, allow all init_weights types here. # Usually should be "lora". From b66571c124f5901073638a103040e883a67b1aaf Mon Sep 17 00:00:00 2001 From: julian fong Date: Sun, 1 Dec 2024 16:21:29 -0500 Subject: [PATCH 03/15] fixed formatting --- src/adapters/configuration/adapter_config.py | 18 ++++++++------ src/adapters/methods/lora.py | 26 +++++++++++--------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 188c38b272..2d940ebcc6 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -485,12 +485,12 @@ class LoRAConfig(AdapterConfig): `merge_adapter()`. d (:obj:`bool` or :obj:`float`, optional): The value of d used in the VeraConfig. Defaults to None. Places a trainable - scaling parameter before the decomposition matrix A to allow scaling of the + scaling parameter before the decomposition matrix A to allow scaling of the internal weights. - + b (:obj:`bool` or :obj:`float`, optional): The value of b used in the VeraConfig. Defaults to None. Places a trainable - scaling parameter before the decomposition matrix B to allow scaling of the + scaling parameter before the decomposition matrix B to allow scaling of the internal weights. """ @@ -533,18 +533,19 @@ class IA3Config(LoRAConfig): init_weights: str = "ia3" use_gating: bool = False + @dataclass(eq=False) class VeraConfig(LoRAConfig): """ Lora Config that applies vector-based random matrix adaptation. It adds trainable matrices 'd' and 'b' while keeping the original LoRA matrices - frozen, random, and shared across layers. See more through their paper: + frozen, random, and shared across layers. See more through their paper: https://arxiv.org/pdf/2106.09685. Note that `r` will still be supplied - since we are still initializing decomposition matrices A and B, - however the `composition_mode` parameter along with the + since we are still initializing decomposition matrices A and B, + however the `composition_mode` parameter along with the `use_gating` parameter will be ignored. """ - + selfattn_lora: bool = False intermediate_lora: bool = False output_lora: bool = False @@ -553,7 +554,8 @@ class VeraConfig(LoRAConfig): init_weights: str = "vera" d: Union[bool, float] = 0.1 b: Union[bool, float] = 0 - + + @dataclass(eq=False) class ReftConfig(AdapterConfig): """ diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index d1b67fcfea..c6f4ccf61e 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -174,6 +174,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens return hidden_states, gate + class Vera(nn.Module): def __init__( self, @@ -184,20 +185,19 @@ def __init__( super().__init__() self.d = config.d self.b = config.b - + self.lora_A_shape = lora_A_shape self.lora_B_shape = lora_B_shape self.d_shape = self.lora_A_shape[1] self.b_shape = self.lora_B_shape[0] - - #initialize frozen, random tensors + + # initialize frozen, random tensors self.lora_A = torch.tensor(torch.zeros(lora_A_shape)) - self.lora_B = torch.tensor(torch - .zeros(lora_B_shape)) - + self.lora_B = torch.tensor(torch.zeros(lora_B_shape)) + # Actual trainable parameters - self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape)*self.d)) - self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape)*self.b)) + self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape) * self.d)) + self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape) * self.b)) # For compatibility with LoRA, allow all init_weights types here. # Usually should be "vera" or "lora". @@ -216,17 +216,19 @@ def __init__( nn.kaiming.uniform_(self.lora_B) else: raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) - + @property def delta_w(self) -> torch.Tensor: return self.vera_B @ self.lora_B @ self.vera_D @ self.lora_A - + def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): if hidden_states is None: hidden_states = layer_input - hidden_states = self.vera_B @ self.lora_B @ self.vera_D @ self.lora_A + hidden_states = self.vera_B @ self.lora_B @ self.vera_D @ self.lora_A return hidden_states + + class LoRALayer(AdapterLayerBase): adapter_modules_name = "loras" @@ -265,7 +267,7 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: lora_cls = IA3 else: raise ValueError(f"Unknown composition_mode: {lora_config.composition_mode}") - #figure out good criteria to load vera + # figure out good criteria to load vera # lora = lora_cls( *self._get_lora_shapes(lora_config), From acee994622dbd214052004d59b606888cea636c2 Mon Sep 17 00:00:00 2001 From: julian fong Date: Thu, 12 Dec 2024 14:16:35 -0500 Subject: [PATCH 04/15] updates --- src/adapters/__init__.py | 2 + src/adapters/configuration/adapter_config.py | 9 ++- src/adapters/methods/lora.py | 68 ++++++++++---------- src/adapters/model_mixin.py | 16 +++-- 4 files changed, 51 insertions(+), 44 deletions(-) diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index a917828e72..d6a0b17fda 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -64,6 +64,7 @@ "SeqBnInvConfig", "StaticAdapterFusionConfig", "UniPELTConfig", + "VeraConfig", ], "context": [ "AdapterSetup", @@ -181,6 +182,7 @@ SeqBnInvConfig, StaticAdapterFusionConfig, UniPELTConfig, + VeraConfig, ) from .context import AdapterSetup, ForwardContext from .heads import ( diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 2d940ebcc6..3adbef2ddf 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -485,12 +485,12 @@ class LoRAConfig(AdapterConfig): `merge_adapter()`. d (:obj:`bool` or :obj:`float`, optional): The value of d used in the VeraConfig. Defaults to None. Places a trainable - scaling parameter before the decomposition matrix A to allow scaling of the + scaling parameter `d` before the decomposition matrix A to allow scaling of the internal weights. b (:obj:`bool` or :obj:`float`, optional): The value of b used in the VeraConfig. Defaults to None. Places a trainable - scaling parameter before the decomposition matrix B to allow scaling of the + scaling parameter `b` before the decomposition matrix B to allow scaling of the internal weights. """ @@ -541,9 +541,8 @@ class VeraConfig(LoRAConfig): trainable matrices 'd' and 'b' while keeping the original LoRA matrices frozen, random, and shared across layers. See more through their paper: https://arxiv.org/pdf/2106.09685. Note that `r` will still be supplied - since we are still initializing decomposition matrices A and B, - however the `composition_mode` parameter along with the - `use_gating` parameter will be ignored. + since we are still initializing decomposition matrices A and B. + The `composition_mode` parameter should also be set to `add`. """ selfattn_lora: bool = False diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index c6f4ccf61e..7e1afb9f26 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -18,7 +18,7 @@ from ..configuration import LoRAConfig, ModelAdaptersConfig from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase from .utils import dequantize_bnb_weight - +from ..context import ForwardContext try: from bitsandbytes.nn import Int8Params, Linear4bit, Linear8bitLt, Params4bit @@ -45,7 +45,6 @@ def __init__( self.composition_mode = config.composition_mode self.attn_matrices = config.attn_matrices self.use_gating = config.use_gating - # Optional dropout if config.dropout > 0.0: self.lora_dropout = nn.Dropout(p=config.dropout) @@ -75,7 +74,7 @@ def __init__( if self.use_gating: self.gate = nn.Linear(lora_A_shape[-1], gating_heads) nn.init.normal_(self.gate.weight, std=0.02) - + @property def delta_w(self) -> torch.Tensor: return self.lora_B @ self.lora_A @@ -178,57 +177,56 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens class Vera(nn.Module): def __init__( self, + name, lora_A_shape, lora_B_shape, config: LoRAConfig, ): super().__init__() + self.name = name self.d = config.d self.b = config.b self.lora_A_shape = lora_A_shape self.lora_B_shape = lora_B_shape - self.d_shape = self.lora_A_shape[1] + self.d_shape = self.lora_A_shape[0] self.b_shape = self.lora_B_shape[0] - # initialize frozen, random tensors - self.lora_A = torch.tensor(torch.zeros(lora_A_shape)) - self.lora_B = torch.tensor(torch.zeros(lora_B_shape)) - # Actual trainable parameters self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape) * self.d)) self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape) * self.b)) - # For compatibility with LoRA, allow all init_weights types here. - # Usually should be "vera" or "lora". - if config.init_weights == "lora": - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B) - elif config.init_weights == "bert": - nn.init.normal_(self.lora_A, std=0.02) - nn.init.normal_(self.lora_B, std=0.02) - elif config.init_weights == "ia3": - nn.init.ones_(self.lora_A) - nn.init.ones_(self.lora_B) - elif config.init_weights == "vera": - nn.kaiming.uniform_(self.lora_A) - nn.kaiming.uniform_(self.lora_B) - else: - raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) - @property def delta_w(self) -> torch.Tensor: - return self.vera_B @ self.lora_B @ self.vera_D @ self.lora_A + parameters = ForwardContext.get_context().shared_parameters[self.name] + lora_A = parameters["lora_A"] + lora_B = parameters["lora_B"] + return self.vera_B @ lora_B @ self.vera_D @ lora_A def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): + parameters = ForwardContext.get_context().shared_parameters[self.name] + lora_A = parameters["lora_A"] + lora_B = parameters["lora_B"] + if hidden_states is None: hidden_states = layer_input - hidden_states = self.vera_B @ self.lora_B @ self.vera_D @ self.lora_A + hidden_states = self.vera_B @ lora_B @ self.vera_D @ lora_A return hidden_states - +def init_shared_Vera_parameters(model_config, adapter_config, device): + hidden_size = model_config.hidden_size + r = adapter_config["r"] + parameters = nn.ParameterDict() + + # initialize frozen, random tensors A, B + parameters["lora_A"] = torch.zeros(r, hidden_size).to(device) + parameters["lora_B"] = torch.zeros(hidden_size, r).to(device) + + nn.init.kaiming_uniform_(parameters["lora_A"]) + nn.init.kaiming_uniform_(parameters["lora_B"]) + return parameters + class LoRALayer(AdapterLayerBase): adapter_modules_name = "loras" @@ -242,7 +240,7 @@ def __init__( self.loras = nn.ModuleDict(dict()) self.merged = False - + def get_n_heads(self, lora: Union[LoRA, IA3, LoRAConfig]): return 1 @@ -254,6 +252,7 @@ def _get_lora_shapes(self, config: LoRAConfig): def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: self.layer_idx = layer_idx + lora_config = self.adapters_config.match( adapter_name, config_type=LoRAConfig, @@ -262,13 +261,14 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: ) if lora_config is not None and self._check_lora_location(lora_config): if lora_config.composition_mode == "add": - lora_cls = LoRA + if lora_config.d and lora_config.b: + lora_cls = Vera + else: + lora_cls = LoRA elif lora_config.composition_mode == "scale": lora_cls = IA3 else: raise ValueError(f"Unknown composition_mode: {lora_config.composition_mode}") - # figure out good criteria to load vera - # lora = lora_cls( *self._get_lora_shapes(lora_config), lora_config, @@ -277,8 +277,6 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: lora.train(self.training) lora = lora.to(self.weight.device) self.loras[adapter_name] = lora - return True - return False def average_adapter( diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 659a6cfcff..d347161c3a 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -16,13 +16,13 @@ from transformers.utils import is_accelerate_available from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition -from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig +from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig, IA3Config from .context import AdapterSetup, ForwardContext from .hub_mixin import PushAdapterToHubMixin from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader from .methods.adapter_layer_base import AdapterLayerBase from .methods.bottleneck import BottleneckLayer -from .methods.lora import LoRALayer +from .methods.lora import LoRALayer, init_shared_Vera_parameters from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool from .methods.prompt_tuning import PromptTuningLayer @@ -426,7 +426,7 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr This method initializes adapter modules and fusion modules from the model config. """ self.base_model.shared_parameters = nn.ModuleDict() - + # Initialize adapters config init_adapters_config(self, model_config, adapters_config) # Initialize adapters in all submodules @@ -610,13 +610,21 @@ def _add_adapter_weights(self, adapter_name: str): ) else: raise ValueError( - "The model has different hidden sizes {}. Sharing comapcter weights is only possible if" + "The model has different hidden sizes {}. Sharing compacter weights is only possible if" " the hidden_sizes match.".format(hidden_sizes) ) else: self.base_model.shared_parameters[adapter_name] = init_shared_parameters( adapter_config, self.config.hidden_size, self.device ) + + # Vera Initialization + if self.adapters_config.match(adapter_name, LoRAConfig): + adapter_config = self.adapters_config.match(adapter_name, LoRAConfig) + self.base_model.shared_parameters[adapter_name] = init_shared_Vera_parameters( + self.config, adapter_config, self.device + ) + # Prefix Tuning for module in self.modules(): if isinstance(module, PrefixTuningPool): From 18182afaafa23a5764b6e94780aeae59fda10324 Mon Sep 17 00:00:00 2001 From: julian fong Date: Thu, 12 Dec 2024 14:27:32 -0500 Subject: [PATCH 05/15] Updates --- src/adapters/methods/lora.py | 14 ++++++++------ src/adapters/model_mixin.py | 6 +++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 7e1afb9f26..1d4329770a 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -74,7 +74,7 @@ def __init__( if self.use_gating: self.gate = nn.Linear(lora_A_shape[-1], gating_heads) nn.init.normal_(self.gate.weight, std=0.02) - + @property def delta_w(self) -> torch.Tensor: return self.lora_B @ self.lora_A @@ -207,18 +207,19 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens parameters = ForwardContext.get_context().shared_parameters[self.name] lora_A = parameters["lora_A"] lora_B = parameters["lora_B"] - + if hidden_states is None: hidden_states = layer_input hidden_states = self.vera_B @ lora_B @ self.vera_D @ lora_A return hidden_states + def init_shared_Vera_parameters(model_config, adapter_config, device): hidden_size = model_config.hidden_size r = adapter_config["r"] parameters = nn.ParameterDict() - + # initialize frozen, random tensors A, B parameters["lora_A"] = torch.zeros(r, hidden_size).to(device) parameters["lora_B"] = torch.zeros(hidden_size, r).to(device) @@ -226,7 +227,8 @@ def init_shared_Vera_parameters(model_config, adapter_config, device): nn.init.kaiming_uniform_(parameters["lora_A"]) nn.init.kaiming_uniform_(parameters["lora_B"]) return parameters - + + class LoRALayer(AdapterLayerBase): adapter_modules_name = "loras" @@ -240,7 +242,7 @@ def __init__( self.loras = nn.ModuleDict(dict()) self.merged = False - + def get_n_heads(self, lora: Union[LoRA, IA3, LoRAConfig]): return 1 @@ -252,7 +254,7 @@ def _get_lora_shapes(self, config: LoRAConfig): def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: self.layer_idx = layer_idx - + lora_config = self.adapters_config.match( adapter_name, config_type=LoRAConfig, diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index d347161c3a..721011dbae 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -16,7 +16,7 @@ from transformers.utils import is_accelerate_available from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition -from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig, IA3Config +from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig from .context import AdapterSetup, ForwardContext from .hub_mixin import PushAdapterToHubMixin from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader @@ -426,7 +426,7 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr This method initializes adapter modules and fusion modules from the model config. """ self.base_model.shared_parameters = nn.ModuleDict() - + # Initialize adapters config init_adapters_config(self, model_config, adapters_config) # Initialize adapters in all submodules @@ -617,7 +617,7 @@ def _add_adapter_weights(self, adapter_name: str): self.base_model.shared_parameters[adapter_name] = init_shared_parameters( adapter_config, self.config.hidden_size, self.device ) - + # Vera Initialization if self.adapters_config.match(adapter_name, LoRAConfig): adapter_config = self.adapters_config.match(adapter_name, LoRAConfig) From f28e50827d85a2919314990ec69e6489a53ae71a Mon Sep 17 00:00:00 2001 From: julian fong Date: Thu, 12 Dec 2024 15:20:01 -0500 Subject: [PATCH 06/15] Updates --- src/adapters/methods/lora.py | 3 ++- src/adapters/model_mixin.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 1d4329770a..f65ef85a36 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -16,9 +16,10 @@ from ..composition import Average, BatchSplit, Parallel, Stack from ..configuration import LoRAConfig, ModelAdaptersConfig +from ..context import ForwardContext from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase from .utils import dequantize_bnb_weight -from ..context import ForwardContext + try: from bitsandbytes.nn import Int8Params, Linear4bit, Linear8bitLt, Params4bit diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 721011dbae..01388e0100 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -10,7 +10,7 @@ import torch from torch import nn -from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig +from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig, VeraConfig from transformers import GenerationConfig from transformers.modeling_outputs import ModelOutput from transformers.utils import is_accelerate_available @@ -619,8 +619,8 @@ def _add_adapter_weights(self, adapter_name: str): ) # Vera Initialization - if self.adapters_config.match(adapter_name, LoRAConfig): - adapter_config = self.adapters_config.match(adapter_name, LoRAConfig) + if self.adapters_config.match(adapter_name, VeraConfig): + adapter_config = self.adapters_config.match(adapter_name, VeraConfig) self.base_model.shared_parameters[adapter_name] = init_shared_Vera_parameters( self.config, adapter_config, self.device ) From f38b0e315a9d26796757b0cfd3e583c0b1ce9bb2 Mon Sep 17 00:00:00 2001 From: julian fong Date: Thu, 12 Dec 2024 15:55:25 -0500 Subject: [PATCH 07/15] removed typo --- src/adapters/methods/lora.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index f65ef85a36..8b56d311c1 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -280,6 +280,8 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: lora.train(self.training) lora = lora.to(self.weight.device) self.loras[adapter_name] = lora + return True + return False def average_adapter( From 385cd357070fe2481011fa1b67b74a315f5ee9fc Mon Sep 17 00:00:00 2001 From: julian fong Date: Thu, 12 Dec 2024 16:34:18 -0500 Subject: [PATCH 08/15] fix black --- src/adapters/methods/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 8b56d311c1..207319a25d 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -281,7 +281,7 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: lora = lora.to(self.weight.device) self.loras[adapter_name] = lora return True - + return False def average_adapter( From 46af3fd30879af722f4c79718e2a43afe7674bde Mon Sep 17 00:00:00 2001 From: julian fong Date: Fri, 13 Dec 2024 22:17:59 -0500 Subject: [PATCH 09/15] updates --- src/adapters/configuration/adapter_config.py | 6 +- src/adapters/methods/lora.py | 72 ++++++++++++++++++-- 2 files changed, 68 insertions(+), 10 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 3adbef2ddf..1ba13ab270 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -545,14 +545,14 @@ class VeraConfig(LoRAConfig): The `composition_mode` parameter should also be set to `add`. """ - selfattn_lora: bool = False + selfattn_lora: bool = True intermediate_lora: bool = False output_lora: bool = False r: int = 8 - init_weights: str = "vera" d: Union[bool, float] = 0.1 - b: Union[bool, float] = 0 + b: Union[bool, float] = 0.0 + init_weights: str = "vera" @dataclass(eq=False) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 207319a25d..f84110ba48 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -69,6 +69,9 @@ def __init__( elif config.init_weights == "ia3": nn.init.ones_(self.lora_A) nn.init.ones_(self.lora_B) + elif config.init_weights == "vera": + nn.init.kaiming_uniform_(self.lora_A) + nn.init.kaiming_uniform_(self.lora_B) else: raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) @@ -91,6 +94,7 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: return weights - added * self.scaling def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): + print("triggered") if hidden_states is None: hidden_states = layer_input hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B) @@ -178,15 +182,21 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens class Vera(nn.Module): def __init__( self, - name, lora_A_shape, lora_B_shape, config: LoRAConfig, + gating_heads: int = 1, ): super().__init__() - self.name = name self.d = config.d self.b = config.b + self.r = config.r + self.alpha = config.alpha + self.use_gating = config.use_gating + + # Optional dropout + if config.dropout > 0.0: + self.lora_dropout = nn.Dropout(p=config.dropout) self.lora_A_shape = lora_A_shape self.lora_B_shape = lora_B_shape @@ -196,6 +206,11 @@ def __init__( # Actual trainable parameters self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape) * self.d)) self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape) * self.b)) + self.scaling = self.alpha / self.r + + if self.use_gating: + self.gate = nn.Linear(lora_A_shape[-1], gating_heads) + nn.init.normal_(self.gate.weight, std=0.02) @property def delta_w(self) -> torch.Tensor: @@ -204,6 +219,16 @@ def delta_w(self) -> torch.Tensor: lora_B = parameters["lora_B"] return self.vera_B @ lora_B @ self.vera_D @ lora_A + def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor: + """Performs the composition operation between existing and injected weights.""" + if scaling is None: + scaling = self.scaling + return weights + added * scaling + + def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: + """Inverts the composition operation between existing and injected weights.""" + return weights - added * self.scaling + def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): parameters = ForwardContext.get_context().shared_parameters[self.name] lora_A = parameters["lora_A"] @@ -211,22 +236,51 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens if hidden_states is None: hidden_states = layer_input - hidden_states = self.vera_B @ lora_B @ self.vera_D @ lora_A - return hidden_states + if getattr(self, "lora_dropout"): + hidden_states = self.lora_dropout(hidden_states) + + hidden_states = hidden_states @ self.vera_B @ lora_B @ self.vera_D @ lora_A + + if self.use_gating: + gate = torch.sigmoid(self.gate(layer_input)) + gate = torch.mean(gate, dim=1).unsqueeze(-1) + hidden_states = hidden_states * gate + else: + gate = None + + return hidden_states, gate + + def set_vera_adapter_name(self, name): + self.name = name def init_shared_Vera_parameters(model_config, adapter_config, device): hidden_size = model_config.hidden_size r = adapter_config["r"] + parameters = nn.ParameterDict() # initialize frozen, random tensors A, B parameters["lora_A"] = torch.zeros(r, hidden_size).to(device) parameters["lora_B"] = torch.zeros(hidden_size, r).to(device) - nn.init.kaiming_uniform_(parameters["lora_A"]) - nn.init.kaiming_uniform_(parameters["lora_B"]) + if adapter_config["init_weights"] == "lora": + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(parameters["lora_A"], a=math.sqrt(5)) + nn.init.zeros_(parameters["lora_B"]) + elif adapter_config["init_weights"] == "bert": + nn.init.normal_(parameters["lora_A"], std=0.02) + nn.init.normal_(parameters["lora_B"], std=0.02) + elif adapter_config["init_weights"] == "ia3": + nn.init.ones_(parameters["lora_A"]) + nn.init.ones_(parameters["lora_B"]) + elif adapter_config["init_weights"] == "vera": + nn.init.kaiming_uniform_(parameters["lora_A"]) + nn.init.kaiming_uniform_(parameters["lora_B"]) + else: + raise ValueError("Unknown init_wfffeights type: {}".format(adapter_config["init_weights"])) + return parameters @@ -264,7 +318,7 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: ) if lora_config is not None and self._check_lora_location(lora_config): if lora_config.composition_mode == "add": - if lora_config.d and lora_config.b: + if isinstance(lora_config.d, float) or isinstance(lora_config.b, float): lora_cls = Vera else: lora_cls = LoRA @@ -277,6 +331,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: lora_config, gating_heads=self.get_n_heads(lora_config), ) + # if we're using Vera, then set the adapter name into the Vera object + if lora_cls == Vera: + lora.set_vera_adapter_name(name=adapter_name) + lora.train(self.training) lora = lora.to(self.weight.device) self.loras[adapter_name] = lora From 9f3a20286cbd9e2a39fa9aed736a837e1ccda4d8 Mon Sep 17 00:00:00 2001 From: julian fong Date: Fri, 13 Dec 2024 22:54:14 -0500 Subject: [PATCH 10/15] fixed typo --- src/adapters/methods/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index f84110ba48..d417d5af28 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -279,7 +279,7 @@ def init_shared_Vera_parameters(model_config, adapter_config, device): nn.init.kaiming_uniform_(parameters["lora_A"]) nn.init.kaiming_uniform_(parameters["lora_B"]) else: - raise ValueError("Unknown init_wfffeights type: {}".format(adapter_config["init_weights"])) + raise ValueError("Unknown init_weights type: {}".format(adapter_config["init_weights"])) return parameters From 0c0f7e65e7ffd329d968e18dfe4a42d6a2ca4308 Mon Sep 17 00:00:00 2001 From: julian fong Date: Mon, 23 Dec 2024 11:47:48 -0500 Subject: [PATCH 11/15] updates --- src/adapters/methods/lora.py | 12 ++++++++---- src/adapters/model_mixin.py | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index d417d5af28..ef8ec158d8 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -94,7 +94,6 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: return weights - added * self.scaling def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): - print("triggered") if hidden_states is None: hidden_states = layer_input hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B) @@ -239,8 +238,13 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens if getattr(self, "lora_dropout"): hidden_states = self.lora_dropout(hidden_states) - - hidden_states = hidden_states @ self.vera_B @ lora_B @ self.vera_D @ lora_A + # print(self.vera_B.shape) + # print(lora_B.shape) + # print(self.vera_D.shape) + # print(lora_A.shape) + # print((self.vera_B @ lora_B @ self.vera_D @ lora_A).shape) + # print(hidden_states.shape) + hidden_states = hidden_states @ torch.t(self.vera_B @ lora_B @ self.vera_D @ lora_A ) if self.use_gating: gate = torch.sigmoid(self.gate(layer_input)) @@ -255,7 +259,7 @@ def set_vera_adapter_name(self, name): self.name = name -def init_shared_Vera_parameters(model_config, adapter_config, device): +def init_shared_vera_parameters(model_config, adapter_config, device): hidden_size = model_config.hidden_size r = adapter_config["r"] diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 01388e0100..5e142ad4c5 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -22,7 +22,7 @@ from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader from .methods.adapter_layer_base import AdapterLayerBase from .methods.bottleneck import BottleneckLayer -from .methods.lora import LoRALayer, init_shared_Vera_parameters +from .methods.lora import LoRALayer, init_shared_vera_parameters from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool from .methods.prompt_tuning import PromptTuningLayer @@ -621,7 +621,7 @@ def _add_adapter_weights(self, adapter_name: str): # Vera Initialization if self.adapters_config.match(adapter_name, VeraConfig): adapter_config = self.adapters_config.match(adapter_name, VeraConfig) - self.base_model.shared_parameters[adapter_name] = init_shared_Vera_parameters( + self.base_model.shared_parameters[adapter_name] = init_shared_vera_parameters( self.config, adapter_config, self.device ) From 1229fc5eb9bcaa92b3eb431b682003e3f949ed95 Mon Sep 17 00:00:00 2001 From: julian fong Date: Mon, 23 Dec 2024 23:45:33 -0500 Subject: [PATCH 12/15] added review updates --- src/adapters/configuration/adapter_config.py | 13 +++++--- src/adapters/methods/lora.py | 33 ++++++++++---------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index b0f6e6b962..14580bdf9c 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -509,10 +509,11 @@ class LoRAConfig(AdapterConfig): composition_mode: str = "add" init_weights: str = "lora" use_gating: bool = False - d: Union[bool, float] = None - b: Union[bool, float] = None + vera_d: float = None + vera_b: float = None dtype: Optional[str] = None + @dataclass(eq=False) class IA3Config(LoRAConfig): """ @@ -542,7 +543,7 @@ class VeraConfig(LoRAConfig): Lora Config that applies vector-based random matrix adaptation. It adds trainable matrices 'd' and 'b' while keeping the original LoRA matrices frozen, random, and shared across layers. See more through their paper: - https://arxiv.org/pdf/2106.09685. Note that `r` will still be supplied + https://arxiv.org/pdf/2310.11454. Note that `r` will still be supplied since we are still initializing decomposition matrices A and B. The `composition_mode` parameter should also be set to `add`. """ @@ -552,9 +553,11 @@ class VeraConfig(LoRAConfig): output_lora: bool = False r: int = 8 - d: Union[bool, float] = 0.1 - b: Union[bool, float] = 0.0 + vera_d: float = 0.1 + vera_b: float = 0.0 init_weights: str = "vera" + composition_mode: str = "add" + dtype: Optional[str] = None @dataclass(eq=False) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 2135d33f32..65ff383dad 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -38,6 +38,7 @@ def __init__( lora_B_shape, config: LoRAConfig, gating_heads: int = 1, + name: str = None, ): super().__init__() assert config.composition_mode == "add", "LoRA module only supports composition_mode='add'." @@ -46,6 +47,7 @@ def __init__( self.composition_mode = config.composition_mode self.attn_matrices = config.attn_matrices self.use_gating = config.use_gating + self.name = name # Optional dropout if config.dropout > 0.0: self.lora_dropout = nn.Dropout(p=config.dropout) @@ -115,6 +117,7 @@ def __init__( lora_B_shape, config: LoRAConfig, gating_heads: int = 1, + name: str = None, ): super().__init__() assert config.composition_mode == "scale", "IA3 module only supports composition_mode='scale'." @@ -125,6 +128,7 @@ def __init__( self.composition_mode = config.composition_mode self.attn_matrices = config.attn_matrices self.use_gating = config.use_gating + self.name = name # Optional dropout if config.dropout > 0.0: raise ValueError("IA3 module does not support dropout.") @@ -186,13 +190,20 @@ def __init__( lora_B_shape, config: LoRAConfig, gating_heads: int = 1, + name: str = None, ): super().__init__() - self.d = config.d - self.b = config.b + self.d = config.vera_d + self.b = config.vera_b self.r = config.r self.alpha = config.alpha self.use_gating = config.use_gating + self.name = name + + # check to make sure that the `composition_mode` is set to `add` + self.composition_mode = config.composition_mode + if self.composition_mode != "add": + raise ValueError("Vera module only supports composition_mode='add'.") # Optional dropout if config.dropout > 0.0: @@ -239,13 +250,8 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens if getattr(self, "lora_dropout"): hidden_states = self.lora_dropout(hidden_states) - # print(self.vera_B.shape) - # print(lora_B.shape) - # print(self.vera_D.shape) - # print(lora_A.shape) - # print((self.vera_B @ lora_B @ self.vera_D @ lora_A).shape) - # print(hidden_states.shape) - hidden_states = hidden_states @ torch.t(self.vera_B @ lora_B @ self.vera_D @ lora_A ) + + hidden_states = hidden_states @ torch.t(self.vera_B @ lora_B @ self.vera_D @ lora_A) if self.use_gating: gate = torch.sigmoid(self.gate(layer_input)) @@ -256,9 +262,6 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens return hidden_states, gate - def set_vera_adapter_name(self, name): - self.name = name - def init_shared_vera_parameters(model_config, adapter_config, device): hidden_size = model_config.hidden_size @@ -323,7 +326,7 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: ) if lora_config is not None and self._check_lora_location(lora_config): if lora_config.composition_mode == "add": - if isinstance(lora_config.d, float) or isinstance(lora_config.b, float): + if isinstance(lora_config.vera_d, float) or isinstance(lora_config.vera_b, float): lora_cls = Vera else: lora_cls = LoRA @@ -335,10 +338,8 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: *self._get_lora_shapes(lora_config), lora_config, gating_heads=self.get_n_heads(lora_config), + name=adapter_name, ) - # if we're using Vera, then set the adapter name into the Vera object - if lora_cls == Vera: - lora.set_vera_adapter_name(name=adapter_name) lora.train(self.training) lora = lora.to(self.weight.device) From 20ddb5c76a972343e5be122e6db20bd167d4b063 Mon Sep 17 00:00:00 2001 From: julian fong Date: Mon, 23 Dec 2024 23:48:08 -0500 Subject: [PATCH 13/15] apply fix from #770 --- src/adapters/methods/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 65ff383dad..fe67300583 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -259,6 +259,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens hidden_states = hidden_states * gate else: gate = None + hidden_states = hidden_states * self.scaling return hidden_states, gate From 25fe0a97920913be671aeadf063ebe5b3ef87699 Mon Sep 17 00:00:00 2001 From: julian fong Date: Tue, 24 Dec 2024 09:17:32 -0500 Subject: [PATCH 14/15] updated docstring --- src/adapters/configuration/adapter_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 14580bdf9c..0f315060ed 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -483,12 +483,12 @@ class LoRAConfig(AdapterConfig): Place a trainable gating module besides the added parameter module to control module activation. This is e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using `merge_adapter()`. - d (:obj:`bool` or :obj:`float`, optional): + vera_d (:obj:`bool` or :obj:`float`, optional): The value of d used in the VeraConfig. Defaults to None. Places a trainable scaling parameter `d` before the decomposition matrix A to allow scaling of the internal weights. - b (:obj:`bool` or :obj:`float`, optional): + vera_b (:obj:`bool` or :obj:`float`, optional): The value of b used in the VeraConfig. Defaults to None. Places a trainable scaling parameter `b` before the decomposition matrix B to allow scaling of the internal weights. From 7f7983278659eb8b25344c3cf1b6ecd2b6211c04 Mon Sep 17 00:00:00 2001 From: julian fong Date: Tue, 24 Dec 2024 09:18:24 -0500 Subject: [PATCH 15/15] updated docstring --- src/adapters/configuration/adapter_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 0f315060ed..a6cc14917b 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -483,12 +483,12 @@ class LoRAConfig(AdapterConfig): Place a trainable gating module besides the added parameter module to control module activation. This is e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using `merge_adapter()`. - vera_d (:obj:`bool` or :obj:`float`, optional): + vera_d (:obj:`float`, optional): The value of d used in the VeraConfig. Defaults to None. Places a trainable scaling parameter `d` before the decomposition matrix A to allow scaling of the internal weights. - vera_b (:obj:`bool` or :obj:`float`, optional): + vera_b (:obj:`float`, optional): The value of b used in the VeraConfig. Defaults to None. Places a trainable scaling parameter `b` before the decomposition matrix B to allow scaling of the internal weights.