Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Vera #763

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Conversation

julian-fong
Copy link
Contributor

@julian-fong julian-fong commented Dec 1, 2024

This PR aims to implement Vera, which introduces trainable parameters d and b while keeps LoRA matrices A and B frozen, random, and shared across layers.

I've opted to put the Vera implementation under the Lora implementation (like IA3)

Paper: https://arxiv.org/pdf/2310.11454

This PR includes:

  • A new Vera implementation under the lora methods file
  • A new adapter config named VeraConfig
  • A shared parameter initialization function that takes in the init_weights argument set from the VeraConfig
  • A new criteria inside the add_adapter function inside the LoRALayer to check if we should use the Vera class. Currently, the criteria will check to see if the parameters from the VeraConfig d and b are of type float, its the most simple criteria I could think of at the moment.
  • A new parameter included added inside add_adapter inside the LoRALayer as suggested. It will now pass the name of the adapter inside the __init__ function.

Things to note:

In the original Vera paper, the decomposition matrices B and A are frozen and shared across layers. As suggested, I've opted to create these matrices similar to the PHMLayer implementation, using a new method named init_shared_vera_parameters. This function will take in the init_weights argument set from the VeraConfig to setup the initialization of B and A, while the other parameters from the VeraConfig will be used inside the Vera module.

As I'm not 100% sure what the composition modes do ('add', 'scale'), I've opted the Vera class to be use only if composition_mode is set to add, and either d or b are of type float (and not None).

reviews appreciated!

@julian-fong
Copy link
Contributor Author

@calpt Regarding the paper: from my first readings it seems like it doesn't mention any scaling via a constant alpha/r or any gating. Should I still include it in the vera implementation to make it consistent with the lora and IA3 modules? I am also assuming the com and com_inv would also need to be re-included for Vera in order to allow for integration with the LoRALayer class

@calpt
Copy link
Member

calpt commented Dec 2, 2024

@calpt Regarding the paper: from my first readings it seems like it doesn't mention any scaling via a constant alpha/r or any gating. Should I still include it in the vera implementation to make it consistent with the lora and IA3 modules? I am also assuming the com and com_inv would also need to be re-included for Vera in order to allow for integration with the LoRALayer class

yes, doesn't hurt to include these options even if not mentioned in the paper, unless this makes the implementation significantly more challenging

@julian-fong
Copy link
Contributor Author

I'll also include a new test module to test the Vera module similar to the IA3 and the Lora test modules soon

Copy link
Member

@calpt calpt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did a first review pass, thanks for working on this!
will review again once we have tests & docs added.

for adding tests, you might sync with @TimoImhof, since we have a larger test folder refactoring coming up here: #740, so might make sense to directly base off that?

gate = torch.mean(gate, dim=1).unsqueeze(-1)
hidden_states = hidden_states * gate
else:
gate = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this is likely merged after #770, the same fix from there should be applied here

config: LoRAConfig,
gating_heads: int = 1,
):
super().__init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also add an assert for composition mode "add" here (same as in LoRA init), just to make sure

Comment on lines 511 to 512
d: Union[bool, float] = None
b: Union[bool, float] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we name these "vera_b" and "vera_d", to make more obvious what these are related to?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why can these be bools? ie what happens when I set d=True, b=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, i think this is a typo based on a previous idea I had which i scraped later.. thanks

@@ -90,6 +94,7 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
return weights - added * self.scaling

def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
print("triggered")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove?

if getattr(self, "lora_dropout"):
hidden_states = self.lora_dropout(hidden_states)

hidden_states = hidden_states @ self.vera_B @ lora_B @ self.vera_D @ lora_A
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the order be reversed here? ie we matmul hidden states with lora_A -> vera_d -> lora_B -> vera_b, according to §3.1 (2) of the paper?

Comment on lines 334 to 336
# if we're using Vera, then set the adapter name into the Vera object
if lora_cls == Vera:
lora.set_vera_adapter_name(name=adapter_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels a bit hacky to do this only for vera as the name is not specific to this type. what do you think of always passing the name directly to the __init__ method of each module class (for all LoRA, Vera, IA3) and setting self.name directly there?
that might be cleaner long-term as we might want to use the name in LoRA as well in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've thought about that idea as well but opted for now to implement this idea first since right now lora and IA3 don't use self.name. I'll refactor it as you said. Thanks!

self.name = name


def init_shared_Vera_parameters(model_config, adapter_config, device):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ideally lower-case "v" in the middle of method names

Comment on lines 540 to 545
Lora Config that applies vector-based random matrix adaptation. It adds
trainable matrices 'd' and 'b' while keeping the original LoRA matrices
frozen, random, and shared across layers. See more through their paper:
https://arxiv.org/pdf/2106.09685. Note that `r` will still be supplied
since we are still initializing decomposition matrices A and B.
The `composition_mode` parameter should also be set to `add`.
Copy link
Member

@calpt calpt Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the paper link still needs updating :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants