From 6433f8fce39992a20eb629cff2203f27b35d6b0d Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Wed, 15 Nov 2023 10:18:26 -0800 Subject: [PATCH] Add CoCa model (#506) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/506 imported-using-ghimport Test Plan: Imported from OSS Reviewed By: pbontrager Differential Revision: D51332627 Pulled By: ebsmothers fbshipit-source-id: 3569349bf143283950501dc33e2a90e3807e9508 --- tests/models/coca/__init__.py | 5 + tests/models/coca/test_coca_model.py | 141 +++++ tests/models/coca/test_multimodal_decoder.py | 115 ++++ tests/models/coca/test_text_decoder.py | 282 ++++++++++ tests/modules/layers/test_attention_pooler.py | 116 ++++ torchmultimodal/models/coca/__init__.py | 5 + torchmultimodal/models/coca/coca_model.py | 503 ++++++++++++++++++ .../models/coca/multimodal_decoder.py | 108 ++++ torchmultimodal/models/coca/text_decoder.py | 252 +++++++++ .../modules/layers/attention_pooler.py | 101 ++++ 10 files changed, 1628 insertions(+) create mode 100644 tests/models/coca/__init__.py create mode 100644 tests/models/coca/test_coca_model.py create mode 100644 tests/models/coca/test_multimodal_decoder.py create mode 100644 tests/models/coca/test_text_decoder.py create mode 100644 tests/modules/layers/test_attention_pooler.py create mode 100644 torchmultimodal/models/coca/__init__.py create mode 100644 torchmultimodal/models/coca/coca_model.py create mode 100644 torchmultimodal/models/coca/multimodal_decoder.py create mode 100644 torchmultimodal/models/coca/text_decoder.py create mode 100644 torchmultimodal/modules/layers/attention_pooler.py diff --git a/tests/models/coca/__init__.py b/tests/models/coca/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/tests/models/coca/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/models/coca/test_coca_model.py b/tests/models/coca/test_coca_model.py new file mode 100644 index 00000000..4c479d5b --- /dev/null +++ b/tests/models/coca/test_coca_model.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed +from torchmultimodal.models.coca.coca_model import ( + coca_vit, + CoCaForPretraining, + CoCaModelOutput, +) + + +class TestCoCaModel: + @pytest.fixture(autouse=True) + def random(self): + set_rng_seed(0) + + @pytest.fixture + def batch_size(self): + return 2 + + @pytest.fixture + def vocab_size(self): + return 50 + + @pytest.fixture + def num_text_positions(self): + return 11 + + @pytest.fixture + def attention_pooler_output_dim(self): + return 8 + + @pytest.fixture + def text_output_dim(self): + return 8 + + @pytest.fixture + def image_size(self): + return 12 + + @pytest.fixture + def coca_model( + self, + vocab_size, + num_text_positions, + attention_pooler_output_dim, + text_output_dim, + image_size, + ): + coca_model = coca_vit( + vision_patch_size=4, + vision_dim_feedforward=24, + vision_n_layer=2, + vision_n_head=2, + vocab_size=vocab_size, + num_text_positions=num_text_positions, + text_hidden_dim=8, + text_n_layer=2, + text_n_head=2, + text_dim_feedforward=32, + text_output_dim=text_output_dim, + fusion_n_layer=2, + fusion_n_head=2, + fusion_dim_feedforward=32, + multimodal_output_projection_dim=vocab_size, + pooler_input_embed_dim=6, + pooler_output_embed_dim=attention_pooler_output_dim, + image_size=image_size, + pooler_n_head=2, + cascaded_pooler=False, + ) + init_weights_with_constant(coca_model) + coca_model.eval() + return coca_model + + @pytest.fixture + def text_inputs(self): + return torch.LongTensor( + [ + [1, 3, 4, 5, 6, 7, 8, 2, 0, 0, 0], + [1, 25, 28, 34, 39, 45, 40, 5, 12, 6, 2], + ] + ) + + @pytest.fixture + def image_inputs(self, batch_size, image_size): + return torch.randn(batch_size, 3, image_size, image_size) + + @pytest.fixture + def expected( + self, + batch_size, + vocab_size, + num_text_positions, + attention_pooler_output_dim, + text_output_dim, + ): + pooled_val = 0.3536 + logit_val = 8.0 + return CoCaModelOutput( + image_pooled_output=pooled_val + * torch.ones(batch_size, attention_pooler_output_dim), + text_pooled_output=pooled_val * torch.ones(batch_size, text_output_dim), + multimodal_embeddings=logit_val + * torch.ones(batch_size, num_text_positions - 1, vocab_size), + ) + + @pytest.fixture + def coca_for_pretraining(self, coca_model): + coca_for_pretraining = CoCaForPretraining(coca_model) + init_weights_with_constant(coca_for_pretraining) + coca_for_pretraining.eval() + return coca_for_pretraining + + def test_coca_model(self, text_inputs, image_inputs, coca_model, expected): + actual = coca_model(image_inputs, text_inputs) + assert_expected(actual, expected, rtol=0, atol=1e-4) + + def test_scripting(self, text_inputs, image_inputs, coca_model): + scripted_model = torch.jit.script(coca_model) + assert_expected( + scripted_model(image_inputs, text_inputs), + coca_model(image_inputs, text_inputs), + rtol=0, + atol=1e-4, + ) + + def test_coca_for_pretraining( + self, text_inputs, image_inputs, coca_for_pretraining + ): + actual_losses = coca_for_pretraining(image_inputs, text_inputs) + expected_losses = { + "contrastive": torch.tensor(0.6931), + "captioning": torch.tensor(3.9120), + } + assert_expected(actual_losses, expected_losses, rtol=0, atol=1e-4) diff --git a/tests/models/coca/test_multimodal_decoder.py b/tests/models/coca/test_multimodal_decoder.py new file mode 100644 index 00000000..e67bbdea --- /dev/null +++ b/tests/models/coca/test_multimodal_decoder.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, init_weights_with_constant +from torch import nn, Tensor +from torchmultimodal.models.coca.multimodal_decoder import CoCaMultimodalDecoder + + +class TestCoCaMultimodalDecoder: + @pytest.fixture + def batch_size(self): + return 2 + + @pytest.fixture + def input_seq_len(self): + return 5 + + @pytest.fixture + def num_image_positions(self): + return 10 + + @pytest.fixture + def text_embedding_dim(self): + return 4 + + @pytest.fixture + def multimodal_decoder(self, input_seq_len, batch_size, text_embedding_dim): + decoder = CoCaMultimodalDecoder( + input_seq_len=input_seq_len, + text_embedding_dim=text_embedding_dim, + n_layer=2, + n_head=2, + dim_feedforward=4 * text_embedding_dim, + output_dim=3, + final_layer_norm_eps=1e-5, + ) + init_weights_with_constant(decoder) + + # Custom init final MLP layer weight, final LN, and text projection + decoder.transformer_decoder.layer[1].feedforward.model[2].weight = nn.Parameter( + torch.arange( + decoder.transformer_decoder.layer[1] + .feedforward.model[2] + .weight.numel(), + dtype=torch.float, + ).reshape( + decoder.transformer_decoder.layer[1].feedforward.model[2].weight.shape + ) + ) + decoder.output_projection.weight = nn.Parameter( + torch.arange(decoder.output_projection.weight.numel(), dtype=torch.float) + .reshape(decoder.output_projection.weight.T.shape) + .T + ) + decoder.transformer_decoder.final_layer_norm.weight = nn.Parameter( + torch.arange( + decoder.transformer_decoder.final_layer_norm.weight.numel(), + dtype=torch.float, + ) + ) + decoder.eval() + return decoder + + @pytest.fixture + def text_inputs(self, batch_size, input_seq_len, text_embedding_dim): + return torch.arange(0.0, 1.0, 1.0 / 40).reshape( + batch_size, input_seq_len, text_embedding_dim + ) + + @pytest.fixture + def image_inputs(self, batch_size, num_image_positions, text_embedding_dim): + return torch.arange(10.0, 20.0, 1.0 / 8).reshape( + batch_size, num_image_positions, text_embedding_dim + ) + + @pytest.fixture + def expected(self): + return Tensor( + [ + [ + [58.2492, 66.7214, 75.1935], + [58.2492, 66.7214, 75.1935], + [58.2492, 66.7214, 75.1935], + [58.2492, 66.7214, 75.1935], + [58.2492, 66.7214, 75.1935], + ], + [ + [58.2492, 66.7214, 75.1935], + [58.2492, 66.7214, 75.1935], + [58.2492, 66.7214, 75.1935], + [58.2492, 66.7214, 75.1935], + [58.2492, 66.7214, 75.1935], + ], + ] + ) + + def test_coca_multimodal_decoder( + self, text_inputs, image_inputs, multimodal_decoder, expected + ): + actual = multimodal_decoder(text_inputs, image_inputs) + assert_expected(actual, expected, rtol=0, atol=1e-4) + + def test_scripting(self, text_inputs, image_inputs, multimodal_decoder): + scripted_multimodal_decoder = torch.jit.script(multimodal_decoder) + assert_expected( + scripted_multimodal_decoder(text_inputs, image_inputs), + multimodal_decoder(text_inputs, image_inputs), + rtol=0, + atol=1e-4, + ) diff --git a/tests/models/coca/test_text_decoder.py b/tests/models/coca/test_text_decoder.py new file mode 100644 index 00000000..2ad142d8 --- /dev/null +++ b/tests/models/coca/test_text_decoder.py @@ -0,0 +1,282 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed +from torch import nn, Tensor +from torchmultimodal.models.coca.text_decoder import CoCaTextDecoder, CoCaTextEmbeddings + + +class TestCoCaTextEmbeddings: + @pytest.fixture + def vocab_size(self) -> int: + return 15 + + @pytest.fixture + def num_positions(self) -> int: + return 7 + + @pytest.fixture + def embedding_dim(self) -> int: + return 10 + + @pytest.fixture + def batch_size(self) -> int: + return 2 + + @pytest.fixture + def text_embeddings(self, vocab_size, num_positions, embedding_dim): + embeddings = CoCaTextEmbeddings( + vocab_size=vocab_size, + num_positions=num_positions, + embedding_dim=embedding_dim, + pad_idx=1, + ) + init_weights_with_constant(embeddings) + + # Set CLS embedding and token embeddings to ranges + embeddings.cls_embedding = nn.Parameter( + torch.arange(embeddings.cls_embedding.shape[0], dtype=torch.float) + ) + embeddings.token_embeddings.weight = nn.Parameter( + torch.arange(vocab_size, dtype=torch.float) + .unsqueeze(1) + .expand_as(embeddings.token_embeddings.weight) + ) + + return embeddings + + @pytest.fixture + def inputs(self): + return torch.LongTensor([[4, 5, 6, 7, 0, 1, 2], [11, 12, 13, 14, 0, 2, 1]]) + + @pytest.fixture + def expected(self, inputs, embedding_dim, batch_size): + embeds = (inputs[:, :-1] + 1).unsqueeze(-1).repeat(1, 1, embedding_dim) + cls_embeds = ( + torch.arange(1, embedding_dim + 1) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + ) + expected = torch.cat([embeds, cls_embeds], dim=1).to(dtype=torch.float) + return expected + + def test_coca_text_embeddings(self, inputs, text_embeddings, expected): + actual = text_embeddings(inputs[:, :-1]) + assert_expected(actual, expected) + + +class TestCoCaTextDecoder: + @pytest.fixture + def batch_size(self): + return 2 + + @pytest.fixture + def embedding_dim(self): + return 8 + + @pytest.fixture + def vocab_size(self): + return 12 + + @pytest.fixture + def num_positions(self): + return 6 + + @pytest.fixture + def default_pad_idx(self): + return 0 + + @pytest.fixture + def output_dim(self): + return 3 + + @pytest.fixture + def get_text_decoder( + self, + batch_size, + vocab_size, + num_positions, + embedding_dim, + output_dim, + ): + def create_text_decoder( + pad_idx: int = 0, + embed_cls: bool = True, + custom_init: bool = False, + num_positions: int = num_positions, + ): + decoder = CoCaTextDecoder( + vocab_size=vocab_size, + num_positions=num_positions, + embedding_dim=embedding_dim, + n_layer=2, + n_head=2, + dim_feedforward=4 * embedding_dim, + output_dim=output_dim, + pad_idx=pad_idx, + embed_cls=embed_cls, + ) + + init_weights_with_constant(decoder) + if custom_init: + set_rng_seed(0) + nn.init.normal_(decoder.embeddings.token_embeddings.weight) + nn.init.normal_(decoder.text_projection.weight) + for block in decoder.transformer_decoder.layer: + nn.init.normal_(block.attention.q_proj.weight) + nn.init.normal_(block.attention.k_proj.weight) + nn.init.normal_(block.attention.v_proj.weight) + nn.init.normal_(block.attention.output_proj.weight) + if decoder.embeddings.cls_embedding is not None: + nn.init.normal_(decoder.embeddings.cls_embedding) + + decoder.eval() + return decoder + + return create_text_decoder + + @pytest.fixture + def input_ids(self): + return torch.LongTensor([[2, 4, 5, 7, 9, 1], [6, 8, 1, 0, 0, 0]]) + + @pytest.fixture + def padding_mask(self, input_ids, default_pad_idx): + return input_ids != default_pad_idx + + @pytest.fixture + def expected(self, batch_size, output_dim, num_positions, embedding_dim): + return ( + Tensor([8]).repeat(batch_size, output_dim), + Tensor([726]).repeat(batch_size, num_positions - 1, embedding_dim), + ) + + @pytest.fixture + def expected_attention_mask(self, batch_size, num_positions): + return torch.BoolTensor( + [ + [ + [ + [True, False, False, False, False, False, False], + [True, True, False, False, False, False, False], + [True, True, True, False, False, False, False], + [True, True, True, True, False, False, False], + [True, True, True, True, True, False, False], + [True, True, True, True, True, True, False], + [True, True, True, True, True, True, True], + ] + ], + [ + [ + [True, False, False, False, False, False, False], + [True, True, False, False, False, False, False], + [True, True, True, False, False, False, False], + [True, True, True, True, False, False, False], + [True, True, True, True, True, False, False], + [True, True, True, True, True, True, False], + [True, True, True, True, False, False, False], + ], + ], + ] + ) + + @pytest.mark.parametrize( + "pad_idx, embed_cls, expected_pooled, expected_tokens_shape, expected_tokens_mean", + [ + ( + 0, + True, + torch.Tensor([[5.5019, -4.5114, 3.0416], [3.4487, -6.2877, 3.1439]]), + torch.Size([2, 5, 8]), + torch.Tensor( + [ + [585.0038, 587.7021, 588.5288, 585.5997, 588.6697], + [586.2949, 585.1484, 588.0995, 590.9081, 591.0029], + ] + ), + ), + ( + None, + True, + torch.Tensor([[5.5019, -4.5114, 3.0416], [3.4142, -6.3097, 3.1282]]), + torch.Size([2, 5, 8]), + torch.Tensor( + [ + [585.0038, 587.7021, 588.5288, 585.5997, 588.6697], + [586.2949, 585.1484, 588.0995, 590.9081, 591.0029], + ] + ), + ), + ( + None, + False, + torch.Tensor([[5.8831, -4.7312, 3.1304], [4.3524, -5.0214, 2.7832]]), + torch.Size([2, 6, 8]), + torch.Tensor( + [ + [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], + [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], + ] + ), + ), + ], + ) + def test_coca_text_decoder( + self, + input_ids, + get_text_decoder, + pad_idx, + embed_cls, + expected_pooled, + expected_tokens_shape, + expected_tokens_mean, + ): + text_decoder = get_text_decoder( + pad_idx=pad_idx, embed_cls=embed_cls, custom_init=True + ) + actual_pooled, actual_tokens = text_decoder(input_ids) + assert_expected(actual_pooled, expected_pooled, rtol=0, atol=1e-3) + assert_expected(actual_tokens.size(), expected_tokens_shape) + assert_expected( + actual_tokens.mean(dim=-1), expected_tokens_mean, rtol=0, atol=1e-3 + ) + + def test_coca_text_decoder_with_padding_mask( + self, input_ids, padding_mask, get_text_decoder, expected + ): + text_decoder = get_text_decoder() + actual = text_decoder(input_ids) + actual_with_padding_mask = text_decoder(input_ids, padding_mask) + assert_expected(actual, expected) + assert_expected(actual_with_padding_mask, expected) + + def test_build_attention_mask( + self, + num_positions, + input_ids, + get_text_decoder, + padding_mask, + expected_attention_mask, + ): + # Since embed_cls is True the mask will contain an extra token + text_decoder = get_text_decoder(num_positions=num_positions + 1) + + inferred_padding_mask = text_decoder.build_mask(input_ids) + explicit_padding_mask = text_decoder.build_mask(input_ids, padding_mask) + assert_expected(inferred_padding_mask, expected_attention_mask) + assert_expected(explicit_padding_mask, expected_attention_mask) + + def test_scripting(self, get_text_decoder, input_ids): + text_decoder = get_text_decoder() + scripted_text_decoder = torch.jit.script(text_decoder) + assert_expected( + scripted_text_decoder(input_ids), + text_decoder(input_ids), + rtol=0, + atol=1e-4, + ) diff --git a/tests/modules/layers/test_attention_pooler.py b/tests/modules/layers/test_attention_pooler.py new file mode 100644 index 00000000..3b09fb02 --- /dev/null +++ b/tests/modules/layers/test_attention_pooler.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed +from torchmultimodal.modules.layers.attention_pooler import ( + AttentionPooler, + CascadedAttentionPooler, +) + + +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(0) + + +class TestAttentionPooler: + @pytest.fixture + def batch_size(self): + return 2 + + @pytest.fixture + def input_embed_dim(self): + return 4 + + @pytest.fixture + def output_embed_dim(self): + return 6 + + @pytest.fixture + def cascaded_output_embed_dim(self): + return 10 + + @pytest.fixture + def seq_len(self): + return 8 + + @pytest.fixture + def n_head(self): + return 2 + + @pytest.fixture + def n_queries(self): + return 12 + + @pytest.fixture + def inputs(self, batch_size, seq_len, input_embed_dim): + return torch.randn(batch_size, seq_len, input_embed_dim) + + @pytest.fixture + def pooler(self, input_embed_dim, output_embed_dim, n_head, n_queries): + pooler = AttentionPooler( + input_embed_dim=input_embed_dim, + output_embed_dim=output_embed_dim, + n_head=n_head, + n_queries=n_queries, + ) + init_weights_with_constant(pooler) + return pooler + + @pytest.fixture + def cascaded_pooler( + self, pooler, output_embed_dim, cascaded_output_embed_dim, n_head + ): + second_pooler = AttentionPooler( + input_embed_dim=output_embed_dim, + output_embed_dim=cascaded_output_embed_dim, + n_head=n_head, + n_queries=1, + ) + init_weights_with_constant(second_pooler) + cascaded_pooler = CascadedAttentionPooler([pooler, second_pooler]) + return cascaded_pooler + + def test_forward(self, pooler, inputs, batch_size, n_queries, output_embed_dim): + actual = pooler(inputs) + expected_shape = (batch_size, n_queries, output_embed_dim) + expected_sum = torch.tensor(144.0) + assert_expected(actual.shape, expected_shape) + assert_expected(actual.sum(), expected_sum) + + def test_torchscript(self, pooler, inputs): + scripted_pooler = torch.jit.script(pooler) + out = pooler(inputs) + scripted_out = scripted_pooler(inputs) + assert_expected(scripted_out, out) + + def test_cascaded_pooler_forward( + self, + cascaded_pooler, + inputs, + batch_size, + n_queries, + output_embed_dim, + cascaded_output_embed_dim, + ): + actual = cascaded_pooler(inputs) + + expected_shapes = [ + (batch_size, n_queries, output_embed_dim), + (batch_size, 1, cascaded_output_embed_dim), + ] + expected_sums = [torch.tensor(144.0), torch.tensor(20.0)] + assert_expected([x.shape for x in actual], expected_shapes) + assert_expected([x.sum() for x in actual], expected_sums) + + def test_cascaded_pooler_torchscript(self, cascaded_pooler, inputs): + scripted_pooler = torch.jit.script(cascaded_pooler) + out = cascaded_pooler(inputs) + scripted_out = scripted_pooler(inputs) + assert_expected(scripted_out, out) diff --git a/torchmultimodal/models/coca/__init__.py b/torchmultimodal/models/coca/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/torchmultimodal/models/coca/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchmultimodal/models/coca/coca_model.py b/torchmultimodal/models/coca/coca_model.py new file mode 100644 index 00000000..3b48de4e --- /dev/null +++ b/torchmultimodal/models/coca/coca_model.py @@ -0,0 +1,503 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from functools import partial +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torchmultimodal.models.coca.multimodal_decoder import CoCaMultimodalDecoder +from torchmultimodal.models.coca.text_decoder import CoCaTextDecoder +from torchmultimodal.modules.encoders.vision_transformer import vision_transformer +from torchmultimodal.modules.layers.attention_pooler import ( + AttentionPooler, + CascadedAttentionPooler, +) +from torchmultimodal.modules.layers.transformer import TransformerOutput +from torchmultimodal.modules.losses.contrastive_loss_with_temperature import ( + ContrastiveLossWithTemperature, +) + + +class CoCaModelOutput(NamedTuple): + image_pooled_output: Tensor + text_pooled_output: Tensor + multimodal_embeddings: Tensor + + +class CoCaModel(nn.Module): + """ + CoCa model class containing vision encoder, text decoder, and multimodal decoder. + Reference: https://arxiv.org/abs/2205.01917 + Args: + vision_encoder (nn.Module): Instantiated vision encoder. Should return either + TransformerOutput or Tensor. + text_decoder (CoCaTextDecoder): Instantiated CoCaTextDecoder returning a + Tuple[Tensor, Tensor], where the first element is the normalized CLS + embedding, and the second element is the full set of token embeddings. + multimodal_decoder (nn.Module): Instantiated CoCaMultimodalDecoder returning a + Tensor of multimodal embeddings. + vision_pooler (nn.Module): Pooler for vision outputs (see e.g. AttentionPooler). + vision_proj (nn.Module): Projection layer for vision encoder. Note that the + projections for the text_decoder and multimodal_decoder are handled inside + the CoCaTextDecoder and CoCaMultimodalDecoder classes, respectively, but + for vision we apply attentional pooling first so the vision projection + is separated from the vision_encoder class. + """ + + def __init__( + self, + vision_encoder: nn.Module, # e.g. ViT + text_decoder: CoCaTextDecoder, + multimodal_decoder: CoCaMultimodalDecoder, + vision_pooler: nn.Module, + vision_proj: nn.Module, + ): + super().__init__() + self.vision_encoder = vision_encoder + self.text_decoder = text_decoder + self.multimodal_decoder = multimodal_decoder + self.vision_pooler = vision_pooler + self.vision_proj = vision_proj + + def forward( + self, images: Tensor, texts: Tensor, text_padding_mask: Optional[Tensor] = None + ) -> CoCaModelOutput: + """ + Args: + images (Tensor): Tensor of size (bsz, c, h, w) containing image pixels. + texts (Tensor): Tensor of size (bsz, seq_len) containing text tokens. + text_padding_mask (Optional[Tensor]): Boolean mask indicating padded tokens. + True for unpadded tokens, False for padded tokens. Default: None + Returns: + CoCaModelOutput containing pooled image embeddings, text embeddings, + and multimodal embeddings. + """ + + # Image encoder + vision_encoder_outs = self.vision_encoder(images) + + # Extract image embeddings + if isinstance(vision_encoder_outs, TransformerOutput): + image_embeddings = vision_encoder_outs.last_hidden_state + elif isinstance(vision_encoder_outs, tuple): + vision_encoder_outs = vision_encoder_outs[0] + assert isinstance(vision_encoder_outs, Tensor) + image_embeddings = vision_encoder_outs + else: + assert isinstance(vision_encoder_outs, Tensor) + image_embeddings = vision_encoder_outs + assert isinstance(image_embeddings, Tensor), "Image embeddings must be Tensor" + + pooled_outputs = self.vision_pooler(image_embeddings) + if torch.jit.isinstance(pooled_outputs, List[Tensor]): + assert len(pooled_outputs) == 2 + captioning_image_embeddings, contrastive_image_embeddings = pooled_outputs + else: + assert isinstance( + pooled_outputs, Tensor + ), "Pooled image embeddings must be Tensor" + # For parallel pooler arch of CoCa, we use a single pooler and split + # the outputs for contrastive and captioning tasks + contrastive_image_embeddings, captioning_image_embeddings = ( + pooled_outputs[:, 0], + pooled_outputs[:, 1:], + ) + contrastive_image_embeddings = self.vision_proj(contrastive_image_embeddings) + contrastive_image_embeddings = F.normalize(contrastive_image_embeddings, dim=-1) + + # Text decoder + pooled_text_embeddings, text_tokens = self.text_decoder( + texts, text_padding_mask + ) + contrastive_text_embeddings = F.normalize(pooled_text_embeddings, dim=-1) + + # Multimodal decoder + multimodal_embeddings = self.multimodal_decoder( + text_tokens, captioning_image_embeddings + ) + + return CoCaModelOutput( + contrastive_image_embeddings, + contrastive_text_embeddings, + multimodal_embeddings, + ) + + +def coca_vit( + *, + # Required vision args + vision_patch_size: int, + vision_dim_feedforward: int, + vision_n_layer: int, + vision_n_head: int, + # Required text args + vocab_size: int, + num_text_positions: int, + text_hidden_dim: int, + text_n_layer: int, + text_n_head: int, + text_dim_feedforward: int, + text_output_dim: int, + # Required fusion args + fusion_n_layer: int, + fusion_n_head: int, + fusion_dim_feedforward: int, + # Required attention pooler args + pooler_input_embed_dim: int, + pooler_output_embed_dim: int, + pooler_n_head: int, + # Optional vision args + image_size: Union[int, Tuple[int, int]] = 224, + num_channels: int = 3, + vision_activation: Callable[..., nn.Module] = nn.GELU, + vision_transformer_dropout: float = 0.0, + patch_embed_dropout_prob: float = 0.0, + vision_layer_norm_eps: float = 1e-5, + vision_final_layer_norm_eps: Optional[float] = None, + vision_norm_first: bool = True, + vision_include_cls_embed: bool = False, # This is different from ViT default + vision_drop_path_rate: Optional[float] = None, + vision_patch_drop_rate: Optional[Union[float, Tuple[float, float]]] = None, + # Optional text args + pad_idx: Optional[int] = 0, + text_embed_cls: bool = True, + text_dropout: float = 0.0, + text_activation: Callable[..., nn.Module] = nn.GELU, + text_layer_norm_eps: float = 1e-5, + text_norm_first: bool = True, + text_final_layer_norm_eps: Optional[float] = 1e-5, + # Optional fusion args + fusion_dropout: float = 0.0, + fusion_activation: Callable[..., nn.Module] = nn.GELU, + fusion_layer_norm_eps: float = 1e-5, + fusion_norm_first: bool = True, + fusion_final_layer_norm_eps: Optional[float] = 1e-5, + multimodal_output_projection_dim: Optional[int] = None, + # Optional attention pooler args + cascaded_pooler: bool = True, + pooler_n_queries: int = 256, + pooler_layer_norm_eps: float = 1e-5, +) -> CoCaModel: + """ + Args: + vision_patch_size (Union[int, Tuple[int, int]]): ViT patch size + vision_dim_feedforward (int): Dimension of FFN for ViT encoder. + vision_n_layer (int): Number of layers in ViT encoder + vision_n_head (int): Number of heads in ViT encoder. + vocab_size (int): Text vocab size. + num_text_positions (int): Number of positions for text tokens. + text_hidden_dim (int): Embedding dimension in text transformer. + text_n_layer (int): Number of layers in text transformer. + text_n_head (int): Number of heads in text transformer. + text_dim_feedforward (int): Dimension of FFN for text transformer. + text_output_dim (int): Output dimension of text decoder. + fusion_n_layer (int): Number of layers in multimodal transformer. + fusion_n_head (int): Number of heads in multimodal transformer. + fusion_dim_feedforward (int): Dimension of FFN for multimodal transformer. + pooler_input_embed_dim (int): Input dimension for attention pooler. + pooler_output_embed_dim (int): Output dimension for attention pooler. + pooler_n_head (int): Number of heads in attention pooler. + image_size (Union[int, Tuple[int, int]]): Size of input image. Default: 224 + num_channels (int): Number of channels of image. Default: 3 + vision_activation (Callable[..., nn.Module]): ViT activation function. + Default: GELU + vision_transformer_dropout (float): ViT encoder dropout rate. Default: 0.0 + patch_embed_dropout_prob (float): Image patch embedding dropout rate. + Default: 0.0 + vision_layer_norm_eps (float): LN epsilon in ViT encoder. Default: 1e-5 + vision_final_layer_norm_eps (float): Final LN epsilon for ViT. + Default: 0.0 (no final LN) + vision_norm_first (bool): Whether to use pre-norm ViT layers. Default: True + vision_include_cls_embed (bool): Whether to include cls as an embedding + Default: False (to match open_clip implementation) + vision_drop_path_rate (Optional[float]): Stochastic drop path rate in ViT. + Default: None (no drop path) + vision_patch_drop_rate (Optional[Union[float, Tuple[float, float]]]): Rate + for masking patches prior to ViT encoder. Default: None (no masking) + pad_idx (int): Padding index of text. Default: 0 + text_embed_cls (bool): Whether to replace the final position of text with cls + embedding. Default: True + text_dropout (float): Dropout rate for text transformer. Default: 0.0 + text_activation (Callable[..., nn.Module]): Text transformer activation + function. Default: GELU + text_layer_norm_eps (float): LN epsilon in text transformer. Default: 1e-5 + text_norm_first (bool): Whether to use pre-norm layers in text decoder. + Default: True + text_final_layer_norm_eps (float): Final LN epsilon for text decoder. + Default: 0.0 (no final LN) + fusion_dropout (float): Dropout rate for multimodal transformer. Default: 0.0 + fusion_activation (Callable[..., nn.Module]): Activation function for + multimodal transformer. Default: GELU + fusion_layer_norm_eps (float): LN epsilon in multimodal transformer. + Default: 1e-5 + fusion_norm_first (bool): Whether to use pre-norm layers in multimodal decoder. + Default: True + fusion_final_layer_norm_eps (float): Final LN epsilon for multimodal decoder. + Default: 0.0 (no final LN) + multimodal_output_projection_dim (Optional[int]): Output dimension of + multimodal projection. If None, no projection will be applied to + multimodal embeddings. Default: None + cascaded_pooler (bool): Whether to cascade (stack) contrastive and captioning + attention poolers or parallelize them. Default: True + pooler_n_queries (int): Number of queries in attention pooler. Default: 256 + pooler_layer_norm_eps (float): LN epsilon in attention pooler. Default: 1e-5 + """ + attention_pooler: nn.Module + if cascaded_pooler: + captioning_pooler = AttentionPooler( + input_embed_dim=pooler_input_embed_dim, + output_embed_dim=pooler_output_embed_dim, + n_head=pooler_n_head, + n_queries=pooler_n_queries, + layer_norm_eps=pooler_layer_norm_eps, + ) + contrastive_pooler = AttentionPooler( + input_embed_dim=pooler_input_embed_dim, + output_embed_dim=pooler_output_embed_dim, + n_head=pooler_n_head, + n_queries=pooler_n_queries, + layer_norm_eps=pooler_layer_norm_eps, + ) + attention_pooler = CascadedAttentionPooler( + [captioning_pooler, contrastive_pooler] + ) + else: + attention_pooler = AttentionPooler( + input_embed_dim=pooler_input_embed_dim, + output_embed_dim=pooler_output_embed_dim, + n_head=pooler_n_head, + n_queries=pooler_n_queries, + layer_norm_eps=pooler_layer_norm_eps, + ) + + vision_proj = nn.Linear( + pooler_output_embed_dim, pooler_output_embed_dim, bias=False + ) + nn.init.normal_(vision_proj.weight, std=pooler_input_embed_dim**-0.5) + + vision_encoder = vision_transformer( + patch_size=vision_patch_size, + hidden_dim=pooler_input_embed_dim, + dim_feedforward=vision_dim_feedforward, + n_layer=vision_n_layer, + n_head=vision_n_head, + image_size=image_size, + num_channels=num_channels, + activation=vision_activation, + transformer_dropout=vision_transformer_dropout, + patch_embed_dropout_prob=patch_embed_dropout_prob, + layer_norm_eps=vision_layer_norm_eps, + final_layer_norm_eps=vision_final_layer_norm_eps, + norm_first=vision_norm_first, + include_cls_embed=vision_include_cls_embed, + drop_path_rate=vision_drop_path_rate, + patch_drop_rate=vision_patch_drop_rate, + ) + + text_decoder = CoCaTextDecoder( + vocab_size=vocab_size, + num_positions=num_text_positions, + embedding_dim=text_hidden_dim, + n_layer=text_n_layer, + n_head=text_n_head, + dim_feedforward=text_dim_feedforward, + output_dim=text_output_dim, + pad_idx=pad_idx, + embed_cls=text_embed_cls, + dropout=text_dropout, + activation=text_activation, + layer_norm_eps=text_layer_norm_eps, + norm_first=text_norm_first, + final_layer_norm_eps=text_final_layer_norm_eps, + ) + + mm_input_seq_len = num_text_positions - 1 if text_embed_cls else num_text_positions + + multimodal_decoder = CoCaMultimodalDecoder( + input_seq_len=mm_input_seq_len, + text_embedding_dim=pooler_output_embed_dim, + n_layer=fusion_n_layer, + n_head=fusion_n_head, + dim_feedforward=fusion_dim_feedforward, + output_dim=multimodal_output_projection_dim, + dropout=fusion_dropout, + activation=fusion_activation, + layer_norm_eps=fusion_layer_norm_eps, + norm_first=fusion_norm_first, + final_layer_norm_eps=fusion_final_layer_norm_eps, + ) + + return CoCaModel( + vision_encoder=vision_encoder, + text_decoder=text_decoder, + multimodal_decoder=multimodal_decoder, + vision_proj=vision_proj, + vision_pooler=attention_pooler, + ) + + +def coca_vit_b_32() -> CoCaModel: + return coca_vit( + vision_patch_size=32, + vision_n_layer=12, + vision_n_head=12, + vision_dim_feedforward=3072, + vision_include_cls_embed=False, + vocab_size=49408, + num_text_positions=77, + text_hidden_dim=512, + text_n_layer=12, + text_n_head=8, + text_dim_feedforward=2048, + text_output_dim=512, + fusion_n_layer=12, + fusion_n_head=8, + fusion_dim_feedforward=2048, + multimodal_output_projection_dim=49408, + pooler_input_embed_dim=768, + pooler_output_embed_dim=512, + pooler_n_head=8, + cascaded_pooler=True, + ) + + +def coca_vit_l_14() -> CoCaModel: + return coca_vit( + vision_patch_size=14, + vision_n_layer=24, + vision_n_head=16, + vision_dim_feedforward=4096, + vision_include_cls_embed=False, + vocab_size=49408, + num_text_positions=77, + text_hidden_dim=768, + text_n_layer=12, + text_n_head=12, + text_dim_feedforward=3072, + text_output_dim=768, + fusion_n_layer=12, + fusion_n_head=12, + fusion_dim_feedforward=3072, + multimodal_output_projection_dim=49408, + pooler_input_embed_dim=1024, + pooler_output_embed_dim=768, + pooler_n_head=8, + cascaded_pooler=True, + ) + + +class CoCaForPretraining(nn.Module): + """ + CoCa pretraining model class. + Ties CoCa model to captioning and contrastive losses. + Args: + model (CoCaModel): Instantiated CoCa model. + pad_idx (int): Index of padding tokens (used to filter captioning + loss indices). Default: 0 + contrastive_logit_scale_min (Optional[float]): Min clamp value for contrastive + temperature. Default: 0.0 + contrastive_logit_scale_max (Optional[float]): Max clamp value for contrastive + temperature. Default: log(100) + """ + + def __init__( + self, + model: CoCaModel, + pad_idx: int = 0, + contrastive_logit_scale_min: Optional[float] = math.log(1.0), + contrastive_logit_scale_max: Optional[float] = math.log(100.0), + ): + super().__init__() + self.model = model + self.contrastive_loss = ContrastiveLossWithTemperature( + logit_scale_min=contrastive_logit_scale_min, + logit_scale_max=contrastive_logit_scale_max, + ) + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_idx) + + def forward( + self, images: Tensor, texts: Tensor, text_padding_mask: Optional[Tensor] = None + ) -> Dict[str, Tensor]: + """ + Args: + images (Tensor): Tensor of size (bsz, c, h, w) containing image pixels. + texts (Tensor): Tensor of size (bsz, seq_len) containing text tokens. + text_padding_mask (Optional[Tensor]): Boolean mask indicating padded tokens. + True for unpadded tokens, False for padded tokens. Default: None + Returns: + Dict[str, Tensor]: Dict containing contrastive and captioning losses with + respective keys 'contrastive' and 'captioning'. + """ + model_outs = self.model(images, texts, text_padding_mask) + captioning_labels = texts[:, 1:].contiguous() + contrastive_loss = self.contrastive_loss( + model_outs.image_pooled_output, model_outs.text_pooled_output + ) + + vocab_size = model_outs.multimodal_embeddings.shape[-1] + captioning_loss = self.caption_loss( + model_outs.multimodal_embeddings.view(-1, vocab_size), + captioning_labels.view(-1), + ) + return {"contrastive": contrastive_loss, "captioning": captioning_loss} + + +def coca_for_pretraining(pad_idx: int = 0, **kwargs: Any) -> CoCaForPretraining: + model = coca_vit(**kwargs) + return CoCaForPretraining(model, pad_idx=pad_idx) + + +default_coca_cls_pooler = partial(torch.select, dim=1, index=-1) + + +class CoCaModelWithHeads(nn.Module): + """ + CoCa model with heads. + Args: + model (CoCaModel): Instantiated CoCa model. + heads (nn.ModuleDict): Dictionary of heads, taking either unimodal or + multimodal embeddings as inputs + pad_idx (int): Index of padding tokens (used to filter captioning + loss indices). Default: 0 + pooler (Callable): how to extract the the multimodal embeddings. some examples + [default] partial(torch.select, dim=1, index=-1) + partial(torch.mean, dim=1) + partial(torch.max, dim=1) + torchmultimodal.fb.modules.layers.attention_pooler.AttentionPooler + """ + + def __init__( + self, + model: CoCaModel, + heads: nn.ModuleDict, + pad_idx: int = 0, + pooler: Callable = default_coca_cls_pooler, + ): + super().__init__() + self.model = model + self.heads = heads + self.pooler = pooler + + def forward( + self, images: Tensor, texts: Tensor, text_padding_mask: Optional[Tensor] = None + ) -> Dict[str, Tensor]: + + model_out = self.model(images, texts, text_padding_mask) + mm_out = model_out.multimodal_embeddings + + bsz = mm_out.shape[0] + # reshape in the case of attention pooler + pooled_output = self.pooler(mm_out).view((bsz, -1)) + + # run the heads + head_outputs = {} + for k, head in self.heads.items(): + head_outputs[k] = head(pooled_output) + + return head_outputs diff --git a/torchmultimodal/models/coca/multimodal_decoder.py b/torchmultimodal/models/coca/multimodal_decoder.py new file mode 100644 index 00000000..a971ee37 --- /dev/null +++ b/torchmultimodal/models/coca/multimodal_decoder.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import torch +from torch import nn, Tensor +from torchmultimodal.modules.layers.transformer import TransformerDecoder +from torchmultimodal.utils.attention import get_causal_attention_mask + + +class CoCaMultimodalDecoder(nn.Module): + """ + Multimodal decoder for CoCa model. + Uses a transformer decoder with causal mask for text embeddings + that cross-attends to image embeddings, followed by output projection. + Based on the implementation in open_clip: https://tinyurl.com/mn35vdmd + + Args: + input_seq_len (int): Number of text positions (used to construct + causal mask) + text_embedding_dim (int): Dimension of text embeddings + inside transformer decoder. + n_layer (int): Number of transformer layers + n_head (int): Number of heads in multi-head attention + dim_feedforward (int): Dimension of FFN in transformer decoder + dropout (float): Dropout probability in transformer decoder. Default: 0.0 + activation (Callable[..., nn.Module]): Activation function of transformer + decoder. Default: nn.GELU + layer_norm_eps (float): Epsilon value for transformer decoder layer norms. + Default: 1e-5 + norm_first (bool): Whether to apply layer normalization before or after + self-attention in transformer decoder. Default: True + final_layer_norm_eps (Optional[float]): Regularization value for final layer norm + in transformer decoder. Default: 1e-5 + visual_embedding_dim (Optional[int]): Dimension of visual embeddings inside + transformer decoder (used for cross-attention). Default: None (visual + embeddings assumed to be same dimension as text embeddings) + """ + + def __init__( + self, + input_seq_len: int, + text_embedding_dim: int, + n_layer: int, + n_head: int, + dim_feedforward: int, + output_dim: Optional[int] = None, + dropout: float = 0.0, + activation: Callable[..., nn.Module] = nn.GELU, + layer_norm_eps: float = 1e-5, + norm_first: bool = True, + final_layer_norm_eps: Optional[float] = 1e-5, + visual_embedding_dim: Optional[int] = None, + ): + super().__init__() + self.transformer_decoder = TransformerDecoder( + n_layer=n_layer, + d_model=text_embedding_dim, + n_head=n_head, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + norm_first=norm_first, + use_cross_attention=True, + final_layer_norm_eps=final_layer_norm_eps, + dim_kv=visual_embedding_dim, + ) + if output_dim is not None: + self.output_projection = nn.Linear( + text_embedding_dim, output_dim, bias=False + ) + else: + self.output_projection = None + + self.register_buffer( + "causal_mask", + get_causal_attention_mask(input_seq_len).to(dtype=torch.bool), + persistent=False, + ) + + def forward(self, texts: Tensor, images: Tensor) -> Tensor: + """ + Args: + texts (Tensor): Tensor containing text embeddings of shape [batch_size, text_seq_length, embeddings_dim] + images (Tensor): Tensor containing image embeddings of shape [batch_size, image_seq_length, embeddings_dim] + text_causal_mask (Tensor): Tensor containing causal mask of shape [text_seq_length, text_seq_length] + Returns: + Tensor: Tensor containing output embeddings of shape [batch_size, text_seq_length, output_dim] + """ + seq_len = texts.shape[1] + assert self.causal_mask.shape == (seq_len, seq_len) + decoder_outputs = self.transformer_decoder( + hidden_states=texts, + encoder_hidden_states=images, + attention_mask=self.causal_mask, + ) + hidden_states = decoder_outputs.last_hidden_state + assert hidden_states is not None, "hidden states must not be None" + if self.output_projection is not None: + out = self.output_projection(hidden_states) + else: + out = hidden_states + return out diff --git a/torchmultimodal/models/coca/text_decoder.py b/torchmultimodal/models/coca/text_decoder.py new file mode 100644 index 00000000..456df631 --- /dev/null +++ b/torchmultimodal/models/coca/text_decoder.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torchmultimodal.modules.layers.transformer import TransformerDecoder +from torchmultimodal.utils.attention import get_causal_attention_mask + + +class CoCaTextEmbeddings(nn.Module): + """ + Text embeddings for CoCa model. Includes token embeddings, positional embeddings, + and optional CLS embedding. + + Args: + vocab_size (int): Size of the vocab + num_positions (int): Number of token positions for positional embeddings + not including cls. + embedding_dim (int): Output embedding dimension + pad_idx (Optional[int]): Padding index to be ignored by token embeddings. + Default: 0 + embed_cls (bool): Whether to include CLS embedding. Default: True + """ + + def __init__( + self, + vocab_size: int, + num_positions: int, + embedding_dim: int, + pad_idx: Optional[int] = 0, + embed_cls: bool = True, + ): + super().__init__() + self.num_positions = num_positions + if embed_cls: + self.cls_embedding = nn.Parameter(torch.empty(embedding_dim)) + else: + self.cls_embedding = None + + self.token_embeddings = nn.Embedding(vocab_size, embedding_dim, pad_idx) + self.position_embeddings = nn.Parameter( + torch.empty(num_positions, embedding_dim) + ) + self.init_parameters() + + def init_parameters(self) -> None: + nn.init.normal_(self.token_embeddings.weight, std=0.02) + nn.init.normal_(self.position_embeddings, std=0.01) + if self.cls_embedding is not None: + nn.init.constant_(self.cls_embedding, 0.01) + + def forward(self, input_ids: Tensor) -> Tensor: + """ + Args: + input_ids (Tensor of size (batch_size, seq_length)): + Indices of input sequence tokens. + Returns: + Tensor of size (batch_size, seq_length, embedding_dim) + """ + + assert input_ids.shape[1] == ( + self.num_positions if self.cls_embedding is None else self.num_positions - 1 + ) + embeddings = self.token_embeddings(input_ids) + + if self.cls_embedding is not None: + # Expand cls embedding (embedding_dim) -> (batch_size, 1, embedding_dim) + cls_embed = self.cls_embedding.reshape(1, 1, -1).repeat( + input_ids.shape[0], 1, 1 + ) + embeddings = torch.cat([embeddings, cls_embed], dim=1) + + embeddings = embeddings + self.position_embeddings.to(dtype=embeddings.dtype) + + return embeddings + + +class CoCaTextDecoder(nn.Module): + """ + Text decoder for CoCa model. + Based on the implementation in open_clip: https://tinyurl.com/2jswrb9h + + Args: + vocab_size (int): Size of the vocab + num_positions (int): Number of token positions for positional embeddings. + embedding_dim (int): Embedding dimension for transformer + n_layer (int): Number of transformer layers + n_head (int): Number of attention heads + dim_feedforward (int): Hidden dimension in transformer FFN + output_dim (int): Output dimension of decoder cls / eos projection + pad_idx (Optional[int]): Padding index (will be masked from CLS token). + Default: 0 + embed_cls (bool): Whether to append CLS embedding. Default: True + dropout (float): Dropout probability in transformer decoder + activation (Callable[..., nn.Module]): Activation function of transformer + decoder. Default: nn.GELU + layer_norm_eps (float): Epsilon value for transformer decoder layer norms. + Default: 1e-5 + norm_first (bool): Whether to apply layer normalization before or after + self-attention in transformer decoder. Default: True + final_layer_norm_eps (Optional[float]): Final layer norm epsilon. Only applied + to CLS token if embed_cls=True. Default: 1e-5 + """ + + def __init__( + self, + vocab_size: int, + num_positions: int, + embedding_dim: int, + n_layer: int, + n_head: int, + dim_feedforward: int, + output_dim: int, + pad_idx: Optional[int] = 0, + embed_cls: bool = True, + dropout: float = 0.0, + activation: Callable[..., nn.Module] = nn.GELU, + layer_norm_eps: float = 1e-5, + norm_first: bool = True, + final_layer_norm_eps: Optional[float] = 1e-5, + ): + super().__init__() + self.pad_idx = pad_idx + self.embed_cls = embed_cls + self.num_positions = num_positions + self.embeddings = CoCaTextEmbeddings( + vocab_size=vocab_size, + num_positions=num_positions, + embedding_dim=embedding_dim, + pad_idx=pad_idx, + embed_cls=embed_cls, + ) + self.transformer_decoder = TransformerDecoder( + n_layer=n_layer, + d_model=embedding_dim, + n_head=n_head, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + norm_first=norm_first, + use_cross_attention=False, + ) + if final_layer_norm_eps is not None: + self.ln_final = nn.LayerNorm( + normalized_shape=embedding_dim, eps=final_layer_norm_eps + ) + self.text_projection = nn.Linear(embedding_dim, output_dim, bias=False) + self.register_buffer( + "causal_mask", + get_causal_attention_mask(num_positions).to(dtype=torch.bool), + persistent=False, + ) + self.init_parameters(embedding_dim, n_layer) + + def init_parameters(self, embedding_dim: int, n_layer: int) -> None: + # Initialization based on https://tinyurl.com/cmm7cwjt + attn_std = embedding_dim**-0.5 + proj_std = (2 * embedding_dim * n_layer) ** -0.5 + fc_std = (2 * embedding_dim) ** -0.5 + for layer in self.transformer_decoder.layer: + nn.init.normal_(layer.attention.q_proj.weight, std=attn_std) + nn.init.normal_(layer.attention.k_proj.weight, std=attn_std) + nn.init.normal_(layer.attention.v_proj.weight, std=attn_std) + nn.init.normal_(layer.attention.output_proj.weight, std=proj_std) + nn.init.normal_(layer.feedforward.model[0].weight, std=fc_std) + nn.init.normal_(layer.feedforward.model[2].weight, std=proj_std) + nn.init.normal_(self.text_projection.weight, std=embedding_dim**0.5) + + def build_mask( + self, + input_ids: Tensor, + padding_mask: Optional[Tensor] = None, + ) -> Tensor: + # If no CLS token, we can directly return the causal mask + if not self.embed_cls or self.pad_idx is None: + return self.causal_mask + + # If padding_mask is not passed, infer it + if padding_mask is None: + padding_mask = input_ids != self.pad_idx + assert padding_mask is not None + # (batch_size, seq_len) -> (batch_size, 1, seq_len) + padding_mask = padding_mask.unsqueeze(1) + + # (batch_size, 1, seq_len) -> (batch_size, seq_len+1, seq_len+1) + padding_mask = F.pad(padding_mask, (1, 0, padding_mask.shape[2], 0), value=1.0) + # Make broadcastable for MHA + mask = (padding_mask * self.causal_mask).unsqueeze(1) + + return mask + + def forward( + self, + input_ids: Tensor, + padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + input_ids (Tensor of size (batch_size, seq_length)): + Indices of input sequence tokens. + padding_mask (Optional[Tensor] of size (batch_size, seq_length)): + Boolean tensor: True for unpadded tokens, False for padded tokens. + Returns: + A tuple including + pooled (Tensor): Normalized CLS embedding of shape + (batch_size, output_dim) (for use in contrastive loss). + tokens (Tensor): Embeddings for all non-CLS tokens. Shape: + (batch_size, num_positions, output_dim). + """ + + # If using CLS embedding, drop the final token + if self.embed_cls: + if input_ids.shape[1] == self.num_positions: + input_ids = input_ids[:, :-1] + if padding_mask is not None and padding_mask.shape[1] == self.num_positions: + padding_mask = padding_mask[:, :-1] + + target_shape = self.num_positions - 1 if self.embed_cls else self.num_positions + assert ( + input_ids.shape[1] == target_shape + ), f"{input_ids.shape} doesn't match ({target_shape},*)" + + embeddings = self.embeddings(input_ids) + mask = self.build_mask(input_ids, padding_mask) + decoder_out = self.transformer_decoder(embeddings, attention_mask=mask) + hidden_states = decoder_out.last_hidden_state + assert hidden_states is not None, "hidden states must not be None" + if self.embed_cls: + pooled, tokens = hidden_states[:, -1], hidden_states[:, :-1] + if self.ln_final is not None: + pooled = self.ln_final(pooled) + else: + hidden_states = self.ln_final(hidden_states) + # Use argmax to get EOS embedding (assumes EOS token has highest value) + pooled, tokens = ( + hidden_states[ + torch.arange(hidden_states.shape[0]), input_ids.argmax(dim=-1) + ], + hidden_states, + ) + + if self.text_projection is not None: + pooled = self.text_projection(pooled) + + return pooled, tokens diff --git a/torchmultimodal/modules/layers/attention_pooler.py b/torchmultimodal/modules/layers/attention_pooler.py new file mode 100644 index 00000000..ad0c02cc --- /dev/null +++ b/torchmultimodal/modules/layers/attention_pooler.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +from torch import nn, Tensor +from torchmultimodal.modules.layers.multi_head_attention import ( + MultiHeadAttentionWithCache, +) + + +class AttentionPooler(nn.Module): + """ + Attention pooling layer: pools inputs to sequence length n_queries by performing + cross-attention with learned query embeddings. Originally proposed in + https://arxiv.org/abs/1810.00825. This implementation is based on the one + in open_clip repo: https://tinyurl.com/4yj492sc. + Args: + input_embed_dim (int): Embedding dimension of inputs. + output_embed_dim (int): Embedding dimension of outputs. + n_head (int): Number of attention heads. + n_queries (int): Number of queries. Defaults to 256 + layer_norm_eps (Optional[float]): Epsilon for layer norms. Defaults to 1e-5 + """ + + def __init__( + self, + input_embed_dim: int, + output_embed_dim: int, + n_head: int, + n_queries: int = 256, + layer_norm_eps: float = 1e-5, + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, output_embed_dim)) + self.attn = MultiHeadAttentionWithCache( + dim_q=output_embed_dim, dim_kv=input_embed_dim, num_heads=n_head + ) + self.ln_q = nn.LayerNorm(output_embed_dim, layer_norm_eps) + self.ln_k = nn.LayerNorm(input_embed_dim, layer_norm_eps) + self.ln_post = nn.LayerNorm(output_embed_dim, layer_norm_eps) + + def forward(self, x: Tensor) -> Tensor: + """ + Inputs: + x (Tensor): Input tensor of shape (batch_size, seq_len, input_embed_dim). + Returns: + Attention pooled tensor with shape + (batch_size, n_queries, output_embed_dim). + """ + x = self.ln_k(x) + query = self.ln_q(self.query) + batch_size = x.shape[0] + + # (n_queries, output_embed_dim) -> (batch_size, n_queries, output_embed_dim) + query = self._repeat(query, batch_size) + + out = self.attn(query, x, x) + assert isinstance(out, Tensor) + out = self.ln_post(out) + return out + + def _repeat(self, query: Tensor, n: int) -> Tensor: + return query.unsqueeze(0).repeat(n, 1, 1) + + +class CascadedAttentionPooler(nn.Module): + """ + Wrapper class to perform cascaded attention pooling given multiple attention + poolers. E.g. in CoCa the contrastive pooler is applied on top of the outputs of + the captioning pooler. + + Args: + poolers (List[AttentionPooler]): List of individual attention poolers + """ + + def __init__( + self, + poolers: List[AttentionPooler], + ): + super().__init__() + self.poolers = nn.ModuleList(poolers) + + def forward(self, x: Tensor) -> List[Tensor]: + """ + Inputs: + x (Tensor): Input tensor of shape (batch_size, seq_len, input_embed_dim). + Returns: + List[Tensor] containing attention pooled tensors with shapes + (batch_size, n_queries, output_embed_dim), where n_queries and + output_embed_dim are determined by each individual pooler. + """ + pooler_outs = [] + for pooler in self.poolers: + x = pooler(x) + pooler_outs.append(x) + return pooler_outs