Skip to content

Commit

Permalink
Consolidate VideoGPT components (facebookresearch#478)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#478

Condolidating all of the VideoGPT components and models into a single folder which follows our general convention. This change also removes model-specific components like AxialAttention from the general library to reduce confusion for users.

Full set of changes include:
- Move gpt.y, video_gpt.py and video_vqvae.py to models/video_gpt/
- Rename video_gpt.py to model.py to follow convention
- Move AxialAttention and AxialAttentionBlock to video_vqvae.py
- Update all tests
- Update code in examples/mugen

Reviewed By: ebsmothers, pikapecan

Differential Revision: D49753884

fbshipit-source-id: b882e6028bf166bb50efa51db81276b8f271f28e
  • Loading branch information
Kartikay Khandelwal authored and facebook-github-bot committed Oct 3, 2023
1 parent 5a0952c commit 0de91e1
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 222 deletions.
8 changes: 4 additions & 4 deletions examples/mugen/generation/text_video_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torch import nn, Tensor

from torchmultimodal.models.gpt import (
from torchmultimodal.models.video_gpt.gpt import (
MultimodalGPT,
MultimodalTransformerDecoder,
RightShift,
Expand Down Expand Up @@ -69,7 +69,7 @@ def text_video_gpt(
new frame will be of shape ``(8, 8, 8)`` with each dim divided by the rate of downsample. Defaults to
``(4, 32, 32)``.
d_model (int): Dimension of the underlying transformer decoder.
See :py:class:`torchmultimodal.models.gpt.TransformerDecoderLayer`. Defaults to ``768``.
See :py:class:`torchmultimodal.models.video_gpt.gpt.TransformerDecoderLayer`. Defaults to ``768``.
n_head (int): Number of attention heads used by the transformer decoder. Defaults to ``8``.
dropout (float): Dropout probability used by the projection layer of the transformer decoder.
Defaults to ``0.2``.
Expand All @@ -93,7 +93,7 @@ def text_video_gpt(
Defaults to ``None``.
Returns:
An instance of :py:class:`torchmultimodal.models.gpt.MultimodalGPT`.
An instance of :py:class:`torchmultimodal.models.video_gpt.gpt.MultimodalGPT`.
"""

# builds text tokenizer from pre-trained
Expand Down Expand Up @@ -195,7 +195,7 @@ class TextTokenizer(nn.Module):
"""Converts between text and tokens / embedings
Wrapper around the tokenizer to be consistent with the API required by
:py:class:`torchmultimodal.models.gpt.MultimodalGPT`. It also contains the
:py:class:`torchmultimodal.models.video_gpt.gpt.MultimodalGPT`. It also contains the
embedding layer to enable lookup by token ids.
"""

Expand Down
8 changes: 4 additions & 4 deletions examples/mugen/generation/video_vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Optional

from torchmultimodal.models.video_vqvae import (
from torchmultimodal.models.video_gpt.video_vqvae import (
preprocess_int_conv_params,
VideoDecoder,
VideoEncoder,
Expand Down Expand Up @@ -50,7 +50,7 @@ def video_vqvae_mugen(
n_res_layers (int, optional): Number of ``AttentionResidualBlocks`` to include in encoder and decoder.
Defaults to ``4``.
attn_hidden_dim (int, optional): Size of hidden dim of
:class:`~torchmultimodal.models.video_vqvae.AttentionResidualBlocks`. Defaults to ``240``.
:class:`~torchmultimodal.models.video_gpt.video_vqvae.AttentionResidualBlocks`. Defaults to ``240``.
num_embeddings (int, optional): Number of embedding vectors used in
:class:`~torchmultimodal.modules.layers.codebook.Codebook`. Defaults to ``2048``.
embedding_dim (int, optional): Dimensionality of embedding vectors in
Expand All @@ -63,8 +63,8 @@ def video_vqvae_mugen(
Returns:
An instance of :class:`~torchmultimodal.models.vqvae.VQVAE` constructed with:
* :class:`~torchmultimodal.model.video_vqvae.VideoEncoder`
* :class:`~torchmultimodal.model.video_vqvae.VideoDecoder`
* :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoEncoder`
* :class:`~torchmultimodal.model.video_gpt.video_vqvae.VideoDecoder`
"""
encoder_strides = ((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2), (1, 1, 1))
decoder_strides = ((2, 2, 2), (2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2))
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tests.test_utils import assert_expected, assert_expected_namedtuple, set_rng_seed
from torch import nn
from torch.nn import functional as F
from torchmultimodal.models.gpt import (
from torchmultimodal.models.video_gpt.gpt import (
MultimodalGPT,
MultimodalGPTOutput,
MultimodalTransformerDecoder,
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_initialize_parameters(self, gpt, mocker):
# Testing mean and std of the initialized weights data requires a large
# amount samples to be statistically stable. Here we just test whether
# the method in question has been called to avoid test flakiness.
mock_init = mocker.patch("torchmultimodal.models.gpt.Tensor.normal_")
mock_init = mocker.patch("torchmultimodal.models.video_gpt.gpt.Tensor.normal_")
gpt = gpt(use_gpt_init=True)
mock_init.assert_called()

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_video_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from tests.test_utils import assert_expected, set_rng_seed

from torchmultimodal.models.video_gpt import video_gpt, video_vqvae
from torchmultimodal.models.video_gpt.model import video_gpt, video_vqvae


@pytest.fixture(autouse=True)
Expand Down
89 changes: 88 additions & 1 deletion tests/models/test_video_vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from itertools import repeat

import pytest
import torch
from tests.test_utils import assert_expected, assert_expected_namedtuple, set_rng_seed

from torchmultimodal.models.video_vqvae import (
from torchmultimodal.models.video_gpt.video_vqvae import (
AttentionResidualBlock,
AxialAttention,
AxialAttentionBlock,
preprocess_int_conv_params,
video_vqvae,
VideoDecoder,
Expand All @@ -36,6 +40,89 @@ def input_tensor():
return torch.ones(1, 2, 2, 2, 2)


class TestAxialBlock:
@pytest.fixture
def hidden_dim(self):
return 3

@pytest.fixture
def n_dim(self):
return 3

@pytest.fixture
def input_shape(self, n_dim):
return tuple(repeat(2, n_dim))

@pytest.fixture
def axial_block(self, input_shape, hidden_dim):
return AxialAttentionBlock(len(input_shape), hidden_dim, 1)

@pytest.fixture
def q(self, input_shape, hidden_dim):
n_heads = 1
return torch.randn(1, n_heads, *input_shape, hidden_dim // n_heads)

@pytest.fixture
def kv(self, input_shape, hidden_dim):
n_heads = 1
return torch.randn(1, n_heads, *input_shape, hidden_dim // n_heads)

@pytest.fixture
def axial_attn(self):
return AxialAttention(1) # only on second axis of input

def test_axial_attention(self, axial_attn, q, kv):
k = v = kv
actual, _ = axial_attn(q, k, v)
expected = torch.tensor(
[
[
[
[
[[-0.5869, 1.8958, 0.8688], [0.0299, 0.2098, 1.2741]],
[[-0.6662, 1.9747, 0.8980], [0.1002, 0.2094, 1.5472]],
],
[
[[0.5902, -0.3275, -0.8727], [-1.0557, 1.0791, 0.3916]],
[[0.6623, -0.3223, -0.8948], [-1.0755, 1.0763, 0.3708]],
],
]
]
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_axial_block_forward(self, axial_block, hidden_dim, input_shape):
"""Test AxialAttentionBlock with sub-components"""
x = 2 * torch.ones(1, hidden_dim, *input_shape)
actual = axial_block(x)
expected = torch.tensor(
[
[
[
[[0.822055, 0.822055], [0.822055, 0.822055]],
[[0.822055, 0.822055], [0.822055, 0.822055]],
],
[
[[-0.767143, -0.767143], [-0.767143, -0.767143]],
[[-0.767143, -0.767143], [-0.767143, -0.767143]],
],
[
[[-0.916860, -0.916860], [-0.916860, -0.916860]],
[[-0.916860, -0.916860], [-0.916860, -0.916860]],
],
]
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_axial_block_channel_dim(self, axial_block, hidden_dim, input_shape):
"""Test dim check in forward of AxialAttentionBlock"""
x = torch.zeros(1, hidden_dim + 1, *input_shape)
with pytest.raises(ValueError):
_ = axial_block(x)


class TestAttentionResidualBlock:
def test_hidden_dim_assertion(self):
with pytest.raises(ValueError):
Expand Down
65 changes: 0 additions & 65 deletions tests/modules/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import torch
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.modules.layers.attention import (
AxialAttention,
AxialAttentionBlock,
merge_multihead,
MultiHeadAttention,
scaled_dot_product_attention,
Expand Down Expand Up @@ -58,11 +56,6 @@ def self_attn():
return SelfAttention(attn_dropout=0.0)


@pytest.fixture
def axial_attn():
return AxialAttention(1) # only on second axis of input


class TestMultiheadAttention:
@pytest.fixture
def multihead_attn(self, hidden_dim):
Expand Down Expand Up @@ -395,28 +388,6 @@ def test_self_attention(self_attn, q, kv):
assert_expected(actual, expected, rtol=0, atol=1e-4)


def test_axial_attention(axial_attn, q, kv):
k = v = kv
actual, _ = axial_attn(q, k, v)
expected = torch.tensor(
[
[
[
[
[[-0.5869, 1.8958, 0.8688], [0.0299, 0.2098, 1.2741]],
[[-0.6662, 1.9747, 0.8980], [0.1002, 0.2094, 1.5472]],
],
[
[[0.5902, -0.3275, -0.8727], [-1.0557, 1.0791, 0.3916]],
[[0.6623, -0.3223, -0.8948], [-1.0755, 1.0763, 0.3708]],
],
]
]
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)


def test_split_multihead(input_shape):
x = torch.randn(1, *input_shape, 6) # (b, d1, ..., dn, c)
out = split_multihead(x, 2)
Expand All @@ -430,39 +401,3 @@ def test_merge_multihead(input_shape, hidden_dim, q):
actual = torch.tensor(out.shape)
expected = torch.tensor((1, *input_shape, hidden_dim))
assert_expected(actual, expected)


class TestAxialBlock:
@pytest.fixture
def axial_block(self, input_shape, hidden_dim):
return AxialAttentionBlock(len(input_shape), hidden_dim, 1)

def test_axial_block_forward(self, axial_block, hidden_dim, input_shape):
"""Test AxialAttentionBlock with sub-components"""
x = 2 * torch.ones(1, hidden_dim, *input_shape)
actual = axial_block(x)
expected = torch.tensor(
[
[
[
[[0.822055, 0.822055], [0.822055, 0.822055]],
[[0.822055, 0.822055], [0.822055, 0.822055]],
],
[
[[-0.767143, -0.767143], [-0.767143, -0.767143]],
[[-0.767143, -0.767143], [-0.767143, -0.767143]],
],
[
[[-0.916860, -0.916860], [-0.916860, -0.916860]],
[[-0.916860, -0.916860], [-0.916860, -0.916860]],
],
]
]
)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_axial_block_channel_dim(self, axial_block, hidden_dim, input_shape):
"""Test dim check in forward of AxialAttentionBlock"""
x = torch.zeros(1, hidden_dim + 1, *input_shape)
with pytest.raises(ValueError):
_ = axial_block(x)
2 changes: 1 addition & 1 deletion tests/utils/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from tests.test_utils import assert_expected, set_rng_seed

from torchmultimodal.models.video_gpt import video_gpt
from torchmultimodal.models.video_gpt.model import video_gpt
from torchmultimodal.utils.generate import (
GenerationUtil,
get_logits_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class TransformerDecoderOutput(NamedTuple):
"""Outputs from :class:`~torchmultimodal.models.gpt.TransformerDecoder`.
"""Outputs from :class:`~torchmultimodal.models.video_gpt.gpt.TransformerDecoder`.
Attributes:
last_hidden_states (Tensor): Output from the last layer of the transformer.
Expand All @@ -36,7 +36,7 @@ class TransformerDecoderOutput(NamedTuple):


class TransformerLayerOutput(NamedTuple):
"""Outputs from :class:`~torchmultimodal.models.gpt.TransformerDecoderLayer`.
"""Outputs from :class:`~torchmultimodal.models.video_gpt.gpt.TransformerDecoderLayer`.
Attributes:
hidden_states (Tensor): Output from the current layer.
Expand All @@ -52,7 +52,7 @@ class TransformerLayerOutput(NamedTuple):


class MultimodalGPTOutput(NamedTuple):
"""Outputs from :meth:`~torchmultimodal.models.gpt.MultimodalGPT.forward`.
"""Outputs from :meth:`~torchmultimodal.models.video_gpt.gpt.MultimodalGPT.forward`.
Attributes:
decoder_output (TransformerDeocoderOutput): Contains output from the multimodal transformer decoder.
Expand Down Expand Up @@ -200,7 +200,7 @@ def forward(
Defaults to ``False``.
Returns:
An instance of :class:`~torchmultimodal.models.gpt.MultimodalGPTOutput`.
An instance of :class:`~torchmultimodal.models.video_gpt.gpt.MultimodalGPTOutput`.
"""
decoder_output = self.fwd(
in_tokens=in_tokens,
Expand Down Expand Up @@ -462,7 +462,7 @@ def forward(
Defaults to ``False``.
Returns:
An instace of :class:`~torchmultimodal.models.gpt.TransformerDecoderOutput`.
An instace of :class:`~torchmultimodal.models.video_gpt.gpt.TransformerDecoderOutput`.
"""
if (in_modality is None) and (out_modality is None):
raise ValueError(
Expand Down Expand Up @@ -562,7 +562,7 @@ def forward(
Defaults to ``False``.
Returns:
An instance of :class:`~torchmultimodal.models.gpt.TransformerDecoderOutput`.
An instance of :class:`~torchmultimodal.models.video_gpt.gpt.TransformerDecoderOutput`.
"""
if attn_mask is not None and attn_mask.dim() == 2:
attn_mask = attn_mask[
Expand Down Expand Up @@ -680,7 +680,7 @@ def forward(
Defaults to ``False``.
Returns:
An instance of :class:`~torchmultimodal.models.gpt.TransformerLayerOutput`.
An instance of :class:`~torchmultimodal.models.video_gpt.gpt.TransformerLayerOutput`.
"""
attn_probs = None
past_key_values = None
Expand Down
Loading

0 comments on commit 0de91e1

Please sign in to comment.