-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ghstack-source-id: d33bb143d838e7e5303550e648a82635c98846d8 Pull Request resolved: #506
- Loading branch information
1 parent
acc421e
commit df65bc9
Showing
10 changed files
with
1,620 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed | ||
from torchmultimodal.models.coca.coca_model import ( | ||
coca_vit, | ||
CoCaForPretraining, | ||
CoCaModelOutput, | ||
) | ||
|
||
|
||
class TestCoCaModel: | ||
@pytest.fixture(autouse=True) | ||
def random(self): | ||
set_rng_seed(0) | ||
|
||
@pytest.fixture | ||
def batch_size(self): | ||
return 2 | ||
|
||
@pytest.fixture | ||
def vocab_size(self): | ||
return 50 | ||
|
||
@pytest.fixture | ||
def num_text_positions(self): | ||
return 11 | ||
|
||
@pytest.fixture | ||
def attention_pooler_output_dim(self): | ||
return 8 | ||
|
||
@pytest.fixture | ||
def text_output_dim(self): | ||
return 8 | ||
|
||
@pytest.fixture | ||
def image_size(self): | ||
return 12 | ||
|
||
@pytest.fixture | ||
def coca_model( | ||
self, | ||
vocab_size, | ||
num_text_positions, | ||
attention_pooler_output_dim, | ||
text_output_dim, | ||
image_size, | ||
): | ||
coca_model = coca_vit( | ||
vision_patch_size=4, | ||
vision_dim_feedforward=24, | ||
vision_n_layer=2, | ||
vision_n_head=2, | ||
vocab_size=vocab_size, | ||
num_text_positions=num_text_positions, | ||
text_hidden_dim=8, | ||
text_n_layer=2, | ||
text_n_head=2, | ||
text_dim_feedforward=32, | ||
text_output_dim=text_output_dim, | ||
fusion_n_layer=2, | ||
fusion_n_head=2, | ||
fusion_dim_feedforward=32, | ||
multimodal_output_projection_dim=vocab_size, | ||
pooler_input_embed_dim=6, | ||
pooler_output_embed_dim=attention_pooler_output_dim, | ||
image_size=image_size, | ||
pooler_n_head=2, | ||
cascaded_pooler=False, | ||
) | ||
init_weights_with_constant(coca_model) | ||
coca_model.eval() | ||
return coca_model | ||
|
||
@pytest.fixture | ||
def text_inputs(self): | ||
return torch.LongTensor( | ||
[ | ||
[1, 3, 4, 5, 6, 7, 8, 2, 0, 0, 0], | ||
[1, 25, 28, 34, 39, 45, 40, 5, 12, 6, 2], | ||
] | ||
) | ||
|
||
@pytest.fixture | ||
def image_inputs(self, batch_size, image_size): | ||
return torch.randn(batch_size, 3, image_size, image_size) | ||
|
||
@pytest.fixture | ||
def expected( | ||
self, | ||
batch_size, | ||
vocab_size, | ||
num_text_positions, | ||
attention_pooler_output_dim, | ||
text_output_dim, | ||
): | ||
pooled_val = 0.3536 | ||
logit_val = 8.0 | ||
return CoCaModelOutput( | ||
image_pooled_output=pooled_val | ||
* torch.ones(batch_size, attention_pooler_output_dim), | ||
text_pooled_output=pooled_val * torch.ones(batch_size, text_output_dim), | ||
multimodal_embeddings=logit_val | ||
* torch.ones(batch_size, num_text_positions - 1, vocab_size), | ||
) | ||
|
||
@pytest.fixture | ||
def coca_for_pretraining(self, coca_model): | ||
coca_for_pretraining = CoCaForPretraining(coca_model) | ||
init_weights_with_constant(coca_for_pretraining) | ||
coca_for_pretraining.eval() | ||
return coca_for_pretraining | ||
|
||
def test_coca_model(self, text_inputs, image_inputs, coca_model, expected): | ||
actual = coca_model(image_inputs, text_inputs) | ||
assert_expected(actual, expected, rtol=0, atol=1e-4) | ||
|
||
def test_scripting(self, text_inputs, image_inputs, coca_model): | ||
scripted_model = torch.jit.script(coca_model) | ||
assert_expected( | ||
scripted_model(image_inputs, text_inputs), | ||
coca_model(image_inputs, text_inputs), | ||
rtol=0, | ||
atol=1e-4, | ||
) | ||
|
||
def test_coca_for_pretraining( | ||
self, text_inputs, image_inputs, coca_for_pretraining | ||
): | ||
actual_losses = coca_for_pretraining(image_inputs, text_inputs) | ||
expected_losses = { | ||
"contrastive": torch.tensor(0.6931), | ||
"captioning": torch.tensor(3.9120), | ||
} | ||
assert_expected(actual_losses, expected_losses, rtol=0, atol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
from tests.test_utils import assert_expected, init_weights_with_constant | ||
from torch import nn, Tensor | ||
from torchmultimodal.models.coca.multimodal_decoder import CoCaMultimodalDecoder | ||
|
||
|
||
class TestCoCaMultimodalDecoder: | ||
@pytest.fixture | ||
def batch_size(self): | ||
return 2 | ||
|
||
@pytest.fixture | ||
def input_seq_len(self): | ||
return 5 | ||
|
||
@pytest.fixture | ||
def num_image_positions(self): | ||
return 10 | ||
|
||
@pytest.fixture | ||
def text_embedding_dim(self): | ||
return 4 | ||
|
||
@pytest.fixture | ||
def multimodal_decoder(self, input_seq_len, batch_size, text_embedding_dim): | ||
decoder = CoCaMultimodalDecoder( | ||
input_seq_len=input_seq_len, | ||
text_embedding_dim=text_embedding_dim, | ||
n_layer=2, | ||
n_head=2, | ||
dim_feedforward=4 * text_embedding_dim, | ||
output_dim=3, | ||
final_layer_norm_eps=1e-5, | ||
) | ||
init_weights_with_constant(decoder) | ||
|
||
# Custom init final MLP layer weight, final LN, and text projection | ||
decoder.transformer_decoder.layer[1].feedforward.model[2].weight = nn.Parameter( | ||
torch.arange( | ||
decoder.transformer_decoder.layer[1] | ||
.feedforward.model[2] | ||
.weight.numel(), | ||
dtype=torch.float, | ||
).reshape( | ||
decoder.transformer_decoder.layer[1].feedforward.model[2].weight.shape | ||
) | ||
) | ||
decoder.output_projection.weight = nn.Parameter( | ||
torch.arange(decoder.output_projection.weight.numel(), dtype=torch.float) | ||
.reshape(decoder.output_projection.weight.T.shape) | ||
.T | ||
) | ||
decoder.transformer_decoder.final_layer_norm.weight = nn.Parameter( | ||
torch.arange( | ||
decoder.transformer_decoder.final_layer_norm.weight.numel(), | ||
dtype=torch.float, | ||
) | ||
) | ||
decoder.eval() | ||
return decoder | ||
|
||
@pytest.fixture | ||
def text_inputs(self, batch_size, input_seq_len, text_embedding_dim): | ||
return torch.arange(0.0, 1.0, 1.0 / 40).reshape( | ||
batch_size, input_seq_len, text_embedding_dim | ||
) | ||
|
||
@pytest.fixture | ||
def image_inputs(self, batch_size, num_image_positions, text_embedding_dim): | ||
return torch.arange(10.0, 20.0, 1.0 / 8).reshape( | ||
batch_size, num_image_positions, text_embedding_dim | ||
) | ||
|
||
@pytest.fixture | ||
def expected(self): | ||
return Tensor( | ||
[ | ||
[ | ||
[58.2492, 66.7214, 75.1935], | ||
[58.2492, 66.7214, 75.1935], | ||
[58.2492, 66.7214, 75.1935], | ||
[58.2492, 66.7214, 75.1935], | ||
[58.2492, 66.7214, 75.1935], | ||
], | ||
[ | ||
[58.2492, 66.7214, 75.1935], | ||
[58.2492, 66.7214, 75.1935], | ||
[58.2492, 66.7214, 75.1935], | ||
[58.2492, 66.7214, 75.1935], | ||
[58.2492, 66.7214, 75.1935], | ||
], | ||
] | ||
) | ||
|
||
def test_coca_multimodal_decoder( | ||
self, text_inputs, image_inputs, multimodal_decoder, expected | ||
): | ||
actual = multimodal_decoder(text_inputs, image_inputs) | ||
assert_expected(actual, expected, rtol=0, atol=1e-4) | ||
|
||
def test_scripting(self, text_inputs, image_inputs, multimodal_decoder): | ||
scripted_multimodal_decoder = torch.jit.script(multimodal_decoder) | ||
assert_expected( | ||
scripted_multimodal_decoder(text_inputs, image_inputs), | ||
multimodal_decoder(text_inputs, image_inputs), | ||
rtol=0, | ||
atol=1e-4, | ||
) |
Oops, something went wrong.