Skip to content

Commit

Permalink
[BUG] Fixes Bottleneck Configs to work with ln_before = True and init…
Browse files Browse the repository at this point in the history
…_weights = "mam_adapter" (#761)

Fixes #745 

When "mam_adapter" is specified, the code will now look for the
`nn.Linear` or `PHMLayer` inside the `self.down_adapter` layer sequence
and apply the initialization on the correct layer

edit: also removes an extra block of code in the `AdapterPlus` notebook
  • Loading branch information
julian-fong authored Dec 2, 2024
1 parent ec4a59e commit e591965
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 21 deletions.
19 changes: 0 additions & 19 deletions notebooks/ViT_AdapterPlus_FineTuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -302,25 +302,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = AdapterTrainer(\n",
" model=model,\n",
" args=training_args,\n",
" data_collator=data_collator,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=eval_dataset,\n",
" tokenizer=processor,\n",
" compute_metrics = compute_metrics\n",
")\n",
"\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
6 changes: 4 additions & 2 deletions src/adapters/methods/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ def __init__(
self.gate.apply(self.init_bert_weights)
elif config["init_weights"] == "mam_adapter":
with torch.no_grad():
nn.init.kaiming_uniform_(self.adapter_down[0].weight, a=math.sqrt(5))
for layer in self.adapter_down:
if isinstance(layer, nn.Linear) or isinstance(layer, PHMLayer):
nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
nn.init.zeros_(layer.bias)
nn.init.zeros_(self.adapter_up.weight)
nn.init.zeros_(self.adapter_down[0].bias)
nn.init.zeros_(self.adapter_up.bias)
if self.use_gating:
self.gate.apply(self.init_bert_weights)
Expand Down

0 comments on commit e591965

Please sign in to comment.