Skip to content

Commit

Permalink
Allow specifying adapter dtype in AdapterConfig (#767)
Browse files Browse the repository at this point in the history
Aims to fix #766
Backwards compatible, since `dtype` defaults to `None` if not set in
`AdapterConfig`.
  • Loading branch information
killershrimp authored Dec 23, 2024
1 parent e591965 commit 7550099
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ class LoRAConfig(AdapterConfig):
Place a trainable gating module besides the added parameter module to control module activation. This is
e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using
`merge_adapter()`.
dtype (str, optional): torch dtype for reparametrization tensors. Defaults to None.
"""

architecture: Optional[str] = "lora"
Expand All @@ -499,6 +500,7 @@ class LoRAConfig(AdapterConfig):
composition_mode: str = "add"
init_weights: str = "lora"
use_gating: bool = False
dtype: Optional[str] = None


@dataclass(eq=False)
Expand All @@ -521,6 +523,7 @@ class IA3Config(LoRAConfig):
composition_mode: str = "scale"
init_weights: str = "ia3"
use_gating: bool = False
dtype: Optional[str] = None


@dataclass(eq=False)
Expand All @@ -540,6 +543,7 @@ class ReftConfig(AdapterConfig):
subtract_projection (bool): If True, subtract the projection of the input.
dropout (float): The dropout rate used in the intervention layer.
non_linearity (str): The activation function used in the intervention layer.
dtype (str, optional): torch dtype for intervention tensors. Defaults to None.
"""

layers: Union[Literal["all"], List[int]]
Expand All @@ -551,6 +555,7 @@ class ReftConfig(AdapterConfig):
subtract_projection = True
dropout: float = 0.05
non_linearity: Optional[str] = None
dtype: Optional[str] = None

architecture: str = "reft"

Expand All @@ -569,6 +574,7 @@ class LoReftConfig(ReftConfig):
r: int = 1
orthogonality: bool = True
tied_weights: bool = False
dtype: Optional[str] = None


@dataclass(eq=False)
Expand All @@ -583,6 +589,7 @@ class NoReftConfig(ReftConfig):
r: int = 1
orthogonality: bool = False
tied_weights: bool = False
dtype: Optional[str] = None


@dataclass(eq=False)
Expand All @@ -598,6 +605,7 @@ class DiReftConfig(ReftConfig):
orthogonality: bool = False
tied_weights: bool = False
subtract_projection = False
dtype: Optional[str] = None


class ConfigUnion(AdapterConfig):
Expand Down
5 changes: 3 additions & 2 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def __init__(
else:
self.lora_dropout = lambda x: x

dtype = getattr(torch, config.dtype) if config.dtype else None
# Actual trainable parameters
self.lora_A = nn.Parameter(torch.zeros(lora_A_shape))
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
self.lora_A = nn.Parameter(torch.zeros(lora_A_shape, dtype=dtype))
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=dtype))
self.scaling = self.lora_alpha / self.r

# For compatibility with (IA)^3, allow all init_weights types here.
Expand Down
9 changes: 6 additions & 3 deletions src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

import torch
import torch.nn as nn
Expand All @@ -18,12 +18,13 @@ def __init__(
subtract_projection: bool = True,
non_linearity: str = None,
dropout: float = 0.0,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.orthogonal = orthogonal
self.learned_source = nn.Linear(in_dim, r_dim, bias=True)
self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype)

projection = nn.Linear(in_dim, r_dim, bias=False)
projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype)
if orthogonal:
self.projection = nn.utils.parametrizations.orthogonal(projection)
else:
Expand All @@ -50,6 +51,7 @@ def __init__(self, in_features: int, config: ReftConfig):
self.suffix_positions = config.suffix_positions
self.tied_weights = config.tied_weights
n_units = 1 if config.tied_weights else 2
dtype = getattr(torch, config.dtype) if config.dtype else None
self.units = nn.ModuleList(
[
ReftUnit(
Expand All @@ -59,6 +61,7 @@ def __init__(self, in_features: int, config: ReftConfig):
config.subtract_projection,
config.non_linearity,
config.dropout,
dtype,
)
for _ in range(n_units)
]
Expand Down

0 comments on commit 7550099

Please sign in to comment.