Skip to content

Commit

Permalink
[FLAVA]Change ordering on contrastive loss initialization
Browse files Browse the repository at this point in the history
ghstack-source-id: e5f700c514c105884f4db21ffc827ddaf3c74a40
Pull Request resolved: #105
  • Loading branch information
ankitade committed Jul 23, 2022
1 parent e315f02 commit 0ad6763
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
22 changes: 13 additions & 9 deletions test/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,30 @@ def setUp(self):

@torch.no_grad()
def test_forward_classification(self):
flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None)
text = torch.randint(0, 30500, (2, 77), dtype=torch.long)
image = torch.rand((2, 3, 224, 224))

labels = torch.randint(0, 2, (2,), dtype=torch.long)

flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None)
flava.eval()

# Test multimodal scenario

output = flava(image, text, "mm", labels)
self.assertAlmostEqual(output.loss.item(), 0.9724, places=4)
self.assertAlmostEqual(output.loss.item(), 0.7180, places=4)

# Test unimodal image scenario
output = flava(image, text, "image", labels)
self.assertAlmostEqual(output.loss.item(), 0.5453, places=4)
self.assertAlmostEqual(output.loss.item(), 0.7020, places=4)

# Test unimodal text scenario
output = flava(image, text, "text", labels)
self.assertAlmostEqual(output.loss.item(), 0.7074, places=4)
self.assertAlmostEqual(output.loss.item(), 0.6663, places=4)

@torch.no_grad()
def test_forward_pretraining(self):
flava = flava_model_for_pretraining()

text = torch.randint(0, 30500, (2, 77), dtype=torch.long)
image = torch.rand((2, 3, 224, 224))
image_for_codebook = torch.rand(2, 3, 112, 112)
Expand All @@ -58,7 +61,8 @@ def test_forward_pretraining(self):
mlm_labels[:, :] = -1
mlm_labels[:, 1:3] = text[:, 1:3]
itm_labels = torch.tensor((0, 1), dtype=torch.long)

flava = flava_model_for_pretraining()
flava.eval()
output = flava(
image=image,
text=text,
Expand All @@ -79,7 +83,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
20.4199,
21.5150,
places=4,
)

Expand All @@ -103,7 +107,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
9.3403,
8.9674,
places=4,
)

Expand All @@ -128,7 +132,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
10.8777,
10.0305,
places=4,
)

Expand Down
20 changes: 12 additions & 8 deletions test/models/flava/test_flava_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def image_input(self):
@pytest.fixture
def inputs_classification(self, image_input, text_input):
def gather_inputs(required_embedding):
labels = torch.randint(0, 2, (2,), dtype=torch.long)
labels = torch.tensor((0, 1), dtype=torch.long)
return image_input, text_input, required_embedding, labels

return gather_inputs
Expand Down Expand Up @@ -88,21 +88,25 @@ def _assert_tensor_dicts_equal(self, dict_actual, dict_expected):
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_flava_model_for_classification(
self, classification_model, inputs_classification
self, inputs_classification, classification_model
):
output = classification_model(*inputs_classification("mm"))
mm_input = inputs_classification("mm")
image_input = inputs_classification("image")
text_input = inputs_classification("text")
classification_model.eval()
output = classification_model(*mm_input)
actual = output.loss
expected = torch.tensor(1.1017)
expected = torch.tensor(1.0827)
assert_expected(actual, expected, rtol=0, atol=1e-4)

output = classification_model(*inputs_classification("image"))
output = classification_model(*image_input)
actual = output.loss
expected = torch.tensor(1.0912)
expected = torch.tensor(1.0849)
assert_expected(actual, expected, rtol=0, atol=1e-4)

output = classification_model(*inputs_classification("text"))
output = classification_model(*text_input)
actual = output.loss
expected = torch.tensor(1.1136)
expected = torch.tensor(1.0822)
assert_expected(actual, expected, rtol=0, atol=1e-4)

def test_flava_model_for_pretraining(self, pretraining_model, inputs_pretraining):
Expand Down
7 changes: 3 additions & 4 deletions torchmultimodal/models/flava/flava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,8 @@ def flava_model_for_pretraining(
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
model = flava_model(**flava_model_kwargs)

codebook = DalleVAEEncoder(image_size=codebook_image_size)
losses = FLAVAPretrainingLoss()
codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model,
Expand All @@ -480,7 +479,7 @@ def flava_model_for_classification(
pretrained_model_key: Optional[str] = "flava_full",
**flava_model_kwargs: Any,
) -> FLAVAForClassification:
model = flava_model(**flava_model_kwargs)

classifier = MLP(
in_dim=classifier_in_dim,
out_dim=num_classes,
Expand All @@ -489,7 +488,7 @@ def flava_model_for_classification(
activation=classifier_activation,
normalization=classifier_normalization,
)

model = flava_model(**flava_model_kwargs)
if loss_fn is None:
loss_fn = nn.CrossEntropyLoss()

Expand Down

0 comments on commit 0ad6763

Please sign in to comment.