From 7550099e46f468da2bf8fc645e447f3b18a2ad3b Mon Sep 17 00:00:00 2001 From: Alex Yun <30671301+killershrimp@users.noreply.github.com> Date: Mon, 23 Dec 2024 03:25:20 -0800 Subject: [PATCH] Allow specifying adapter dtype in AdapterConfig (#767) Aims to fix https://github.com/adapter-hub/adapters/issues/766 Backwards compatible, since `dtype` defaults to `None` if not set in `AdapterConfig`. --- src/adapters/configuration/adapter_config.py | 8 ++++++++ src/adapters/methods/lora.py | 5 +++-- src/adapters/methods/reft.py | 9 ++++++--- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 0f2eec2162..b5249cb9f5 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -483,6 +483,7 @@ class LoRAConfig(AdapterConfig): Place a trainable gating module besides the added parameter module to control module activation. This is e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using `merge_adapter()`. + dtype (str, optional): torch dtype for reparametrization tensors. Defaults to None. """ architecture: Optional[str] = "lora" @@ -499,6 +500,7 @@ class LoRAConfig(AdapterConfig): composition_mode: str = "add" init_weights: str = "lora" use_gating: bool = False + dtype: Optional[str] = None @dataclass(eq=False) @@ -521,6 +523,7 @@ class IA3Config(LoRAConfig): composition_mode: str = "scale" init_weights: str = "ia3" use_gating: bool = False + dtype: Optional[str] = None @dataclass(eq=False) @@ -540,6 +543,7 @@ class ReftConfig(AdapterConfig): subtract_projection (bool): If True, subtract the projection of the input. dropout (float): The dropout rate used in the intervention layer. non_linearity (str): The activation function used in the intervention layer. + dtype (str, optional): torch dtype for intervention tensors. Defaults to None. """ layers: Union[Literal["all"], List[int]] @@ -551,6 +555,7 @@ class ReftConfig(AdapterConfig): subtract_projection = True dropout: float = 0.05 non_linearity: Optional[str] = None + dtype: Optional[str] = None architecture: str = "reft" @@ -569,6 +574,7 @@ class LoReftConfig(ReftConfig): r: int = 1 orthogonality: bool = True tied_weights: bool = False + dtype: Optional[str] = None @dataclass(eq=False) @@ -583,6 +589,7 @@ class NoReftConfig(ReftConfig): r: int = 1 orthogonality: bool = False tied_weights: bool = False + dtype: Optional[str] = None @dataclass(eq=False) @@ -598,6 +605,7 @@ class DiReftConfig(ReftConfig): orthogonality: bool = False tied_weights: bool = False subtract_projection = False + dtype: Optional[str] = None class ConfigUnion(AdapterConfig): diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index c62a94f265..d56a11a91d 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -51,9 +51,10 @@ def __init__( else: self.lora_dropout = lambda x: x + dtype = getattr(torch, config.dtype) if config.dtype else None # Actual trainable parameters - self.lora_A = nn.Parameter(torch.zeros(lora_A_shape)) - self.lora_B = nn.Parameter(torch.zeros(lora_B_shape)) + self.lora_A = nn.Parameter(torch.zeros(lora_A_shape, dtype=dtype)) + self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=dtype)) self.scaling = self.lora_alpha / self.r # For compatibility with (IA)^3, allow all init_weights types here. diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index 0914e8d3aa..9c6647e399 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -18,12 +18,13 @@ def __init__( subtract_projection: bool = True, non_linearity: str = None, dropout: float = 0.0, + dtype: Optional[torch.dtype] = None, ): super().__init__() self.orthogonal = orthogonal - self.learned_source = nn.Linear(in_dim, r_dim, bias=True) + self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype) - projection = nn.Linear(in_dim, r_dim, bias=False) + projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype) if orthogonal: self.projection = nn.utils.parametrizations.orthogonal(projection) else: @@ -50,6 +51,7 @@ def __init__(self, in_features: int, config: ReftConfig): self.suffix_positions = config.suffix_positions self.tied_weights = config.tied_weights n_units = 1 if config.tied_weights else 2 + dtype = getattr(torch, config.dtype) if config.dtype else None self.units = nn.ModuleList( [ ReftUnit( @@ -59,6 +61,7 @@ def __init__(self, in_features: int, config: ReftConfig): config.subtract_projection, config.non_linearity, config.dropout, + dtype, ) for _ in range(n_units) ]