diff --git a/torchmultimodal/models/flava/flava_model.py b/torchmultimodal/models/flava/flava_model.py index 513698217..80c86d6f1 100644 --- a/torchmultimodal/models/flava/flava_model.py +++ b/torchmultimodal/models/flava/flava_model.py @@ -60,7 +60,7 @@ FLAVA_FOR_PRETRAINED_MAPPING = { - "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt" + "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_cl_itm.pt" } @@ -287,18 +287,45 @@ def encode_mm( return self.mm_encoder(fused_state) +class TwoWayHead(nn.Module): + def __init__(self, hidden_size: int = 768, **kwargs: Any): + super().__init__() + + self.seq_relationship = nn.Linear(hidden_size, 2) + + def forward(self, pooled_output): + return self.seq_relationship(pooled_output) + + +class ITMHead(nn.Module): + def __init__(self, hidden_size: int = 768): + super().__init__() + self.pooler = Pooler(hidden_size=hidden_size) + self.cls = TwoWayHead(hidden_size=hidden_size) + + def forward(self, hidden_states: Tensor): + pooled_output = self.pooler(hidden_states) + logits = self.cls(pooled_output) + return logits + + class FLAVAForPreTraining(nn.Module, PretrainedMixin): # TODOs: # 1. Expose logit scale # 2. For FLAVA model, allow interpolating the embeddings to # for patch embeddings def __init__( - self, model: FLAVAModel, image_codebook: nn.Module, loss: FLAVAPretrainingLoss - ) -> None: + self, + model: FLAVAModel, + image_codebook: nn.Module, + loss: nn.Module, + itm_head: nn.Module, + ): super().__init__() self.model = model self.image_codebook = image_codebook self.loss = loss + self.itm_head = itm_head def encode_image( self, @@ -346,6 +373,10 @@ def forward( required_embedding=required_embedding, skip_unmasked_mm_encoder=skip_unmasked_mm_encoder, ) + multimodal_masked_sequence = flava_output.multimodal_masked.last_hidden_state + itm_logits = None + if multimodal_masked_sequence is not None: + itm_logits = self.itm_head(multimodal_masked_sequence) return self.loss( image_sequence=flava_output.image.last_hidden_state, @@ -361,6 +392,7 @@ def forward( mlm_labels=mlm_labels, projected_image_embeddings=flava_output.projected_image_embeddings, projected_text_embeddings=flava_output.projected_text_embeddings, + itm_logits=itm_logits, ) @@ -516,13 +548,13 @@ def flava_model_for_pretraining( # TODO: Add parameters for loss here ) -> FLAVAForPreTraining: model = flava_model(**flava_model_kwargs) + hidden_size = flava_model_kwargs.get("hidden_size") or 768 + itm_head = ITMHead(hidden_size) losses = FLAVAPretrainingLoss() codebook = DalleVAEEncoder(image_size=codebook_image_size) flava = FLAVAForPreTraining( - model=model, - image_codebook=codebook, - loss=losses, + model=model, image_codebook=codebook, loss=losses, itm_head=itm_head ) if pretrained_model_key is not None: diff --git a/torchmultimodal/modules/losses/flava.py b/torchmultimodal/modules/losses/flava.py index 78214c361..cca58a0ba 100644 --- a/torchmultimodal/modules/losses/flava.py +++ b/torchmultimodal/modules/losses/flava.py @@ -110,21 +110,16 @@ def __init__( **kwargs: Any, ): super().__init__() - self.pooler = Pooler(hidden_size=hidden_size) - self.cls = TwoWayHead(hidden_size=hidden_size) self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) def forward( self, - hidden_states: Tensor, + scores: Tensor, labels: Tensor, ) -> ITMLossOutput: if self.training: assert_labels_are_present(labels, "itm labels") - pooled_output = self.pooler(hidden_states) - scores = self.cls(pooled_output) - if labels is None: loss = pooled_output.sum() * 0 else: @@ -375,6 +370,7 @@ def forward( mlm_labels: Optional[Tensor] = None, projected_image_embeddings: Optional[Tensor] = None, projected_text_embeddings: Optional[Tensor] = None, + itm_logits: Optional[Tensor] = None, ) -> FLAVAPretrainingLossOutput: outputs = FLAVAPretrainingLossOutput() pos_mask = None @@ -411,6 +407,7 @@ def forward( outputs.losses.mlm_loss = outputs.mlm_output.loss if multimodal_masked_sequence is not None and self.itm_loss_weight > 0: + assert itm_logits is not None if itm_labels is not None: pos_pairs = itm_labels.ne(0) pos_mask = torch.where( @@ -421,7 +418,7 @@ def forward( multimodal_masked_sequence.size(0), device=multimodal_masked_sequence.device, ).bool() - outputs.itm_output = self.itm_loss(multimodal_masked_sequence, itm_labels) + outputs.itm_output = self.itm_loss(itm_logits, itm_labels) outputs.itm_output.loss *= self.itm_loss_weight outputs.losses.itm_loss = outputs.itm_output.loss