diff --git a/examples/flava/model.py b/examples/flava/model.py index bf4e9073..6d1026b4 100644 --- a/examples/flava/model.py +++ b/examples/flava/model.py @@ -8,16 +8,18 @@ import torch from pytorch_lightning import LightningModule +from torch import nn from torchmetrics import Accuracy from torchmultimodal.models.flava.model import ( flava_model_for_classification, flava_model_for_pretraining, ) +from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss from transformers.optimization import get_cosine_schedule_with_warmup def get_optimizers_for_lightning( - model: torch.nn.Module, + model: nn.Module, learning_rate: float, adam_eps: float, adam_weight_decay: float, @@ -59,6 +61,7 @@ def __init__( self.adam_weight_decay = adam_weight_decay self.warmup_steps = warmup_steps self.max_steps = max_steps + self.loss = FLAVAPretrainingLoss(logit_scale=self.model.logit_scale) def training_step(self, batch, batch_idx): output = self._step(batch, batch_idx) @@ -104,7 +107,24 @@ def _step(self, batch, batch_idx): itm_labels=batch.get("itm_labels", None), required_embedding=required_embedding, ) - return output + + loss = self.loss( + multimodal_masked_sequence=output.multimodal_masked_sequence, + pos_mask=output.pos_mask, + itm_labels=output.itm_labels, + mim_labels=output.mim_labels, + mlm_labels=output.mlm_labels, + mmm_mlm_labels=output.mmm_mlm_labels, + mmm_mim_labels=output.mmm_mim_labels, + projected_image_embeddings=output.projected_image_embeddings, + projected_text_embeddings=output.projected_text_embeddings, + itm_logits=output.itm_logits, + mlm_head_output=output.mlm_head_output, + mim_head_output=output.mim_head_output, + mmm_mlm_head_output=output.mmm_mlm_head_output, + mmm_mim_head_output=output.mmm_mim_head_output, + ) + return loss def configure_optimizers(self): return get_optimizers_for_lightning( diff --git a/test/models/flava/test_checkpoint.py b/test/models/flava/test_checkpoint.py index 677d24aa..d295c58f 100644 --- a/test/models/flava/test_checkpoint.py +++ b/test/models/flava/test_checkpoint.py @@ -12,6 +12,7 @@ flava_model_for_classification, flava_model_for_pretraining, ) +from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss @pytest.fixture(autouse=True) @@ -139,8 +140,25 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model image_input = inputs_pretraining("image") text_input = inputs_pretraining("text") flava = pretraining_model() - + losses = FLAVAPretrainingLoss(flava.logit_scale) output = flava(*mm_input) + output = losses( + multimodal_masked_sequence=output.multimodal_masked_sequence, + pos_mask=output.pos_mask, + itm_labels=output.itm_labels, + mim_labels=output.mim_labels, + mlm_labels=output.mlm_labels, + mmm_mlm_labels=output.mmm_mlm_labels, + mmm_mim_labels=output.mmm_mim_labels, + projected_image_embeddings=output.projected_image_embeddings, + projected_text_embeddings=output.projected_text_embeddings, + itm_logits=output.itm_logits, + mlm_head_output=output.mlm_head_output, + mim_head_output=output.mim_head_output, + mmm_mlm_head_output=output.mmm_mlm_head_output, + mmm_mim_head_output=output.mmm_mim_head_output, + ) + actual = output.losses expected = dict( mmm_text_loss=10.9567, @@ -153,6 +171,22 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model self._assert_tensor_dicts_equal(actual, expected) output = flava(*image_input) + output = losses( + multimodal_masked_sequence=output.multimodal_masked_sequence, + pos_mask=output.pos_mask, + itm_labels=output.itm_labels, + mim_labels=output.mim_labels, + mlm_labels=output.mlm_labels, + mmm_mlm_labels=output.mmm_mlm_labels, + mmm_mim_labels=output.mmm_mim_labels, + projected_image_embeddings=output.projected_image_embeddings, + projected_text_embeddings=output.projected_text_embeddings, + itm_logits=output.itm_logits, + mlm_head_output=output.mlm_head_output, + mim_head_output=output.mim_head_output, + mmm_mlm_head_output=output.mmm_mlm_head_output, + mmm_mim_head_output=output.mmm_mim_head_output, + ) actual = output.losses expected = dict( mmm_text_loss=None, @@ -165,6 +199,22 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model self._assert_tensor_dicts_equal(actual, expected) output = flava(*text_input) + output = losses( + multimodal_masked_sequence=output.multimodal_masked_sequence, + pos_mask=output.pos_mask, + itm_labels=output.itm_labels, + mim_labels=output.mim_labels, + mlm_labels=output.mlm_labels, + mmm_mlm_labels=output.mmm_mlm_labels, + mmm_mim_labels=output.mmm_mim_labels, + projected_image_embeddings=output.projected_image_embeddings, + projected_text_embeddings=output.projected_text_embeddings, + itm_logits=output.itm_logits, + mlm_head_output=output.mlm_head_output, + mim_head_output=output.mim_head_output, + mmm_mlm_head_output=output.mmm_mlm_head_output, + mmm_mim_head_output=output.mmm_mim_head_output, + ) actual = output.losses expected = dict( mmm_text_loss=None, diff --git a/test/models/flava/test_flava.py b/test/models/flava/test_flava.py index 21818680..9b3ca33f 100644 --- a/test/models/flava/test_flava.py +++ b/test/models/flava/test_flava.py @@ -18,6 +18,7 @@ FLAVAOutput, ) from torchmultimodal.modules.layers.transformer import TransformerOutput +from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss NUM_CLASSES = 2 @@ -61,6 +62,7 @@ def test_forward_pretraining(self): mlm_labels[:, 1:3] = text[:, 1:3] itm_labels = torch.tensor((0, 1), dtype=torch.long) flava = flava_model_for_pretraining() + losses = FLAVAPretrainingLoss(flava.logit_scale) flava.eval() output = flava( image=image, @@ -72,6 +74,22 @@ def test_forward_pretraining(self): itm_labels=itm_labels, mlm_labels=mlm_labels, ) + output = losses( + multimodal_masked_sequence=output.multimodal_masked_sequence, + pos_mask=output.pos_mask, + itm_labels=output.itm_labels, + mim_labels=output.mim_labels, + mlm_labels=output.mlm_labels, + mmm_mlm_labels=output.mmm_mlm_labels, + mmm_mim_labels=output.mmm_mim_labels, + projected_image_embeddings=output.projected_image_embeddings, + projected_text_embeddings=output.projected_text_embeddings, + itm_logits=output.itm_logits, + mlm_head_output=output.mlm_head_output, + mim_head_output=output.mim_head_output, + mmm_mlm_head_output=output.mmm_mlm_head_output, + mmm_mim_head_output=output.mmm_mim_head_output, + ) self.assertIsNone(output.mlm_output) self.assertIsNone(output.mim_output) self.assertIsNotNone(output.global_contrastive_output) @@ -96,6 +114,22 @@ def test_forward_pretraining(self): itm_labels=itm_labels, mlm_labels=mlm_labels, ) + output = losses( + multimodal_masked_sequence=output.multimodal_masked_sequence, + pos_mask=output.pos_mask, + itm_labels=output.itm_labels, + mim_labels=output.mim_labels, + mlm_labels=output.mlm_labels, + mmm_mlm_labels=output.mmm_mlm_labels, + mmm_mim_labels=output.mmm_mim_labels, + projected_image_embeddings=output.projected_image_embeddings, + projected_text_embeddings=output.projected_text_embeddings, + itm_logits=output.itm_logits, + mlm_head_output=output.mlm_head_output, + mim_head_output=output.mim_head_output, + mmm_mlm_head_output=output.mmm_mlm_head_output, + mmm_mim_head_output=output.mmm_mim_head_output, + ) self.assertIsNone(output.mlm_output) self.assertIsNotNone(output.mim_output) self.assertIsNone(output.global_contrastive_output) @@ -120,6 +154,22 @@ def test_forward_pretraining(self): itm_labels=itm_labels, mlm_labels=mlm_labels, ) + output = losses( + multimodal_masked_sequence=output.multimodal_masked_sequence, + pos_mask=output.pos_mask, + itm_labels=output.itm_labels, + mim_labels=output.mim_labels, + mlm_labels=output.mlm_labels, + mmm_mlm_labels=output.mmm_mlm_labels, + mmm_mim_labels=output.mmm_mim_labels, + projected_image_embeddings=output.projected_image_embeddings, + projected_text_embeddings=output.projected_text_embeddings, + itm_logits=output.itm_logits, + mlm_head_output=output.mlm_head_output, + mim_head_output=output.mim_head_output, + mmm_mlm_head_output=output.mmm_mlm_head_output, + mmm_mim_head_output=output.mmm_mim_head_output, + ) self.assertIsNotNone(output.mlm_output) self.assertIsNone(output.mim_output) self.assertIsNone(output.global_contrastive_output) diff --git a/torchmultimodal/models/flava/model.py b/torchmultimodal/models/flava/model.py index b696a390..c6ad3ae2 100644 --- a/torchmultimodal/models/flava/model.py +++ b/torchmultimodal/models/flava/model.py @@ -24,11 +24,7 @@ TransformerEncoder, TransformerOutput, ) -from torchmultimodal.modules.losses.flava import ( - FLAVAPretrainingLoss, - FLAVAPretrainingLossOutput, - Pooler, -) +from torchmultimodal.modules.losses.flava import FLAVAPretrainingLossOutput, Pooler from torchmultimodal.utils.common import ModelOutput, PretrainedMixin from typing_extensions import Literal @@ -62,7 +58,7 @@ FLAVA_FOR_PRETRAINED_MAPPING = { # This will no longer load with the updated model, but keeping here just in case # "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin", - "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt", + "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_no_loss.pt", } FLAVA_MODEL_MAPPING = { @@ -105,6 +101,24 @@ class FLAVAForClassificationOutput(ModelOutput): loss: Tensor +@dataclass +class FLAVAForPretrainingOutput: + multimodal_masked_sequence: Tensor + pos_mask: Tensor + mim_labels: Tensor + mlm_labels: Tensor + mmm_mlm_labels: Tensor + mmm_mim_labels: Tensor + itm_labels: Tensor + projected_image_embeddings: Tensor + projected_text_embeddings: Tensor + itm_logits: Tensor + mlm_head_output: Tensor + mim_head_output: Tensor + mmm_mlm_head_output: Tensor + mmm_mim_head_output: Tensor + + class FLAVAModel(nn.Module, PretrainedMixin): def __init__( self, @@ -367,22 +381,22 @@ def __init__( self, model: FLAVAModel, image_codebook: nn.Module, - loss: nn.Module, itm_head: nn.Module, mlm_head: nn.Module, mim_head: nn.Module, mmm_mlm_head: nn.Module, mmm_mim_head: nn.Module, + logit_scale: nn.Module, ): super().__init__() self.model = model self.image_codebook = image_codebook - self.loss = loss self.itm_head = itm_head self.mlm_head = mlm_head self.mim_head = mim_head self.mmm_mlm_head = mmm_mlm_head self.mmm_mim_head = mmm_mim_head + self.logit_scale = logit_scale def encode_image( self, @@ -469,8 +483,10 @@ def forward( itm_logits = self.itm_head(multimodal_masked_sequence) multimodal_masked_sequence = multimodal_masked_sequence[pos_mask] + if mlm_labels is not None: mmm_mlm_labels = mlm_labels[pos_mask] + if image_labels is not None: mmm_mim_labels = image_labels[pos_mask] @@ -494,14 +510,14 @@ def forward( sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :] mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels) - return self.loss( + return FLAVAForPretrainingOutput( multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state, pos_mask=pos_mask, - itm_labels=itm_labels, mim_labels=image_labels, mlm_labels=mlm_labels, - mmm_mlm_labels=mmm_mlm_labels, mmm_mim_labels=mmm_mim_labels, + mmm_mlm_labels=mmm_mlm_labels, + itm_labels=itm_labels, projected_image_embeddings=flava_output.projected_image_embeddings, projected_text_embeddings=flava_output.projected_text_embeddings, itm_logits=itm_logits, @@ -660,6 +676,7 @@ def flava_model_for_pretraining( codebook_image_size: int = 112, pretrained_model_key: Optional[str] = None, image_vocab_size: int = 8192, + logit_scale: float = math.log(1 / 0.07), **flava_model_kwargs: Any, # TODO: Add parameters for loss here ) -> FLAVAForPreTraining: @@ -677,18 +694,18 @@ def flava_model_for_pretraining( mmm_mim_head = MaskedPredictionHead( hidden_size=hidden_size, vocab_size=image_vocab_size ) - losses = FLAVAPretrainingLoss() + codebook = DalleVAEEncoder(image_size=codebook_image_size) flava = FLAVAForPreTraining( model=model, image_codebook=codebook, - loss=losses, itm_head=itm_head, mlm_head=mlm_head, mim_head=mim_head, mmm_mlm_head=mmm_mlm_head, mmm_mim_head=mmm_mim_head, + logit_scale=nn.Parameter(logit_scale * torch.ones([])), ) if pretrained_model_key is not None: