Skip to content

Commit

Permalink
Add CoCa model
Browse files Browse the repository at this point in the history
ghstack-source-id: d33bb143d838e7e5303550e648a82635c98846d8
Pull Request resolved: #506
  • Loading branch information
ebsmothers committed Nov 2, 2023
1 parent acc421e commit df65bc9
Show file tree
Hide file tree
Showing 10 changed files with 1,620 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/models/coca/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
141 changes: 141 additions & 0 deletions tests/models/coca/test_coca_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed
from torchmultimodal.models.coca.coca_model import (
coca_vit,
CoCaForPretraining,
CoCaModelOutput,
)


class TestCoCaModel:
@pytest.fixture(autouse=True)
def random(self):
set_rng_seed(0)

@pytest.fixture
def batch_size(self):
return 2

@pytest.fixture
def vocab_size(self):
return 50

@pytest.fixture
def num_text_positions(self):
return 11

@pytest.fixture
def attention_pooler_output_dim(self):
return 8

@pytest.fixture
def text_output_dim(self):
return 8

@pytest.fixture
def image_size(self):
return 12

@pytest.fixture
def coca_model(
self,
vocab_size,
num_text_positions,
attention_pooler_output_dim,
text_output_dim,
image_size,
):
coca_model = coca_vit(
vision_patch_size=4,
vision_dim_feedforward=24,
vision_n_layer=2,
vision_n_head=2,
vocab_size=vocab_size,
num_text_positions=num_text_positions,
text_hidden_dim=8,
text_n_layer=2,
text_n_head=2,
text_dim_feedforward=32,
text_output_dim=text_output_dim,
fusion_n_layer=2,
fusion_n_head=2,
fusion_dim_feedforward=32,
multimodal_output_projection_dim=vocab_size,
pooler_input_embed_dim=6,
pooler_output_embed_dim=attention_pooler_output_dim,
image_size=image_size,
pooler_n_head=2,
cascaded_pooler=False,
)
init_weights_with_constant(coca_model)
coca_model.eval()
return coca_model

@pytest.fixture
def text_inputs(self):
return torch.LongTensor(
[
[1, 3, 4, 5, 6, 7, 8, 2, 0, 0, 0],
[1, 25, 28, 34, 39, 45, 40, 5, 12, 6, 2],
]
)

@pytest.fixture
def image_inputs(self, batch_size, image_size):
return torch.randn(batch_size, 3, image_size, image_size)

@pytest.fixture
def expected(
self,
batch_size,
vocab_size,
num_text_positions,
attention_pooler_output_dim,
text_output_dim,
):
pooled_val = 0.3536
logit_val = 8.0
return CoCaModelOutput(
image_pooled_output=pooled_val
* torch.ones(batch_size, attention_pooler_output_dim),
text_pooled_output=pooled_val * torch.ones(batch_size, text_output_dim),
multimodal_embeddings=logit_val
* torch.ones(batch_size, num_text_positions - 1, vocab_size),
)

@pytest.fixture
def coca_for_pretraining(self, coca_model):
coca_for_pretraining = CoCaForPretraining(coca_model)
init_weights_with_constant(coca_for_pretraining)
coca_for_pretraining.eval()
return coca_for_pretraining

def test_coca_model(self, text_inputs, image_inputs, coca_model, expected):
actual = coca_model(image_inputs, text_inputs)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_scripting(self, text_inputs, image_inputs, coca_model):
scripted_model = torch.jit.script(coca_model)
assert_expected(
scripted_model(image_inputs, text_inputs),
coca_model(image_inputs, text_inputs),
rtol=0,
atol=1e-4,
)

def test_coca_for_pretraining(
self, text_inputs, image_inputs, coca_for_pretraining
):
actual_losses = coca_for_pretraining(image_inputs, text_inputs)
expected_losses = {
"contrastive": torch.tensor(0.6931),
"captioning": torch.tensor(3.9120),
}
assert_expected(actual_losses, expected_losses, rtol=0, atol=1e-4)
115 changes: 115 additions & 0 deletions tests/models/coca/test_multimodal_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import assert_expected, init_weights_with_constant
from torch import nn, Tensor
from torchmultimodal.models.coca.multimodal_decoder import CoCaMultimodalDecoder


class TestCoCaMultimodalDecoder:
@pytest.fixture
def batch_size(self):
return 2

@pytest.fixture
def input_seq_len(self):
return 5

@pytest.fixture
def num_image_positions(self):
return 10

@pytest.fixture
def text_embedding_dim(self):
return 4

@pytest.fixture
def multimodal_decoder(self, input_seq_len, batch_size, text_embedding_dim):
decoder = CoCaMultimodalDecoder(
input_seq_len=input_seq_len,
text_embedding_dim=text_embedding_dim,
n_layer=2,
n_head=2,
dim_feedforward=4 * text_embedding_dim,
output_dim=3,
final_layer_norm_eps=1e-5,
)
init_weights_with_constant(decoder)

# Custom init final MLP layer weight, final LN, and text projection
decoder.transformer_decoder.layer[1].feedforward.model[2].weight = nn.Parameter(
torch.arange(
decoder.transformer_decoder.layer[1]
.feedforward.model[2]
.weight.numel(),
dtype=torch.float,
).reshape(
decoder.transformer_decoder.layer[1].feedforward.model[2].weight.shape
)
)
decoder.output_projection.weight = nn.Parameter(
torch.arange(decoder.output_projection.weight.numel(), dtype=torch.float)
.reshape(decoder.output_projection.weight.T.shape)
.T
)
decoder.transformer_decoder.final_layer_norm.weight = nn.Parameter(
torch.arange(
decoder.transformer_decoder.final_layer_norm.weight.numel(),
dtype=torch.float,
)
)
decoder.eval()
return decoder

@pytest.fixture
def text_inputs(self, batch_size, input_seq_len, text_embedding_dim):
return torch.arange(0.0, 1.0, 1.0 / 40).reshape(
batch_size, input_seq_len, text_embedding_dim
)

@pytest.fixture
def image_inputs(self, batch_size, num_image_positions, text_embedding_dim):
return torch.arange(10.0, 20.0, 1.0 / 8).reshape(
batch_size, num_image_positions, text_embedding_dim
)

@pytest.fixture
def expected(self):
return Tensor(
[
[
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
],
[
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
[58.2492, 66.7214, 75.1935],
],
]
)

def test_coca_multimodal_decoder(
self, text_inputs, image_inputs, multimodal_decoder, expected
):
actual = multimodal_decoder(text_inputs, image_inputs)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_scripting(self, text_inputs, image_inputs, multimodal_decoder):
scripted_multimodal_decoder = torch.jit.script(multimodal_decoder)
assert_expected(
scripted_multimodal_decoder(text_inputs, image_inputs),
multimodal_decoder(text_inputs, image_inputs),
rtol=0,
atol=1e-4,
)
Loading

0 comments on commit df65bc9

Please sign in to comment.