Skip to content

Commit

Permalink
ENH Support different layer shapes for VeRA (#1817)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkopi committed Jun 10, 2024
1 parent a8286a7 commit 7b1c08d
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 61 deletions.
3 changes: 2 additions & 1 deletion docs/source/package_reference/vera.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ rendered properly in your Markdown viewer.

When saving the adapter parameters, it's possible to eschew storing the low rank matrices by setting `save_projection=False` on the `VeraConfig`. In that case, these matrices will be restored based on the fixed random seed from the `projection_prng_key` argument. This cuts down on the size of the checkpoint, but we cannot guarantee reproducibility on all devices and for all future versions of PyTorch. If you want to ensure reproducibility, set `save_projection=True` (which is the default).

To handle different shapes of adapted layers, VeRA initializes shared A and B matrices with the largest required size for each dimension. During the forward pass, submatrices A and B for a given layer are sliced out from these shared matrices and used as described in the paper. For example, adapting two linear layers of shapes (100, 20) and (80, 50) will create A and B matrices of shapes (rank, 50) and (100, rank) respectively. Then, to adapt a layer of shape (100, 20), submatrices A and B of shapes (rank, 20) and (100, rank) will be extracted.

VeRA currently has the following constraints:

- All targeted parameters must have the same shape.
- Only `nn.Linear` layers are supported.
- Quantized layers are not supported.

Expand Down
48 changes: 18 additions & 30 deletions examples/sequence_classification/VeRA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
" task_type=\"SEQ_CLS\", \n",
" r=rank,\n",
" d_initial=0.1,\n",
" target_modules=[\"query\", \"value\"],\n",
" target_modules=[\"query\", \"value\", \"intermediate.dense\"],\n",
" save_projection=True,\n",
")\n",
"head_lr = 1e-2\n",
Expand Down Expand Up @@ -205,7 +205,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 610,754 || all params: 125,257,924 || trainable%: 0.48759709605278145\n"
"trainable params: 647,714 || all params: 125,294,884 || trainable%: 0.5170\n"
]
}
],
Expand Down Expand Up @@ -255,76 +255,76 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/29 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.24it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.33it/s]\n"
" 0%| | 0/29 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
"100%|██████████| 29/29 [00:18<00:00, 1.58it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.52it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0: {'accuracy': 0.7132352941176471, 'f1': 0.823529411764706}\n"
"epoch 0: {'accuracy': 0.7475490196078431, 'f1': 0.8367670364500792}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.26it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.30it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.68it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1: {'accuracy': 0.7671568627450981, 'f1': 0.8484848484848485}\n"
"epoch 1: {'accuracy': 0.7671568627450981, 'f1': 0.8536209553158706}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.24it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.30it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.66it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.33it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 2: {'accuracy': 0.8259803921568627, 'f1': 0.8738898756660745}\n"
"epoch 2: {'accuracy': 0.8553921568627451, 'f1': 0.8959435626102292}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.25it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.41it/s]\n"
"100%|██████████| 29/29 [00:17<00:00, 1.64it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.35it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 3: {'accuracy': 0.8431372549019608, 'f1': 0.891156462585034}\n"
"epoch 3: {'accuracy': 0.8823529411764706, 'f1': 0.9133574007220215}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:23<00:00, 1.25it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.35it/s]"
"100%|██████████| 29/29 [00:17<00:00, 1.63it/s]\n",
"100%|██████████| 4/4 [00:01<00:00, 3.17it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 4: {'accuracy': 0.8480392156862745, 'f1': 0.8938356164383561}\n"
"epoch 4: {'accuracy': 0.8897058823529411, 'f1': 0.9183303085299456}\n"
]
},
{
Expand Down Expand Up @@ -520,18 +520,6 @@
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"vscode": {
"interpreter": {
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
Expand Down
12 changes: 10 additions & 2 deletions src/peft/tuners/vera/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor:
lambda_d = lambda_d.float()
lambda_b = lambda_b.float()

sliced_A = vera_A[:, : self.in_features]
sliced_B = vera_B[: self.out_features, :]
lambda_b = lambda_b.unsqueeze(-1)
lambda_d = lambda_d.unsqueeze(-1)
output_tensor = transpose((lambda_b * vera_B) @ (lambda_d * vera_A), self.fan_in_fan_out)
output_tensor = transpose((lambda_b * sliced_B) @ (lambda_d * sliced_A), self.fan_in_fan_out)

if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
Expand Down Expand Up @@ -252,9 +254,15 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
vera_A = self.vera_A[active_adapter]
vera_B = self.vera_B[active_adapter]

# As adapted layers may have different shapes and VeRA contains a single shared pair of A and B matrices,
# we initialize these matrices with the largest required size for each dimension.
# During the forward pass, required submatrices are sliced out from the shared vera_A and vera_B.
sliced_A = vera_A[:, : self.in_features]
sliced_B = vera_B[: self.out_features, :]

dropout = self.vera_dropout[active_adapter]
x = x.to(lambda_d.dtype)
result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), vera_A), vera_B)
result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), sliced_A), sliced_B)

