Skip to content

Commit

Permalink
FIX Init AdaLoRA to be identity transform (#1884)
Browse files Browse the repository at this point in the history
Resolves #1836

There was an accidental change in a previous PR that initialized lora_E
as normal, when it should be zeros.
  • Loading branch information
BenjaminBossan committed Jun 25, 2024
1 parent ef23712 commit c9b19bb
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig

def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
nn.init.normal_(self.lora_E[adapter_name], mean=0.0, std=0.02)
nn.init.zeros_(self.lora_E[adapter_name])
nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.02)
nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02)

Expand Down
1 change: 1 addition & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"boft_kwargs": {"init_weights": [False]},
"vera_kwargs": {"init_weights": [False]},
Expand Down
1 change: 1 addition & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c
{
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"vera_kwargs": {"init_weights": [False]},
"task_type": "SEQ_2_SEQ_LM",
Expand Down
1 change: 1 addition & 0 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c
{
"model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"boft_kwargs": {"init_weights": [False]},
"vera_kwargs": {"init_weights": [False]},
Expand Down
31 changes: 30 additions & 1 deletion tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class TestLoraInitialization:
"""Test class to check the initialization of adapters."""
"""Test class to check the initialization of LoRA adapters."""

torch_device = infer_device()

Expand Down Expand Up @@ -520,6 +520,8 @@ def test_lora_use_dora_with_megatron_core_raises(self):


class TestAdaLoraInitialization:
torch_device = infer_device()

def test_adalora_target_modules_set(self):
config = AdaLoraConfig(target_modules=["linear", "embed", "conv2d"])
assert config.target_modules == {"linear", "embed", "conv2d"}
Expand All @@ -532,6 +534,33 @@ def test_adalora_loftq_config_raises(self):
with pytest.raises(ValueError, match="ADALORA does not support LOFTQ"):
AdaLoraConfig(loftq_config={"loftq": "config"})

def get_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# choose a large weight so that averages are close to expected values
self.linear = nn.Linear(1000, 1000)

def forward(self, x):
return self.linear(x)

return MyModule().eval().to(self.torch_device)

@pytest.fixture
def data(self):
return torch.rand(10, 1000).to(self.torch_device)

def test_adalora_default_init_identity(self, data):
# default is True
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)
config = AdaLoraConfig(target_modules=["linear"])
model = get_peft_model(model, config)
output_after = model(data)
assert torch.allclose(output_before, output_after)


class TestPromptTuningInitialization:
torch_device = infer_device()
Expand Down

0 comments on commit c9b19bb

Please sign in to comment.