diff --git a/examples/mugen/generation/text_video_gpt.py b/examples/mugen/generation/text_video_gpt.py index 05cb75fc..779120fc 100644 --- a/examples/mugen/generation/text_video_gpt.py +++ b/examples/mugen/generation/text_video_gpt.py @@ -12,7 +12,7 @@ from torch import nn, Tensor -from torchmultimodal.models.gpt import ( +from torchmultimodal.models.video_gpt.gpt import ( MultimodalGPT, MultimodalTransformerDecoder, RightShift, @@ -69,7 +69,7 @@ def text_video_gpt( new frame will be of shape ``(8, 8, 8)`` with each dim divided by the rate of downsample. Defaults to ``(4, 32, 32)``. d_model (int): Dimension of the underlying transformer decoder. - See :py:class:`torchmultimodal.models.gpt.TransformerDecoderLayer`. Defaults to ``768``. + See :py:class:`torchmultimodal.models.video_gpt.gpt.TransformerDecoderLayer`. Defaults to ``768``. n_head (int): Number of attention heads used by the transformer decoder. Defaults to ``8``. dropout (float): Dropout probability used by the projection layer of the transformer decoder. Defaults to ``0.2``. @@ -93,7 +93,7 @@ def text_video_gpt( Defaults to ``None``. Returns: - An instance of :py:class:`torchmultimodal.models.gpt.MultimodalGPT`. + An instance of :py:class:`torchmultimodal.models.video_gpt.gpt.MultimodalGPT`. """ # builds text tokenizer from pre-trained @@ -195,7 +195,7 @@ class TextTokenizer(nn.Module): """Converts between text and tokens / embedings Wrapper around the tokenizer to be consistent with the API required by - :py:class:`torchmultimodal.models.gpt.MultimodalGPT`. It also contains the + :py:class:`torchmultimodal.models.video_gpt.gpt.MultimodalGPT`. It also contains the embedding layer to enable lookup by token ids. """ diff --git a/examples/mugen/generation/video_vqvae.py b/examples/mugen/generation/video_vqvae.py index a3b4f772..49040d0b 100644 --- a/examples/mugen/generation/video_vqvae.py +++ b/examples/mugen/generation/video_vqvae.py @@ -6,7 +6,7 @@ from typing import Optional -from torchmultimodal.models.video_vqvae import ( +from torchmultimodal.models.video_gpt.video_vqvae import ( preprocess_int_conv_params, VideoDecoder, VideoEncoder, @@ -50,7 +50,7 @@ def video_vqvae_mugen( n_res_layers (int, optional): Number of ``AttentionResidualBlocks`` to include in encoder and decoder. Defaults to ``4``. attn_hidden_dim (int, optional): Size of hidden dim of - :class:`~torchmultimodal.models.video_vqvae.AttentionResidualBlocks`. Defaults to ``240``. + :class:`~torchmultimodal.models.video_gpt.video_vqvae.AttentionResidualBlocks`. Defaults to ``240``. num_embeddings (int, optional): Number of embedding vectors used in :class:`~torchmultimodal.modules.layers.codebook.Codebook`. Defaults to ``2048``. embedding_dim (int, optional): Dimensionality of embedding vectors in @@ -63,8 +63,8 @@ def video_vqvae_mugen( Returns: An instance of :class:`~torchmultimodal.models.vqvae.VQVAE` constructed with: - * :class:`~torchmultimodal.model.video_vqvae.VideoEncoder` - * :class:`~torchmultimodal.model.video_vqvae.VideoDecoder` + * :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoEncoder` + * :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoDecoder` """ encoder_strides = ((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2), (1, 1, 1)) decoder_strides = ((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2)) diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 8b2158a1..71699ce8 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -11,7 +11,7 @@ from tests.test_utils import assert_expected, assert_expected_namedtuple, set_rng_seed from torch import nn from torch.nn import functional as F -from torchmultimodal.models.gpt import ( +from torchmultimodal.models.video_gpt.gpt import ( MultimodalGPT, MultimodalGPTOutput, MultimodalTransformerDecoder, @@ -257,7 +257,7 @@ def test_initialize_parameters(self, gpt, mocker): # Testing mean and std of the initialized weights data requires a large # amount samples to be statistically stable. Here we just test whether # the method in question has been called to avoid test flakiness. - mock_init = mocker.patch("torchmultimodal.models.gpt.Tensor.normal_") + mock_init = mocker.patch("torchmultimodal.models.video_gpt.gpt.Tensor.normal_") gpt = gpt(use_gpt_init=True) mock_init.assert_called() diff --git a/tests/models/test_video_gpt.py b/tests/models/test_video_gpt.py index dccbd0cc..091655a0 100644 --- a/tests/models/test_video_gpt.py +++ b/tests/models/test_video_gpt.py @@ -10,7 +10,7 @@ from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.models.video_gpt import video_gpt, video_vqvae +from torchmultimodal.models.video_gpt.model import video_gpt, video_vqvae @pytest.fixture(autouse=True) diff --git a/tests/models/test_video_vqvae.py b/tests/models/test_video_vqvae.py index 5c40d698..2f7c3ece 100644 --- a/tests/models/test_video_vqvae.py +++ b/tests/models/test_video_vqvae.py @@ -4,12 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from itertools import repeat + import pytest import torch from tests.test_utils import assert_expected, assert_expected_namedtuple, set_rng_seed -from torchmultimodal.models.video_vqvae import ( +from torchmultimodal.models.video_gpt.video_vqvae import ( AttentionResidualBlock, + AxialAttention, + AxialAttentionBlock, preprocess_int_conv_params, video_vqvae, VideoDecoder, @@ -36,6 +40,89 @@ def input_tensor(): return torch.ones(1, 2, 2, 2, 2) +class TestAxialBlock: + @pytest.fixture + def hidden_dim(self): + return 3 + + @pytest.fixture + def n_dim(self): + return 3 + + @pytest.fixture + def input_shape(self, n_dim): + return tuple(repeat(2, n_dim)) + + @pytest.fixture + def axial_block(self, input_shape, hidden_dim): + return AxialAttentionBlock(len(input_shape), hidden_dim, 1) + + @pytest.fixture + def q(self, input_shape, hidden_dim): + n_heads = 1 + return torch.randn(1, n_heads, *input_shape, hidden_dim // n_heads) + + @pytest.fixture + def kv(self, input_shape, hidden_dim): + n_heads = 1 + return torch.randn(1, n_heads, *input_shape, hidden_dim // n_heads) + + @pytest.fixture + def axial_attn(self): + return AxialAttention(1) # only on second axis of input + + def test_axial_attention(self, axial_attn, q, kv): + k = v = kv + actual, _ = axial_attn(q, k, v) + expected = torch.tensor( + [ + [ + [ + [ + [[-0.5869, 1.8958, 0.8688], [0.0299, 0.2098, 1.2741]], + [[-0.6662, 1.9747, 0.8980], [0.1002, 0.2094, 1.5472]], + ], + [ + [[0.5902, -0.3275, -0.8727], [-1.0557, 1.0791, 0.3916]], + [[0.6623, -0.3223, -0.8948], [-1.0755, 1.0763, 0.3708]], + ], + ] + ] + ] + ) + assert_expected(actual, expected, rtol=0, atol=1e-4) + + def test_axial_block_forward(self, axial_block, hidden_dim, input_shape): + """Test AxialAttentionBlock with sub-components""" + x = 2 * torch.ones(1, hidden_dim, *input_shape) + actual = axial_block(x) + expected = torch.tensor( + [ + [ + [ + [[0.822055, 0.822055], [0.822055, 0.822055]], + [[0.822055, 0.822055], [0.822055, 0.822055]], + ], + [ + [[-0.767143, -0.767143], [-0.767143, -0.767143]], + [[-0.767143, -0.767143], [-0.767143, -0.767143]], + ], + [ + [[-0.916860, -0.916860], [-0.916860, -0.916860]], + [[-0.916860, -0.916860], [-0.916860, -0.916860]], + ], + ] + ] + ) + assert_expected(actual, expected, rtol=0, atol=1e-4) + + def test_axial_block_channel_dim(self, axial_block, hidden_dim, input_shape): + """Test dim check in forward of AxialAttentionBlock""" + x = torch.zeros(1, hidden_dim + 1, *input_shape) + with pytest.raises(ValueError): + _ = axial_block(x) + + class TestAttentionResidualBlock: def test_hidden_dim_assertion(self): with pytest.raises(ValueError): diff --git a/tests/modules/layers/test_attention.py b/tests/modules/layers/test_attention.py index 12abbe15..d844cff9 100644 --- a/tests/modules/layers/test_attention.py +++ b/tests/modules/layers/test_attention.py @@ -11,8 +11,6 @@ import torch from tests.test_utils import assert_expected, set_rng_seed from torchmultimodal.modules.layers.attention import ( - AxialAttention, - AxialAttentionBlock, merge_multihead, MultiHeadAttention, scaled_dot_product_attention, @@ -58,11 +56,6 @@ def self_attn(): return SelfAttention(attn_dropout=0.0) -@pytest.fixture -def axial_attn(): - return AxialAttention(1) # only on second axis of input - - class TestMultiheadAttention: @pytest.fixture def multihead_attn(self, hidden_dim): @@ -395,28 +388,6 @@ def test_self_attention(self_attn, q, kv): assert_expected(actual, expected, rtol=0, atol=1e-4) -def test_axial_attention(axial_attn, q, kv): - k = v = kv - actual, _ = axial_attn(q, k, v) - expected = torch.tensor( - [ - [ - [ - [ - [[-0.5869, 1.8958, 0.8688], [0.0299, 0.2098, 1.2741]], - [[-0.6662, 1.9747, 0.8980], [0.1002, 0.2094, 1.5472]], - ], - [ - [[0.5902, -0.3275, -0.8727], [-1.0557, 1.0791, 0.3916]], - [[0.6623, -0.3223, -0.8948], [-1.0755, 1.0763, 0.3708]], - ], - ] - ] - ] - ) - assert_expected(actual, expected, rtol=0, atol=1e-4) - - def test_split_multihead(input_shape): x = torch.randn(1, *input_shape, 6) # (b, d1, ..., dn, c) out = split_multihead(x, 2) @@ -430,39 +401,3 @@ def test_merge_multihead(input_shape, hidden_dim, q): actual = torch.tensor(out.shape) expected = torch.tensor((1, *input_shape, hidden_dim)) assert_expected(actual, expected) - - -class TestAxialBlock: - @pytest.fixture - def axial_block(self, input_shape, hidden_dim): - return AxialAttentionBlock(len(input_shape), hidden_dim, 1) - - def test_axial_block_forward(self, axial_block, hidden_dim, input_shape): - """Test AxialAttentionBlock with sub-components""" - x = 2 * torch.ones(1, hidden_dim, *input_shape) - actual = axial_block(x) - expected = torch.tensor( - [ - [ - [ - [[0.822055, 0.822055], [0.822055, 0.822055]], - [[0.822055, 0.822055], [0.822055, 0.822055]], - ], - [ - [[-0.767143, -0.767143], [-0.767143, -0.767143]], - [[-0.767143, -0.767143], [-0.767143, -0.767143]], - ], - [ - [[-0.916860, -0.916860], [-0.916860, -0.916860]], - [[-0.916860, -0.916860], [-0.916860, -0.916860]], - ], - ] - ] - ) - assert_expected(actual, expected, rtol=0, atol=1e-4) - - def test_axial_block_channel_dim(self, axial_block, hidden_dim, input_shape): - """Test dim check in forward of AxialAttentionBlock""" - x = torch.zeros(1, hidden_dim + 1, *input_shape) - with pytest.raises(ValueError): - _ = axial_block(x) diff --git a/tests/utils/test_generate.py b/tests/utils/test_generate.py index 4a24889a..458e3d4a 100644 --- a/tests/utils/test_generate.py +++ b/tests/utils/test_generate.py @@ -10,7 +10,7 @@ from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.models.video_gpt import video_gpt +from torchmultimodal.models.video_gpt.model import video_gpt from torchmultimodal.utils.generate import ( GenerationUtil, get_logits_mask, diff --git a/torchmultimodal/models/gpt.py b/torchmultimodal/models/video_gpt/gpt.py similarity index 98% rename from torchmultimodal/models/gpt.py rename to torchmultimodal/models/video_gpt/gpt.py index d083f9ff..ebee7188 100644 --- a/torchmultimodal/models/gpt.py +++ b/torchmultimodal/models/video_gpt/gpt.py @@ -17,7 +17,7 @@ class TransformerDecoderOutput(NamedTuple): - """Outputs from :class:`~torchmultimodal.models.gpt.TransformerDecoder`. + """Outputs from :class:`~torchmultimodal.models.video_gpt.gpt.TransformerDecoder`. Attributes: last_hidden_states (Tensor): Output from the last layer of the transformer. @@ -36,7 +36,7 @@ class TransformerDecoderOutput(NamedTuple): class TransformerLayerOutput(NamedTuple): - """Outputs from :class:`~torchmultimodal.models.gpt.TransformerDecoderLayer`. + """Outputs from :class:`~torchmultimodal.models.video_gpt.gpt.TransformerDecoderLayer`. Attributes: hidden_states (Tensor): Output from the current layer. @@ -52,7 +52,7 @@ class TransformerLayerOutput(NamedTuple): class MultimodalGPTOutput(NamedTuple): - """Outputs from :meth:`~torchmultimodal.models.gpt.MultimodalGPT.forward`. + """Outputs from :meth:`~torchmultimodal.models.video_gpt.gpt.MultimodalGPT.forward`. Attributes: decoder_output (TransformerDeocoderOutput): Contains output from the multimodal transformer decoder. @@ -200,7 +200,7 @@ def forward( Defaults to ``False``. Returns: - An instance of :class:`~torchmultimodal.models.gpt.MultimodalGPTOutput`. + An instance of :class:`~torchmultimodal.models.video_gpt.gpt.MultimodalGPTOutput`. """ decoder_output = self.fwd( in_tokens=in_tokens, @@ -462,7 +462,7 @@ def forward( Defaults to ``False``. Returns: - An instace of :class:`~torchmultimodal.models.gpt.TransformerDecoderOutput`. + An instace of :class:`~torchmultimodal.models.video_gpt.gpt.TransformerDecoderOutput`. """ if (in_modality is None) and (out_modality is None): raise ValueError( @@ -562,7 +562,7 @@ def forward( Defaults to ``False``. Returns: - An instance of :class:`~torchmultimodal.models.gpt.TransformerDecoderOutput`. + An instance of :class:`~torchmultimodal.models.video_gpt.gpt.TransformerDecoderOutput`. """ if attn_mask is not None and attn_mask.dim() == 2: attn_mask = attn_mask[ @@ -680,7 +680,7 @@ def forward( Defaults to ``False``. Returns: - An instance of :class:`~torchmultimodal.models.gpt.TransformerLayerOutput`. + An instance of :class:`~torchmultimodal.models.video_gpt.gpt.TransformerLayerOutput`. """ attn_probs = None past_key_values = None diff --git a/torchmultimodal/models/video_gpt.py b/torchmultimodal/models/video_gpt/model.py similarity index 91% rename from torchmultimodal/models/video_gpt.py rename to torchmultimodal/models/video_gpt/model.py index 934be0ff..29a423d3 100644 --- a/torchmultimodal/models/video_gpt.py +++ b/torchmultimodal/models/video_gpt/model.py @@ -7,14 +7,14 @@ from typing import Tuple from torch import nn -from torchmultimodal.models.gpt import ( +from torchmultimodal.models.video_gpt.gpt import ( MultimodalGPT, MultimodalTransformerDecoder, RightShift, TransformerDecoder, TransformerDecoderLayer, ) -from torchmultimodal.models.video_vqvae import VideoDecoder, VideoEncoder +from torchmultimodal.models.video_gpt.video_vqvae import VideoDecoder, VideoEncoder from torchmultimodal.models.vqvae import VQVAE from torchmultimodal.modules.layers.attention import SelfAttention @@ -46,14 +46,14 @@ def video_gpt( Defaults to ``(16, 64, 64)``. latent_shape (Tuple[int, int, int]): Shape of the encoded video data. This should be consistent with the actual latent shape inferred by the video encoder. - See :class:`~torchmultimodal.models.video_vqvae.VideoEncoder`. + See :class:`~torchmultimodal.models.video_gpt.video_vqvae.VideoEncoder`. Defaults to ``(8, 32, 32)``. d_model (int): Dimension of the underlying transformer decoder. Value taken from: https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/gpt.py#L177 Note that this is different from the paper due to :class:`~torchmultimodal.modules.layers.position_embedding.BroadcastedPositionEmbedding` requires that ``d_model`` is a multiple of ``len(latent_shape)``. - See :py:class:`torchmultimodal.models.gpt.TransformerDecoderLayer`. Defaults to ``576``. + See :py:class:`torchmultimodal.models.video_gpt.gpt.TransformerDecoderLayer`. Defaults to ``576``. n_head (int): Number of attention heads used by the transformer decoder. Defaults to ``4``. dropout (float): Dropout probability used by the projection layer of the transformer decoder. Defaults to ``0.2``. @@ -61,10 +61,10 @@ def video_gpt( Defaults to ``0.3``. num_decoder_layers (int): Number of transformer decoder layers. Defaults to ``16``. use_gpt_init (bool): Whether to use weight initialization of GPT model. - See :class:`~torchmultimodal.models.gpt.MultimodalGPT`. Defaults to ``True``. + See :class:`~torchmultimodal.models.video_gpt.gpt.MultimodalGPT`. Defaults to ``True``. Returns: - An instance of :class:`~torchmultimodal.models.gpt.MultimodalGPT`. + An instance of :class:`~torchmultimodal.models.video_gpt.gpt.MultimodalGPT`. """ # constructs in and out tokenizers in_tokenizer = video_vqvae() @@ -138,7 +138,7 @@ def video_vqvae( Dimension-wise strides of the last conv layer of the encoder. Defaults to ``(1, 1, 1)``. in_channel_dim (int, optional): Size of channel dim in input. Defaults to ``3``. encoder_hidden_dim (int, optional): Size of channel dims in encoder conv layers. Defaults to ``240``. - n_res_layers (int, optional): Number of :class:`~torchmultimodal.models.video_vqvae.AttentionResidualBlocks` + n_res_layers (int, optional): Number of :class:`~torchmultimodal.models.video_gpt.video_vqvae.AttentionResidualBlocks` to include in encoder and decoder. Defaults to ``4``. attn_hidden_dim (int, optional): Size of hidden dim of ``AttentionResidualBlocks``. Defaults to ``240``. num_embeddings (int, optional): Number of embedding vectors used in ``Codebook``. Defaults to ``1024``. @@ -156,8 +156,8 @@ def video_vqvae( Returns: An instance of :class:`~torchmultimodal.models.vqvae.VQVAE` constructed with: - * :class:`~torchmultimodal.model.video_vqvae.VideoEncoder` - * :class:`~torchmultimodal.model.video_vqvae.VideoDecoder` + * :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoEncoder` + * :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoDecoder` """ encoder_kernel_sizes = conv_filter_sizes + (encoder_filter_size,) encoder_strides = conv_filter_strides + (encoder_filter_stride,) diff --git a/torchmultimodal/models/video_vqvae.py b/torchmultimodal/models/video_gpt/video_vqvae.py similarity index 70% rename from torchmultimodal/models/video_vqvae.py rename to torchmultimodal/models/video_gpt/video_vqvae.py index 6eeb0271..e70fc298 100644 --- a/torchmultimodal/models/video_vqvae.py +++ b/torchmultimodal/models/video_gpt/video_vqvae.py @@ -6,13 +6,138 @@ from typing import Any, cast, List, Optional, Tuple, Union +import torch + from torch import nn, Size, Tensor from torchmultimodal.models.vqvae import VQVAE -from torchmultimodal.modules.layers.attention import AxialAttentionBlock +from torchmultimodal.modules.layers.attention import ( + MultiHeadAttention, + scaled_dot_product_attention, +) from torchmultimodal.modules.layers.conv import SamePadConv3d, SamePadConvTranspose3d from torchmultimodal.utils.assertion import assert_equal_lengths -from torchmultimodal.utils.common import to_tuple_tuple +from torchmultimodal.utils.common import shift_dim, to_tuple_tuple + + +class AxialAttention(nn.Module): + """Computes attention over a single axis of the input. Other dims are flattened into the batch dimension. + + Args: + axial_dim (int): Dimension to compute attention on, indexed by input dimensions + (i.e., ``0`` for first input dimension, ``1`` for second). + attn_dropout (float): Probability of dropout after softmax. Default is ``0.0``. + """ + + def __init__(self, axial_dim: int, attn_dropout: float = 0.0) -> None: + super().__init__() + self.axial_dim = axial_dim + 2 # account for batch, head + self.attn_dropout = attn_dropout + + def forward( + self, + q: Tensor, + k: Tensor, + v: Tensor, + attention_mask: Optional[Tensor] = None, + head_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + q (Tensor): Query input of shape ``(b, h, d1, ..., dn, dim_q)`` where ``h`` is the number of + attention heads, ``(d1, ..., dn)`` are the query latent dimensions and ``dim_q`` is the dimension + of the query embeddings. + k, v (Tensor): Key/value input of shape ``(b, h, d1', ..., dn', dim_kv)`` where ``h`` is the number + of attention heads, ``(d1', ..., dn')`` are the key/value latent dimensions and ``dim_kv`` is + the dimension of the key/value embeddings. + attention_mask (Tensor, optional): Tensor of shape ``(b, h, d1, ..., q_dn, k_dn)`` where ``q_dn`` is + the dimension of the axis to compute attention on of the query and ``k_dn`` that of the key. + Contains 1s for positions to attend to and 0s for masked positions. + head_mask (Tensor, optional): Tensor of shape ``(b, h, d1, ..., q_dn, k_dn)``. + Contains 1s for positions to attend to and 0s for masked positions. + + Returns: + A tuple of output tensor and attention probabilities. + """ + # Ensure axial dim is within right dimensions, should be between head dim and embedding dim + if self.axial_dim >= len(q.shape) - 1: + raise ValueError("axial dim does not match input shape") + + # flatten all dims into batch dimension except chosen axial dim and channel dim + # b, h, d1, ..., dn, dim_q/dim_kv -> (b, h, d1, ..., dn-1), axial_dim, dim_q/dim_kv + q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3) + k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3) + v = shift_dim(v, self.axial_dim, -2) + old_shape = list(v.shape) + v = v.flatten(end_dim=-3) + + out, attn_probs = scaled_dot_product_attention( + q, + k, + v, + attention_mask=attention_mask, + head_mask=head_mask, + attn_dropout=self.attn_dropout if self.training else 0.0, + ) + out = out.view(*old_shape) + out = shift_dim(out, -2, self.axial_dim) + return out, attn_probs + + +class AxialAttentionBlock(nn.Module): + """Computes multihead axial attention across all dims of the input. + + Axial attention is an alternative to standard full attention, where instead + of computing attention across the entire flattened input, you compute it for + each dimension. To capture the global context that full attention does, stacking + multiple axial attention layers will allow information to propagate among the + multiple dimensions of the input. This enables attention calculations on high + dimensional inputs (images, videos) where full attention would be computationally + expensive and unfeasible. For more details, see `"Axial Attention in + Multidimensional Transformers (Ho et al. 2019)"`_ + and `"CCNet: Criss-Cross Attention for Semantic Segmentation (Huang et al. 2019) + "`_. + + Follows implementation by VideoGPT: + https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/vqvae.py + + Args: + n_dims (int): Dimensionality of input data, not including batch or embedding dims. + qkv_dim (int): Dimensionality of query/key/value embedding vectors. + n_head (int): Number of heads in multihead attention. Must divide into ``qkv_dim`` + evenly. + """ + + def __init__(self, n_dims: int, qkv_dim: int, n_head: int) -> None: + super().__init__() + self.qkv_dim = qkv_dim + self.mha_attns = nn.ModuleList( + [ + MultiHeadAttention( + dim_q=qkv_dim, + dim_kv=qkv_dim, + n_head=n_head, + attn_module=AxialAttention(d), + add_bias=False, + ) + for d in range(n_dims) + ] + ) + + def forward(self, x: Tensor) -> Tensor: + n_channel = x.shape[1] + if n_channel != self.qkv_dim: + raise ValueError( + f"Input channel dimension is {n_channel}, expected {self.qkv_dim}" + ) + + h = shift_dim(x, 1, -1) # (b, c, d1, ..., dn) -> (b, d1, ..., dn, c) + attn_out = torch.zeros_like(h) + for mha_attn in self.mha_attns: + attn_out += mha_attn(h, causal=False) + h = attn_out + h = shift_dim(h, -1, 1) # (b, d1, ..., dn, c) -> (b, c, d1, ..., dn) + return h def video_vqvae( diff --git a/torchmultimodal/modules/layers/attention.py b/torchmultimodal/modules/layers/attention.py index 63d167ac..d7155228 100644 --- a/torchmultimodal/modules/layers/attention.py +++ b/torchmultimodal/modules/layers/attention.py @@ -67,70 +67,6 @@ def forward( return out.unflatten(2, shape), attn_probs -class AxialAttention(nn.Module): - """Computes attention over a single axis of the input. Other dims are flattened into the batch dimension. - - Args: - axial_dim (int): Dimension to compute attention on, indexed by input dimensions - (i.e., ``0`` for first input dimension, ``1`` for second). - attn_dropout (float): Probability of dropout after softmax. Default is ``0.0``. - """ - - def __init__(self, axial_dim: int, attn_dropout: float = 0.0) -> None: - super().__init__() - self.axial_dim = axial_dim + 2 # account for batch, head - self.attn_dropout = attn_dropout - - def forward( - self, - q: Tensor, - k: Tensor, - v: Tensor, - attention_mask: Optional[Tensor] = None, - head_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - q (Tensor): Query input of shape ``(b, h, d1, ..., dn, dim_q)`` where ``h`` is the number of - attention heads, ``(d1, ..., dn)`` are the query latent dimensions and ``dim_q`` is the dimension - of the query embeddings. - k, v (Tensor): Key/value input of shape ``(b, h, d1', ..., dn', dim_kv)`` where ``h`` is the number - of attention heads, ``(d1', ..., dn')`` are the key/value latent dimensions and ``dim_kv`` is - the dimension of the key/value embeddings. - attention_mask (Tensor, optional): Tensor of shape ``(b, h, d1, ..., q_dn, k_dn)`` where ``q_dn`` is - the dimension of the axis to compute attention on of the query and ``k_dn`` that of the key. - Contains 1s for positions to attend to and 0s for masked positions. - head_mask (Tensor, optional): Tensor of shape ``(b, h, d1, ..., q_dn, k_dn)``. - Contains 1s for positions to attend to and 0s for masked positions. - - Returns: - A tuple of output tensor and attention probabilities. - """ - # Ensure axial dim is within right dimensions, should be between head dim and embedding dim - if self.axial_dim >= len(q.shape) - 1: - raise ValueError("axial dim does not match input shape") - - # flatten all dims into batch dimension except chosen axial dim and channel dim - # b, h, d1, ..., dn, dim_q/dim_kv -> (b, h, d1, ..., dn-1), axial_dim, dim_q/dim_kv - q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3) - k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3) - v = shift_dim(v, self.axial_dim, -2) - old_shape = list(v.shape) - v = v.flatten(end_dim=-3) - - out, attn_probs = scaled_dot_product_attention( - q, - k, - v, - attention_mask=attention_mask, - head_mask=head_mask, - attn_dropout=self.attn_dropout if self.training else 0.0, - ) - out = out.view(*old_shape) - out = shift_dim(out, -2, self.axial_dim) - return out, attn_probs - - class MultiHeadAttention(nn.Module): """Computes multihead attention with flexible attention mechanism and caching for fast decoding. @@ -204,13 +140,7 @@ def forward( Returns: * If ``return_attn_weights`` is ``True``: A tuple of output tensor and attention probabilities. * If ``return_attn_weights`` is ``False``: A single output tensor. - - Raises: - TypeError: An error occurred when ``causal`` is ``True`` and ``attn_module`` is ``AxialAttention``. """ - if isinstance(self.attn, AxialAttention) and causal: - raise TypeError("Causal axial attention is not supported.") - # If kv is specified use those inputs for cross-attention, otherwise use q k = v = q if kv is None else kv # compute q @@ -252,62 +182,6 @@ def forward( return a -class AxialAttentionBlock(nn.Module): - """Computes multihead axial attention across all dims of the input. - - Axial attention is an alternative to standard full attention, where instead - of computing attention across the entire flattened input, you compute it for - each dimension. To capture the global context that full attention does, stacking - multiple axial attention layers will allow information to propagate among the - multiple dimensions of the input. This enables attention calculations on high - dimensional inputs (images, videos) where full attention would be computationally - expensive and unfeasible. For more details, see `"Axial Attention in - Multidimensional Transformers (Ho et al. 2019)"`_ - and `"CCNet: Criss-Cross Attention for Semantic Segmentation (Huang et al. 2019) - "`_. - - Follows implementation by VideoGPT: - https://github.com/wilson1yan/VideoGPT/blob/master/videogpt/vqvae.py - - Args: - n_dims (int): Dimensionality of input data, not including batch or embedding dims. - qkv_dim (int): Dimensionality of query/key/value embedding vectors. - n_head (int): Number of heads in multihead attention. Must divide into ``qkv_dim`` - evenly. - """ - - def __init__(self, n_dims: int, qkv_dim: int, n_head: int) -> None: - super().__init__() - self.qkv_dim = qkv_dim - self.mha_attns = nn.ModuleList( - [ - MultiHeadAttention( - dim_q=qkv_dim, - dim_kv=qkv_dim, - n_head=n_head, - attn_module=AxialAttention(d), - add_bias=False, - ) - for d in range(n_dims) - ] - ) - - def forward(self, x: Tensor) -> Tensor: - n_channel = x.shape[1] - if n_channel != self.qkv_dim: - raise ValueError( - f"Input channel dimension is {n_channel}, expected {self.qkv_dim}" - ) - - h = shift_dim(x, 1, -1) # (b, c, d1, ..., dn) -> (b, d1, ..., dn, c) - attn_out = torch.zeros_like(h) - for mha_attn in self.mha_attns: - attn_out += mha_attn(h) - h = attn_out - h = shift_dim(h, -1, 1) # (b, d1, ..., dn, c) -> (b, c, d1, ..., dn) - return h - - def scaled_dot_product_attention( q: Tensor, k: Tensor,