Skip to content

Commit

Permalink
adapter_layer_forward() -> bottleneck_layer_forward()
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Oct 10, 2023
1 parent 55fdc0c commit 5959efa
Show file tree
Hide file tree
Showing 12 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/contributing/adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Now that we have discussed the purpose of every file in `src/adapters/models/<mo
- You can use similar model implementations for guidance.
- Often, existing mixins of another class can be reused. E.g. `BertLayer`, `RobertaLayer`, `XLMRobertaLayer`, `DebertaLayer`, `DebertaV2Layer` and `BertGenerationLayer` (all models derived from BERT) use the `BertLayerAdaptersMixin`.
- To additionally support Prefix Tuning, it's necessary to apply the forward call to the `PrefixTuningLayer` module in the respective attention layer (see step 3 for how to modify the code of an Hugging Face class).
- Make sure the calls to `adapter_layer_forward()` are added in the right places.
- Make sure the calls to `bottleneck_layer_forward()` are added in the right places.
- The mixin for the whole base model class (e.g., `BertModel`) should derive from `ModelBaseAdaptersMixin` and (if possible) `EmbeddingAdaptersMixin` and/or `InvertibleAdaptersMixin`. This mixin should at least implement the `iter_layers()` method but might require additional modifications depending on the architecture.
- If the model is a combination of different models, such as the EncoderDecoderModel, use `ModelUsingSubmodelsAdaptersMixin` instead of `ModelBaseAdaptersMixin`.
3. **Copied functions:**
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def forward(
attention_output = self.lambda_1 * attention_output

# first residual connection
hidden_states = self.attention_adapters.adapter_layer_forward(
hidden_states = self.attention_adapters.bottleneck_layer_forward(
self.drop_path(attention_output), hidden_states, None
)

Expand All @@ -116,7 +116,7 @@ def forward(
layer_output = self.lambda_2 * layer_output

# second residual connection
layer_output = self.output_adapters.adapter_layer_forward(self.drop_path(layer_output), hidden_states, None)
layer_output = self.output_adapters.bottleneck_layer_forward(self.drop_path(layer_output), hidden_states, None)

outputs = (layer_output,) + outputs

Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ class BertSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, BertSelfOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


class BertOutputWithAdapters(BertOutputAdaptersMixin, BertOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BertGenerationSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, BertGene
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


Expand Down Expand Up @@ -153,5 +153,5 @@ class BertGenerationOutputWithAdapters(BertOutputAdaptersMixin, BertGenerationOu
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states
4 changes: 2 additions & 2 deletions src/adapters/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ class DebertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, DebertaSelfOutp
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


class DebertaOutputWithAdapters(BertOutputAdaptersMixin, DebertaOutput):
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DebertaV2SelfOutputWithAdapters(BertSelfOutputAdaptersMixin, DebertaV2Self
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


Expand All @@ -43,7 +43,7 @@ class DebertaV2OutputWithAdapters(BertOutputAdaptersMixin, DebertaV2Output):
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ class ElectraSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, ElectraSelfOutp
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


class ElectraOutputWithAdapters(BertOutputAdaptersMixin, ElectraOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states
4 changes: 2 additions & 2 deletions src/adapters/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class RobertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, RobertaSelfOutp
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


Expand All @@ -151,5 +151,5 @@ class RobertaOutputWithAdapters(BertOutputAdaptersMixin, RobertaOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states
6 changes: 3 additions & 3 deletions src/adapters/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class T5LayerFFWithAdapters(T5FFLayerAdaptersMixin, T5LayerFF):
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = self.adapter_layer_forward(
hidden_states = self.bottleneck_layer_forward(
hidden_states=self.dropout(forwarded_states), residual_input=hidden_states, layer_norm=None
)
return hidden_states
Expand Down Expand Up @@ -207,7 +207,7 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = self.adapter_layer_forward(
hidden_states = self.bottleneck_layer_forward(
hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None
)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
Expand Down Expand Up @@ -239,7 +239,7 @@ def forward(
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = self.adapter_layer_forward(
layer_output = self.bottleneck_layer_forward(
hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None
)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/vit/modeling_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class ViTOutputWithAdapters(ViTOutputAdaptersMixin, ViTOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.output_adapters.adapter_layer_forward(hidden_states, input_tensor, None)
hidden_states = self.output_adapters.bottleneck_layer_forward(hidden_states, input_tensor, None)

return hidden_states

Expand All @@ -94,7 +94,7 @@ def forward(
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

hidden_states = self.attention_adapters.adapter_layer_forward(attention_output, hidden_states, None)
hidden_states = self.attention_adapters.bottleneck_layer_forward(attention_output, hidden_states, None)

# in ViT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/xlm_roberta/modeling_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class XLMRobertaSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, XLMRobertaSe
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states


Expand All @@ -155,5 +155,5 @@ class XLMRobertaOutputWithAdapters(BertOutputAdaptersMixin, XLMRobertaOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, self.LayerNorm)
return hidden_states
4 changes: 2 additions & 2 deletions src/adapters/models/xmod/modeling_xmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class XmodSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, XmodSelfOutput):
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, None)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, None)
return hidden_states


Expand All @@ -152,5 +152,5 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, lang_
layer_norm = self.adapter_layer_norm
elif self.adapter_reuse_layer_norm:
layer_norm = self.LayerNorm
hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, layer_norm)
hidden_states = self.bottleneck_layer_forward(hidden_states, input_tensor, layer_norm)
return hidden_states

0 comments on commit 5959efa

Please sign in to comment.