Skip to content

Commit

Permalink
Move CA to albef dir (#437)
Browse files Browse the repository at this point in the history
Summary:
Moving cross attn to albef since its only used by albef

Pull Request resolved: #437

Test Plan: pytest tests/models/albef/test_albef.py

Reviewed By: ebsmothers

Differential Revision: D47655341

Pulled By: ankitade

fbshipit-source-id: 0779e66c5529eb10ad43c4219f43b766b95927cf
  • Loading branch information
ankitade authored and facebook-github-bot committed Jul 22, 2023
1 parent 1940623 commit 82c1dc2
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 221 deletions.
60 changes: 59 additions & 1 deletion tests/models/albef/test_albef.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
ALBEFModelWithSimilarity,
ALBEFSimilarity,
)
from torchmultimodal.models.albef.multimodal_encoder import ALBEFMultimodalEncoder
from torchmultimodal.models.albef.multimodal_encoder import (
ALBEFMultimodalEncoder,
TransformerCrossAttentionLayer,
)
from torchmultimodal.modules.encoders.bert_text_encoder import bert_text_encoder
from torchmultimodal.utils.common import momentum_update, remove_grad

Expand Down Expand Up @@ -241,3 +244,58 @@ def test_neg_embeddings(albef_with_sim):
assert_expected(image_embeds_neg, expected_image_embeds_neg, rtol=0, atol=1e-4)
assert_expected(text_embeds_neg, expected_text_embeds_neg, rtol=0, atol=1e-4)
assert_expected(text_atts_neg, expected_text_atts_neg, rtol=0, atol=1e-4)


class TestTransformerCrossAttentionLayer:
@pytest.fixture(autouse=True)
def seed(self):
set_rng_seed(4)

@pytest.fixture
def get_encoder_layer(self):
def create_layer(norm_first):
model = TransformerCrossAttentionLayer(2, 1, 2, norm_first=norm_first)
model.eval()
return model

return create_layer

@pytest.fixture
def inputs(self):
return torch.randn(1, 2, 2, 2, 2)

@pytest.fixture
def cross_inputs(self):
return torch.randn(1, 2, 2, 2, 2)