result = result.to(previous_dtype)
return result
Expand Down
29 changes: 12 additions & 17 deletions src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ class VeraModel(BaseTuner):
def __init__(self, model, config, adapter_name) -> None:
super().__init__(model, config, adapter_name)

def _find_first_dim(self, config) -> tuple[int, int]:
def _find_dim(self, config) -> tuple[int, int]:
"""
Finds the first linear layer that has been wrapped with Vera, and extract the input and output dimension.
Finds the largest input and output dimensions across linear layers that have been wrapped with VeRA.
This will be used for determining the size of the shared vera_A and vera_B matrices.
This will throw an error if there are multiple layers of the same type with different shapes.
"""
model_config = getattr(self.model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
Expand All @@ -116,7 +114,7 @@ def _find_first_dim(self, config) -> tuple[int, int]:
peft_config = self._prepare_adapter_config(config, model_config)
peft_config = _maybe_include_all_linear_layers(peft_config, self.model)

first_shape = None
largest_shape = None
for key, module in self.model.named_modules():
if not self._check_target_module_exists(peft_config, key):
continue
Expand All @@ -128,33 +126,30 @@ def _find_first_dim(self, config) -> tuple[int, int]:
else:
continue

if first_shape is None:
first_shape = module_shape
if largest_shape is None:
largest_shape = module_shape
continue

if module_shape != first_shape:
raise ValueError(
"Multiple target layers with different dimensions were specified. VeRA only supports a "
f"single dimension size. Expected shape {first_shape}, got {module_shape}."
)
if module_shape != largest_shape:
largest_shape = tuple(max(a, b) for a, b in zip(largest_shape, module_shape))

if first_shape is None:
if largest_shape is None:
msg = "No layers types compatible with VeRA were found. Please check `peft_config.target_modules`."
raise ValueError(msg)

return first_shape
return largest_shape

def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None:
first_linear_out_dim, first_linear_in_dim = self._find_first_dim(config)
linear_out_dim, linear_in_dim = self._find_dim(config)

# use of persistent to exclude vera_A and vera_B from the state dict if we choose not to save them.
self.vera_A = BufferDict({}, persistent=config.save_projection)
self.vera_B = BufferDict({}, persistent=config.save_projection)

# deterministic init of vera_A and vera_B if we know the key
generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key)
vera_A = _kaiming_init((config.r, first_linear_in_dim), generator=generator)
vera_B = _kaiming_init((first_linear_out_dim, config.r), generator=generator)
vera_A = _kaiming_init((config.r, linear_in_dim), generator=generator)
vera_B = _kaiming_init((linear_out_dim, config.r), generator=generator)
self.vera_A[adapter_name] = vera_A
self.vera_B[adapter_name] = vera_B

Expand Down
1 change: 1 addition & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@
("Vanilla MLP 1 VeRA", "MLP", VeraConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 VeRA", "MLP", VeraConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0", "lin1"]}),
(
"Vanilla MLP 5 VeRA",
"MLP",
Expand Down
Loading

0 comments on commit 7b1c08d

Please sign in to comment.