-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements Vera
#763
base: main
Are you sure you want to change the base?
Implements Vera
#763
Changes from 10 commits
57c5131
259a268
b66571c
acee994
18182af
f28e508
f38b0e3
385cd35
46af3fd
9f3a202
12379e3
0c0f7e6
99cfb68
1229fc5
20ddb5c
25fe0a9
7f79832
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -478,11 +478,20 @@ class LoRAConfig(AdapterConfig): | |
(addition of decomposed matrix, as in LoRA) or "scale" (element-wise multiplication of vector, as in | ||
(IA)^3). "scale" can only be used together with r=1. Defaults to "add". | ||
init_weights (:obj:`str`, optional): Initialization method for the weights of the LoRA modules. | ||
Currently, this can be either "lora" (default) or "bert". | ||
Currently, this can be either "lora" (default) or "bert", or "vera". | ||
use_gating (:obj:`bool`, optional): | ||
Place a trainable gating module besides the added parameter module to control module activation. This is | ||
e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using | ||
`merge_adapter()`. | ||
d (:obj:`bool` or :obj:`float`, optional): | ||
The value of d used in the VeraConfig. Defaults to None. Places a trainable | ||
scaling parameter `d` before the decomposition matrix A to allow scaling of the | ||
internal weights. | ||
|
||
b (:obj:`bool` or :obj:`float`, optional): | ||
The value of b used in the VeraConfig. Defaults to None. Places a trainable | ||
scaling parameter `b` before the decomposition matrix B to allow scaling of the | ||
internal weights. | ||
""" | ||
|
||
architecture: Optional[str] = "lora" | ||
|
@@ -499,6 +508,8 @@ class LoRAConfig(AdapterConfig): | |
composition_mode: str = "add" | ||
init_weights: str = "lora" | ||
use_gating: bool = False | ||
d: Union[bool, float] = None | ||
b: Union[bool, float] = None | ||
|
||
|
||
@dataclass(eq=False) | ||
|
@@ -523,6 +534,27 @@ class IA3Config(LoRAConfig): | |
use_gating: bool = False | ||
|
||
|
||
@dataclass(eq=False) | ||
class VeraConfig(LoRAConfig): | ||
""" | ||
Lora Config that applies vector-based random matrix adaptation. It adds | ||
trainable matrices 'd' and 'b' while keeping the original LoRA matrices | ||
frozen, random, and shared across layers. See more through their paper: | ||
https://arxiv.org/pdf/2106.09685. Note that `r` will still be supplied | ||
since we are still initializing decomposition matrices A and B. | ||
The `composition_mode` parameter should also be set to `add`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the paper link still needs updating :) |
||
""" | ||
|
||
selfattn_lora: bool = True | ||
intermediate_lora: bool = False | ||
output_lora: bool = False | ||
|
||
r: int = 8 | ||
d: Union[bool, float] = 0.1 | ||
b: Union[bool, float] = 0.0 | ||
init_weights: str = "vera" | ||
|
||
|
||
@dataclass(eq=False) | ||
class ReftConfig(AdapterConfig): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
from ..composition import Average, BatchSplit, Parallel, Stack | ||
from ..configuration import LoRAConfig, ModelAdaptersConfig | ||
from ..context import ForwardContext | ||
from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase | ||
from .utils import dequantize_bnb_weight | ||
|
||
|
@@ -68,6 +69,9 @@ def __init__( | |
elif config.init_weights == "ia3": | ||
nn.init.ones_(self.lora_A) | ||
nn.init.ones_(self.lora_B) | ||
elif config.init_weights == "vera": | ||
nn.init.kaiming_uniform_(self.lora_A) | ||
nn.init.kaiming_uniform_(self.lora_B) | ||
else: | ||
raise ValueError("Unknown init_weights type: {}".format(config.init_weights)) | ||
|
||
|
@@ -90,6 +94,7 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: | |
return weights - added * self.scaling | ||
|
||
def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): | ||
print("triggered") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to remove? |
||
if hidden_states is None: | ||
hidden_states = layer_input | ||
hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B) | ||
|
@@ -131,7 +136,7 @@ def __init__( | |
# For compatibility with LoRA, allow all init_weights types here. | ||
# Usually should be "ia3". | ||
if config.init_weights == "lora": | ||
logger.warning("(IA)^3 module initialized with LoRA zeo init. Ignore if this is intended.") | ||
logger.warning("(IA)^3 module initialized with LoRA zero init. Ignore if this is intended.") | ||
nn.init.zeros_(self.lora_B) | ||
elif config.init_weights == "bert": | ||
nn.init.normal_(self.lora_B, std=0.02) | ||
|
@@ -174,6 +179,111 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens | |
return hidden_states, gate | ||
|
||
|
||
class Vera(nn.Module): | ||
def __init__( | ||
self, | ||
lora_A_shape, | ||
lora_B_shape, | ||
config: LoRAConfig, | ||
gating_heads: int = 1, | ||
): | ||
super().__init__() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should also add an assert for composition mode "add" here (same as in LoRA init), just to make sure |
||
self.d = config.d | ||
self.b = config.b | ||
self.r = config.r | ||
self.alpha = config.alpha | ||
self.use_gating = config.use_gating | ||
|
||
# Optional dropout | ||
if config.dropout > 0.0: | ||
self.lora_dropout = nn.Dropout(p=config.dropout) | ||
|
||
self.lora_A_shape = lora_A_shape | ||
self.lora_B_shape = lora_B_shape | ||
self.d_shape = self.lora_A_shape[0] | ||
self.b_shape = self.lora_B_shape[0] | ||
|
||
# Actual trainable parameters | ||
self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape) * self.d)) | ||
self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape) * self.b)) | ||
self.scaling = self.alpha / self.r | ||
|
||
if self.use_gating: | ||
self.gate = nn.Linear(lora_A_shape[-1], gating_heads) | ||
nn.init.normal_(self.gate.weight, std=0.02) | ||
|
||
@property | ||
def delta_w(self) -> torch.Tensor: | ||
parameters = ForwardContext.get_context().shared_parameters[self.name] | ||
lora_A = parameters["lora_A"] | ||
lora_B = parameters["lora_B"] | ||
return self.vera_B @ lora_B @ self.vera_D @ lora_A | ||
|
||
def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor: | ||
"""Performs the composition operation between existing and injected weights.""" | ||
if scaling is None: | ||
scaling = self.scaling | ||
return weights + added * scaling | ||
|
||
def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor: | ||
"""Inverts the composition operation between existing and injected weights.""" | ||
return weights - added * self.scaling | ||
|
||
def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor): | ||
parameters = ForwardContext.get_context().shared_parameters[self.name] | ||
lora_A = parameters["lora_A"] | ||
lora_B = parameters["lora_B"] | ||
|
||
if hidden_states is None: | ||
hidden_states = layer_input | ||
|
||
if getattr(self, "lora_dropout"): | ||
hidden_states = self.lora_dropout(hidden_states) | ||
|
||
hidden_states = hidden_states @ self.vera_B @ lora_B @ self.vera_D @ lora_A | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't the order be reversed here? ie we matmul hidden states with lora_A -> vera_d -> lora_B -> vera_b, according to §3.1 (2) of the paper? |
||
|
||
if self.use_gating: | ||
gate = torch.sigmoid(self.gate(layer_input)) | ||
gate = torch.mean(gate, dim=1).unsqueeze(-1) | ||
hidden_states = hidden_states * gate | ||
else: | ||
gate = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as this is likely merged after #770, the same fix from there should be applied here |
||
|
||
return hidden_states, gate | ||
|
||
def set_vera_adapter_name(self, name): | ||
self.name = name | ||
|
||
|
||
def init_shared_Vera_parameters(model_config, adapter_config, device): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: ideally lower-case "v" in the middle of method names |
||
hidden_size = model_config.hidden_size | ||
r = adapter_config["r"] | ||
|
||
parameters = nn.ParameterDict() | ||
|
||
# initialize frozen, random tensors A, B | ||
parameters["lora_A"] = torch.zeros(r, hidden_size).to(device) | ||
parameters["lora_B"] = torch.zeros(hidden_size, r).to(device) | ||
|
||
if adapter_config["init_weights"] == "lora": | ||
# initialize A the same way as the default for nn.Linear and B to zero | ||
nn.init.kaiming_uniform_(parameters["lora_A"], a=math.sqrt(5)) | ||
nn.init.zeros_(parameters["lora_B"]) | ||
elif adapter_config["init_weights"] == "bert": | ||
nn.init.normal_(parameters["lora_A"], std=0.02) | ||
nn.init.normal_(parameters["lora_B"], std=0.02) | ||
elif adapter_config["init_weights"] == "ia3": | ||
nn.init.ones_(parameters["lora_A"]) | ||
nn.init.ones_(parameters["lora_B"]) | ||
elif adapter_config["init_weights"] == "vera": | ||
nn.init.kaiming_uniform_(parameters["lora_A"]) | ||
nn.init.kaiming_uniform_(parameters["lora_B"]) | ||
else: | ||
raise ValueError("Unknown init_weights type: {}".format(adapter_config["init_weights"])) | ||
|
||
return parameters | ||
|
||
|
||
class LoRALayer(AdapterLayerBase): | ||
adapter_modules_name = "loras" | ||
|
||
|
@@ -199,6 +309,7 @@ def _get_lora_shapes(self, config: LoRAConfig): | |
|
||
def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: | ||
self.layer_idx = layer_idx | ||
|
||
lora_config = self.adapters_config.match( | ||
adapter_name, | ||
config_type=LoRAConfig, | ||
|
@@ -207,7 +318,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: | |
) | ||
if lora_config is not None and self._check_lora_location(lora_config): | ||
if lora_config.composition_mode == "add": | ||
lora_cls = LoRA | ||
if isinstance(lora_config.d, float) or isinstance(lora_config.b, float): | ||
lora_cls = Vera | ||
else: | ||
lora_cls = LoRA | ||
elif lora_config.composition_mode == "scale": | ||
lora_cls = IA3 | ||
else: | ||
|
@@ -217,6 +331,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: | |
lora_config, | ||
gating_heads=self.get_n_heads(lora_config), | ||
) | ||
# if we're using Vera, then set the adapter name into the Vera object | ||
if lora_cls == Vera: | ||
lora.set_vera_adapter_name(name=adapter_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. feels a bit hacky to do this only for vera as the name is not specific to this type. what do you think of always passing the name directly to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've thought about that idea as well but opted for now to implement this idea first since right now |
||
|
||
lora.train(self.training) | ||
lora = lora.to(self.weight.device) | ||
self.loras[adapter_name] = lora | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we name these "vera_b" and "vera_d", to make more obvious what these are related to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also why can these be bools? ie what happens when I set
d=True, b=True
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, i think this is a typo based on a previous idea I had which i scraped later.. thanks