def test_forward_prenorm(self, inputs, cross_inputs, get_encoder_layer):
model = get_encoder_layer(True)
actual = model(inputs, cross_inputs)
expected = torch.tensor(
[
[
[
[[-0.5925, 1.1257], [-0.5925, 1.1257]],
[[-0.5925, 1.1257], [-0.5925, 1.1257]],
],
[
[[-0.5925, 1.1257], [-0.5925, 1.1257]],
[[-0.5925, 1.1257], [-0.5925, 1.1257]],
],
]
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_forward_postnorm(self, inputs, cross_inputs, get_encoder_layer):
model = get_encoder_layer(False)
actual = model(inputs, cross_inputs)
expected = torch.tensor(
[
[
[[[-1.0, 1.0], [-1.0, 1.0]], [[-1.0, 1.0], [-1.0, 1.0]]],
[[[-1.0, 1.0], [-1.0, 1.0]], [[-1.0, 1.0], [-1.0, 1.0]]],
]
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)
52 changes: 0 additions & 52 deletions tests/modules/layers/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from tests.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.modules.layers.transformer import (
TransformerCrossAttentionLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
Expand Down Expand Up @@ -74,57 +73,6 @@ def test_forward_postnorm(self, inputs, get_encoder_layer):
assert_expected(actual, expected, rtol=0, atol=1e-4)


class TestTransformerCrossAttentionLayer:
@pytest.fixture
def get_encoder_layer(self):
def create_layer(norm_first):
model = TransformerCrossAttentionLayer(2, 1, 2, norm_first=norm_first)
model.eval()
return model

return create_layer

@pytest.fixture
def inputs(self):
return torch.randn(1, 2, 2, 2, 2)

@pytest.fixture
def cross_inputs(self):
return torch.randn(1, 2, 2, 2, 2)

def test_forward_prenorm(self, inputs, cross_inputs, get_encoder_layer):
model = get_encoder_layer(True)
actual = model(inputs, cross_inputs)
expected = torch.tensor(
[
[
[
[[-0.5925, 1.1257], [-0.5925, 1.1257]],
[[-0.5925, 1.1257], [-0.5925, 1.1257]],
],
[
[[-0.5925, 1.1257], [-0.5925, 1.1257]],
[[-0.5925, 1.1257], [-0.5925, 1.1257]],
],
]
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_forward_postnorm(self, inputs, cross_inputs, get_encoder_layer):
model = get_encoder_layer(False)
actual = model(inputs, cross_inputs)
expected = torch.tensor(
[
[
[[[-1.0, 1.0], [-1.0, 1.0]], [[-1.0, 1.0], [-1.0, 1.0]]],
[[[-1.0, 1.0], [-1.0, 1.0]], [[-1.0, 1.0], [-1.0, 1.0]]],
]
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)


class TestTransformerEncoder:
@pytest.fixture
def encoder(self):
Expand Down
171 changes: 170 additions & 1 deletion torchmultimodal/models/albef/multimodal_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,179 @@
from typing import Callable, Optional

from torch import nn, Tensor
from torchmultimodal.modules.layers.transformer import TransformerCrossAttentionLayer
from torchmultimodal.modules.layers.attention import MultiHeadAttention, SelfAttention
from torchmultimodal.modules.layers.mlp import MLP
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
from torchmultimodal.utils.attention import get_extended_attention_mask


class TransformerCrossAttentionLayer(nn.Module):
"""Transformer layer with self-attention on inputs and cross-attention on an encoder's outputs.
Can be used in a transformer decoder or an encoder with cross-attention. Similar to
``nn.TransformerDecoderLayer``, but generalized for use in an encoder with cross-attention as well.
Uses a custom ``MultiHeadAttention`` that supports n-dimensional inputs including sequences,
images, video.
Attributes:
d_model (int): size of hidden dimension of input
n_head (int): number of attention heads
dim_feedforward (int): size of hidden dimension of feedforward network
dropout (float): dropout probability for all dropouts. Defaults to 0.
activation (Callable): activation function in feedforward network. Defaults to ``nn.ReLU``.
layer_norm_eps (float): the eps value in layer norms. Default is 1e-12.
norm_first (bool): if True, layer norm is done prior to each of self-attention, cross-attention,
and feedforward. Otherwise, layer norm is done after.
Args:
hidden_states (Tensor): input tensor of shape [b, d1, ..., dn, c] to calculate self-attention on.
encoder_hidden_states (Tensor): input tensor of shape [b, d1, ..., dn, c] to calculate
cross-attention on.
attention_mask (Tensor, optional): mask to be applied to self-attention inputs, ``hidden_states``.
See ``MultiHeadAttention`` for shape requirements.
cross_attention_mask (Tensor, optional): mask to be applied to cross-attention inputs,
``encoder_hidden_states``. See ``MultiHeadAttention`` for shape requirements.
"""

def __init__(
self,
d_model: int,
n_head: int,
dim_feedforward: int,
dropout: float = 0.0,
activation: Callable[..., nn.Module] = nn.ReLU,
layer_norm_eps: float = 1e-12,
norm_first: bool = False,
) -> None:
super().__init__()
# attention block
self.attention = MultiHeadAttention(
dim_q=d_model,
dim_kv=d_model,
n_head=n_head,
attn_module=SelfAttention(dropout),
)
self.attention_dropout = nn.Dropout(dropout)
# cross attention block
self.cross_attention = MultiHeadAttention(
dim_q=d_model,
dim_kv=d_model,
n_head=n_head,
attn_module=SelfAttention(dropout),
)
self.cross_attention_dropout = nn.Dropout(dropout)
# feedforward block
self.feedforward = MLP(
d_model, d_model, dim_feedforward, dropout=dropout, activation=activation
)
self.feedforward_dropout = nn.Dropout(dropout)
# layernorms
self.attention_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps)
self.cross_attention_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps)
self.feedforward_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps)
self.norm_first = norm_first

def _self_attention_block(
self, hidden_states: Tensor, attention_mask: Optional[Tensor] = None
) -> Tensor:
output = self.attention(
hidden_states, attention_mask=attention_mask, return_attn_weights=False
)
output = self.attention_dropout(output)
return output

def _cross_attention_block(
self,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
cross_attention_mask: Optional[Tensor] = None,
) -> Tensor:
output = self.cross_attention(
hidden_states,
encoder_hidden_states,
attention_mask=cross_attention_mask,
return_attn_weights=False,
)
output = self.cross_attention_dropout(output)
return output

def _feedforward_block(self, hidden_states: Tensor) -> Tensor:
h = self.feedforward(hidden_states)
h = self.feedforward_dropout(h)
return h

def _forward_prenorm(
self,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
cross_attention_mask: Optional[Tensor] = None,
) -> Tensor:
x = hidden_states
kv = encoder_hidden_states
inputs = self.attention_layernorm(x)
attn_output = self._self_attention_block(inputs, attention_mask=attention_mask)
attn_residual = attn_output + x
attn_norm_output = self.cross_attention_layernorm(attn_residual)
cross_attention_output = self._cross_attention_block(
attn_norm_output, kv, cross_attention_mask
)
cross_attention_residual = cross_attention_output + attn_norm_output
cross_attention_norm_output = self.feedforward_layernorm(
cross_attention_residual
)
ff_residual = cross_attention_norm_output + self._feedforward_block(
cross_attention_norm_output
)
return ff_residual

def _forward_postnorm(
self,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
cross_attention_mask: Optional[Tensor] = None,
) -> Tensor:
x = hidden_states
kv = encoder_hidden_states
attn_output = self._self_attention_block(x, attention_mask=attention_mask)
attn_residual = attn_output + x
attn_norm_output = self.attention_layernorm(attn_residual)
cross_attention_output = self._cross_attention_block(
attn_norm_output, kv, cross_attention_mask
)
cross_attention_residual = cross_attention_output + attn_norm_output
cross_attention_norm_output = self.cross_attention_layernorm(
cross_attention_residual
)
ff_residual = cross_attention_norm_output + self._feedforward_block(
cross_attention_norm_output
)
outputs = self.feedforward_layernorm(ff_residual)
return outputs

def forward(
self,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
cross_attention_mask: Optional[Tensor] = None,
) -> Tensor:
if self.norm_first:
return self._forward_prenorm(
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_mask,
)
else:
return self._forward_postnorm(
hidden_states,
encoder_hidden_states,
attention_mask,
cross_attention_mask,
)


class ALBEFMultimodalEncoder(nn.Module):
"""
Construct multimodal embeddings from image embeddings, text embeddings, and text attention mask.
Expand Down
Loading

0 comments on commit 82c1dc2

Please sign in to comment.