Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CoCa model #506

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/models/coca/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
141 changes: 141 additions & 0 deletions tests/models/coca/test_coca_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed
from torchmultimodal.models.coca.coca_model import (
coca_vit,
CoCaForPretraining,
CoCaModelOutput,
)


class TestCoCaModel:
@pytest.fixture(autouse=True)
def random(self):
set_rng_seed(0)

@pytest.fixture
def batch_size(self):
return 2

@pytest.fixture
def vocab_size(self):
return 50

@pytest.fixture
def num_text_positions(self):
return 11

@pytest.fixture
def attention_pooler_output_dim(self):
return 8

@pytest.fixture
def text_output_dim(self):
return 8

@pytest.fixture
def image_size(self):
return 12

@pytest.fixture
def coca_model(
self,
vocab_size,
num_text_positions,
attention_pooler_output_dim,
text_output_dim,
image_size,
):
coca_model = coca_vit(
vision_patch_size=4,
vision_dim_feedforward=24,
vision_n_layer=2,
vision_n_head=2,
vocab_size=vocab_size,
num_text_positions=num_text_positions,
text_hidden_dim=8,
text_n_layer=2,
text_n_head=2,
text_dim_feedforward=32,
text_output_dim=text_output_dim,
fusion_n_layer=2,
fusion_n_head=2,
fusion_dim_feedforward=32,
multimodal_output_projection_dim=vocab_size,
pooler_input_embed_dim=6,
pooler_output_embed_dim=attention_pooler_output_dim,
image_size=image_size,
pooler_n_head=2,
cascaded_pooler=False,
)
init_weights_with_constant(coca_model)
coca_model.eval()
return coca_model

@pytest.fixture
def text_inputs(self):
return torch.LongTensor(
[
[1, 3, 4, 5, 6, 7, 8, 2, 0, 0, 0],
[1, 25, 28, 34, 39, 45, 40, 5, 12, 6, 2],
]
)

@pytest.fixture
def image_inputs(self, batch_size, image_size):
return torch.randn(batch_size, 3, image_size, image_size)

@pytest.fixture
def expected(
self,
batch_size,
vocab_size,
num_text_positions,
attention_pooler_output_dim,
text_output_dim,
):
pooled_val = 0.3536
logit_val = 8.0
return CoCaModelOutput(
image_pooled_output=pooled_val
* torch.ones(batch_size, attention_pooler_output_dim),
text_pooled_output=pooled_val * torch.ones(batch_size, text_output_dim),
multimodal_embeddings=logit_val
* torch.ones(batch_size, num_text_positions - 1, vocab_size),
)

@pytest.fixture
def coca_for_pretraining(self, coca_model):
coca_for_pretraining = CoCaForPretraining(coca_model)
init_weights_with_constant(coca_for_pretraining)
coca_for_pretraining.eval()
return coca_for_pretraining

def test_coca_model(self, text_inputs, image_inputs, coca_model, expected):
actual = coca_model(image_inputs, text_inputs)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_scripting(self, text_inputs, image_inputs, coca_model):
scripted_model = torch.jit.script(coca_model)
assert_expected(
scripted_model(image_inputs, text_inputs),
coca_model(image_inputs, text_inputs),
rtol=0,
atol=1e-4,
)

def test_coca_for_pretraining(
self, text_inputs, image_inputs, coca_for_pretraining
):
actual_losses = coca_for_pretraining(image_inputs, text_inputs)
expected_losses = {
"contrastive": torch.tensor(0.6931),
"captioning": torch.tensor(3.9120),
}
assert_expected(actual_losses, expected_losses, rtol=0, atol=1e-4)
115 changes: 115 additions & 0 deletions tests/models/coca/test_multimodal_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import assert_expected, init_weights_with_constant
from torch import nn, Tensor
from torchmultimodal.models.coca.multimodal_decoder import CoCaMultimodalDecoder


class TestCoCaMultimodalDecoder:
@pytest.fixture
def batch_size(self):
return 2

@pytest.fixture
def input_seq_len(self):
return 5

@pytest.fixture
def num_image_positions(self):
return 10

@pytest.fixture
def text_embedding_dim(self):
return 4

@pytest.fixture
def multimodal_decoder(self, input_seq_len, batch_size, text_embedding_dim):
decoder = CoCaMultimodalDecoder(
input_seq_len=input_seq_len,
text_embedding_dim=text_embedding_dim,
n_layer=2,
n_head=2,
dim_feedforward=4 * text_embedding_dim,
output_dim=3,
final_layer_norm_eps=1e-5,
)
init_weights_with_constant(decoder)

# Custom init final MLP layer weight, final LN, and text projection
decoder.transformer_decoder.layer[1].feedforward.model[2].weight = nn.Parameter(
torch.arange(
decoder.transformer_decoder.layer[1]
.feedforward.model[2]
.weight.numel(),
dtype=torch.float,
).reshape(
decoder.transformer_decoder.layer[1].feedforward.model[2].weight.shape
)
)
decoder.output_projection.weight = nn.Parameter(
torch.arange(decoder.output_projection.weight.numel(), dtype=torch.float)
.reshape(decoder.output_projection.weight.T.shape)
.T
)
decoder.transformer_decoder.final_layer_norm.weight = nn.Parameter(
torch.arange(
decoder.transformer_decoder.final_layer_norm.weight.numel(),
dtype=torch.float,
)
)
decoder.eval()
return decoder

@pytest.fixture
def text_inputs(self, batch_size, input_seq_len, text_embedding_dim):
return torch.arange(0.0, 1.0, 1.0 / 40).reshape(
batch_size, input_seq_len, text_embedding_dim
)

@pytest.fixture
def image_inputs(self, batch_size, num_image_positions, text_embedding_dim):
return torch.arange(10.0, 20.0, 1.0 / 8).reshape(
batch_size, num_image_positions, text_embedding_dim
)

@pytest.fixture
def expected(self):
return Tensor(
[
[
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
],
[
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
],
]
)

def test_coca_multimodal_decoder(
self, text_inputs, image_inputs, multimodal_decoder, expected
):
actual = multimodal_decoder(text_inputs, image_inputs)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_scripting(self, text_inputs, image_inputs, multimodal_decoder):
scripted_multimodal_decoder = torch.jit.script(multimodal_decoder)
assert_expected(
scripted_multimodal_decoder(text_inputs, image_inputs),
multimodal_decoder(text_inputs, image_inputs),
rtol=0,
atol=1e-4,
)
Loading
Loading