From 0ad6763ac71dc9356801cd1d1376329312150c6b Mon Sep 17 00:00:00 2001 From: ankitade Date: Sat, 23 Jul 2022 01:54:57 +0000 Subject: [PATCH] [FLAVA]Change ordering on contrastive loss initialization ghstack-source-id: e5f700c514c105884f4db21ffc827ddaf3c74a40 Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/105 --- test/models/flava/test_flava.py | 22 ++++++++++++--------- test/models/flava/test_flava_checkpoint.py | 20 +++++++++++-------- torchmultimodal/models/flava/flava_model.py | 7 +++---- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/test/models/flava/test_flava.py b/test/models/flava/test_flava.py index 6d606ae76..900f319a1 100644 --- a/test/models/flava/test_flava.py +++ b/test/models/flava/test_flava.py @@ -27,27 +27,30 @@ def setUp(self): @torch.no_grad() def test_forward_classification(self): - flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None) text = torch.randint(0, 30500, (2, 77), dtype=torch.long) image = torch.rand((2, 3, 224, 224)) labels = torch.randint(0, 2, (2,), dtype=torch.long) + flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None) + flava.eval() + # Test multimodal scenario + output = flava(image, text, "mm", labels) - self.assertAlmostEqual(output.loss.item(), 0.9724, places=4) + self.assertAlmostEqual(output.loss.item(), 0.7180, places=4) # Test unimodal image scenario output = flava(image, text, "image", labels) - self.assertAlmostEqual(output.loss.item(), 0.5453, places=4) + self.assertAlmostEqual(output.loss.item(), 0.7020, places=4) # Test unimodal text scenario output = flava(image, text, "text", labels) - self.assertAlmostEqual(output.loss.item(), 0.7074, places=4) + self.assertAlmostEqual(output.loss.item(), 0.6663, places=4) @torch.no_grad() def test_forward_pretraining(self): - flava = flava_model_for_pretraining() + text = torch.randint(0, 30500, (2, 77), dtype=torch.long) image = torch.rand((2, 3, 224, 224)) image_for_codebook = torch.rand(2, 3, 112, 112) @@ -58,7 +61,8 @@ def test_forward_pretraining(self): mlm_labels[:, :] = -1 mlm_labels[:, 1:3] = text[:, 1:3] itm_labels = torch.tensor((0, 1), dtype=torch.long) - + flava = flava_model_for_pretraining() + flava.eval() output = flava( image=image, text=text, @@ -79,7 +83,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 20.4199, + 21.5150, places=4, ) @@ -103,7 +107,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 9.3403, + 8.9674, places=4, ) @@ -128,7 +132,7 @@ def test_forward_pretraining(self): sum( value if value is not None else 0 for value in output.losses.values() ).item(), - 10.8777, + 10.0305, places=4, ) diff --git a/test/models/flava/test_flava_checkpoint.py b/test/models/flava/test_flava_checkpoint.py index ae6ba1e2b..5cd2127ac 100644 --- a/test/models/flava/test_flava_checkpoint.py +++ b/test/models/flava/test_flava_checkpoint.py @@ -32,7 +32,7 @@ def image_input(self): @pytest.fixture def inputs_classification(self, image_input, text_input): def gather_inputs(required_embedding): - labels = torch.randint(0, 2, (2,), dtype=torch.long) + labels = torch.tensor((0, 1), dtype=torch.long) return image_input, text_input, required_embedding, labels return gather_inputs @@ -88,21 +88,25 @@ def _assert_tensor_dicts_equal(self, dict_actual, dict_expected): assert_expected(actual, expected, rtol=0, atol=1e-4) def test_flava_model_for_classification( - self, classification_model, inputs_classification + self, inputs_classification, classification_model ): - output = classification_model(*inputs_classification("mm")) + mm_input = inputs_classification("mm") + image_input = inputs_classification("image") + text_input = inputs_classification("text") + classification_model.eval() + output = classification_model(*mm_input) actual = output.loss - expected = torch.tensor(1.1017) + expected = torch.tensor(1.0827) assert_expected(actual, expected, rtol=0, atol=1e-4) - output = classification_model(*inputs_classification("image")) + output = classification_model(*image_input) actual = output.loss - expected = torch.tensor(1.0912) + expected = torch.tensor(1.0849) assert_expected(actual, expected, rtol=0, atol=1e-4) - output = classification_model(*inputs_classification("text")) + output = classification_model(*text_input) actual = output.loss - expected = torch.tensor(1.1136) + expected = torch.tensor(1.0822) assert_expected(actual, expected, rtol=0, atol=1e-4) def test_flava_model_for_pretraining(self, pretraining_model, inputs_pretraining): diff --git a/torchmultimodal/models/flava/flava_model.py b/torchmultimodal/models/flava/flava_model.py index 11d7372cf..56ffaa741 100644 --- a/torchmultimodal/models/flava/flava_model.py +++ b/torchmultimodal/models/flava/flava_model.py @@ -453,9 +453,8 @@ def flava_model_for_pretraining( # TODO: Add parameters for loss here ) -> FLAVAForPreTraining: model = flava_model(**flava_model_kwargs) - - codebook = DalleVAEEncoder(image_size=codebook_image_size) losses = FLAVAPretrainingLoss() + codebook = DalleVAEEncoder(image_size=codebook_image_size) flava = FLAVAForPreTraining( model=model, @@ -480,7 +479,7 @@ def flava_model_for_classification( pretrained_model_key: Optional[str] = "flava_full", **flava_model_kwargs: Any, ) -> FLAVAForClassification: - model = flava_model(**flava_model_kwargs) + classifier = MLP( in_dim=classifier_in_dim, out_dim=num_classes, @@ -489,7 +488,7 @@ def flava_model_for_classification( activation=classifier_activation, normalization=classifier_normalization, ) - + model = flava_model(**flava_model_kwargs) if loss_fn is None: loss_fn = nn.CrossEntropyLoss()