From 5052bed7d5567792fc8bda620302adb97bdb1e40 Mon Sep 17 00:00:00 2001 From: ankitade Date: Tue, 5 Jul 2022 01:20:35 +0000 Subject: [PATCH] [FLAVA]Move itm head to flava model for pretraining [ghstack-poisoned] --- torchmultimodal/models/flava/flava_model.py | 44 ++++++++++++++++++--- torchmultimodal/modules/losses/flava.py | 21 ++-------- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/torchmultimodal/models/flava/flava_model.py b/torchmultimodal/models/flava/flava_model.py index 722533a5..f1ae6397 100644 --- a/torchmultimodal/models/flava/flava_model.py +++ b/torchmultimodal/models/flava/flava_model.py @@ -60,7 +60,7 @@ FLAVA_FOR_PRETRAINED_MAPPING = { - "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt" + "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_cl_itm.pt" } @@ -201,13 +201,13 @@ def flava_model_for_pretraining( # TODO: Add parameters for loss here ): model = flava_model(**flava_model_kwargs) + hidden_size = flava_model_kwargs.get("hidden_size") or 768 + itm_head = ITMHead(hidden_size) losses = FLAVAPretrainingLoss() codebook = DalleVAEEncoder(image_size=codebook_image_size) flava = FLAVAForPreTraining( - model=model, - image_codebook=codebook, - loss=losses, + model=model, image_codebook=codebook, loss=losses, itm_head=itm_head ) if pretrained_model_key is not None: @@ -433,16 +433,45 @@ def encode_mm( return self.mm_encoder(fused_state) +class TwoWayHead(nn.Module): + def __init__(self, hidden_size: int = 768, **kwargs: Any): + super().__init__() + + self.seq_relationship = nn.Linear(hidden_size, 2) + + def forward(self, pooled_output): + return self.seq_relationship(pooled_output) + + +class ITMHead(nn.Module): + def __init__(self, hidden_size: int = 768): + super().__init__() + self.pooler = Pooler(hidden_size=hidden_size) + self.cls = TwoWayHead(hidden_size=hidden_size) + + def forward(self, hidden_states: Tensor): + pooled_output = self.pooler(hidden_states) + logits = self.cls(pooled_output) + return logits + + class FLAVAForPreTraining(nn.Module, PretrainedMixin): # TODOs: # 1. Expose logit scale # 2. For FLAVA model, allow interpolating the embeddings to # for patch embeddings - def __init__(self, model: FLAVAModel, image_codebook: nn.Module, loss: nn.Module): + def __init__( + self, + model: FLAVAModel, + image_codebook: nn.Module, + loss: nn.Module, + itm_head: nn.Module, + ): super().__init__() self.model = model self.image_codebook = image_codebook self.loss = loss + self.itm_head = itm_head def encode_image( self, @@ -488,6 +517,10 @@ def forward( required_embedding=required_embedding, skip_unmasked_mm_encoder=skip_unmasked_mm_encoder, ) + multimodal_masked_sequence = flava_output.multimodal_masked.last_hidden_state + itm_logits = None + if multimodal_masked_sequence is not None: + itm_logits = self.itm_head(multimodal_masked_sequence) return self.loss( image_sequence=flava_output.image.last_hidden_state, @@ -503,6 +536,7 @@ def forward( mlm_labels=mlm_labels, projected_image_embeddings=flava_output.projected_image_embeddings, projected_text_embeddings=flava_output.projected_text_embeddings, + itm_logits=itm_logits, ) diff --git a/torchmultimodal/modules/losses/flava.py b/torchmultimodal/modules/losses/flava.py index 68528ed2..61ab257f 100644 --- a/torchmultimodal/modules/losses/flava.py +++ b/torchmultimodal/modules/losses/flava.py @@ -90,16 +90,6 @@ def forward(self, hidden_states): return pooled_output -class TwoWayHead(nn.Module): - def __init__(self, hidden_size: int = 768, **kwargs: Any): - super().__init__() - - self.seq_relationship = nn.Linear(hidden_size, 2) - - def forward(self, pooled_output): - return self.seq_relationship(pooled_output) - - class ITMLoss(nn.Module): def __init__( self, @@ -108,21 +98,16 @@ def __init__( **kwargs: Any, ): super().__init__() - self.pooler = Pooler(hidden_size=hidden_size) - self.cls = TwoWayHead(hidden_size=hidden_size) self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) def forward( self, - hidden_states: Tensor, + scores: Tensor, labels: Tensor, ): if self.training: assert_labels_are_present(labels, "itm labels") - pooled_output = self.pooler(hidden_states) - scores = self.cls(pooled_output) - if labels is None: loss = pooled_output.sum() * 0 else: @@ -371,6 +356,7 @@ def forward( mlm_labels: Optional[Tensor] = None, projected_image_embeddings: Optional[Tensor] = None, projected_text_embeddings: Optional[Tensor] = None, + itm_logits: Optional[Tensor] = None, ) -> FLAVAPretrainingLossOutput: outputs = FLAVAPretrainingLossOutput() pos_mask = None @@ -422,6 +408,7 @@ def forward( outputs.losses.mlm_loss = outputs.mlm_output.loss if multimodal_masked_sequence is not None and self.itm_loss_weight > 0: + assert itm_logits is not None if itm_labels is not None: pos_pairs = itm_labels.ne(0) pos_mask = torch.where( @@ -432,7 +419,7 @@ def forward( multimodal_masked_sequence.size(0), device=multimodal_masked_sequence.device, ).bool() - outputs.itm_output = self.itm_loss(multimodal_masked_sequence, itm_labels) + outputs.itm_output = self.itm_loss(itm_logits, itm_labels) outputs.itm_output.loss *= self.itm_loss_weight outputs.losses.itm_loss = outputs.itm_output.loss