diff --git a/torchmultimodal/models/flava/model.py b/torchmultimodal/models/flava/model.py index e2431c8ea..10f955ddd 100644 --- a/torchmultimodal/models/flava/model.py +++ b/torchmultimodal/models/flava/model.py @@ -62,7 +62,8 @@ FLAVA_FOR_PRETRAINED_MAPPING = { # This will no longer load with the updated model, but keeping here just in case # "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin", - "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm.pt", + + "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt", } FLAVA_MODEL_MAPPING = { @@ -314,6 +315,45 @@ def forward(self, hidden_states: Tensor): return logits +class MaskedPredictionHead(nn.Module): + def __init__( + self, + hidden_size: int = 768, + vocab_size: int = 30522, + transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu, + layer_norm_eps: float = 1e-5, + use_fp32_layer_norm: bool = True, + **kwargs: Any, + ): + super().__init__() + + self.dense = nn.Linear(hidden_size, hidden_size) + self.transform_act_fn = transform_act_fn + + self.layer_norm: nn.LayerNorm + if use_fp32_layer_norm: + self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps) + else: + self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(vocab_size)) + + # Need a link between the two variables so that the bias is + # correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + class FLAVAForPreTraining(nn.Module, PretrainedMixin): # TODOs: # 1. Expose logit scale @@ -325,12 +365,20 @@ def __init__( image_codebook: nn.Module, loss: nn.Module, itm_head: nn.Module, + mlm_head: nn.Module, + mim_head: nn.Module, + mmm_mlm_head: nn.Module, + mmm_mim_head: nn.Module, ): super().__init__() self.model = model self.image_codebook = image_codebook self.loss = loss self.itm_head = itm_head + self.mlm_head = mlm_head + self.mim_head = mim_head + self.mmm_mlm_head = mmm_mlm_head + self.mmm_mim_head = mmm_mim_head def encode_image( self, @@ -383,6 +431,25 @@ def forward( if multimodal_masked_sequence is not None: itm_logits = self.itm_head(multimodal_masked_sequence) + image_masked_sequence = flava_output.image_masked.last_hidden_state + text_masked_sequence = flava_output.text_masked.last_hidden_state + mlm_head_output = ( + mim_head_output + ) = mmm_mlm_head_output = mmm_mim_head_output = None + + if image_masked_sequence is not None and multimodal_masked_sequence is None: + mim_head_output = self.mim_head(image_masked_sequence) + if text_masked_sequence is not None and multimodal_masked_sequence is None: + mlm_head_output = self.mlm_head(text_masked_sequence) + if multimodal_masked_sequence is not None: + start_index = -(text_masked_sequence.size(1)) + mmm_text_sequence = multimodal_masked_sequence[:, start_index:, :] + mmm_mlm_head_output = self.mmm_mlm_head(mmm_text_sequence) + if multimodal_masked_sequence is not None: + total_indices = image_masked_sequence.size(1) - 1 + mmm_image_sequence = multimodal_masked_sequence[:, 2 : 2 + total_indices, :] + mmm_mim_head_output = self.mmm_mim_head(mmm_image_sequence) + return self.loss( image_sequence=flava_output.image.last_hidden_state, text_sequence=flava_output.text.last_hidden_state, @@ -398,6 +465,10 @@ def forward( projected_image_embeddings=flava_output.projected_image_embeddings, projected_text_embeddings=flava_output.projected_text_embeddings, itm_logits=itm_logits, + mlm_head_output=mlm_head_output, + mim_head_output=mim_head_output, + mmm_mlm_head_output=mmm_mlm_head_output, + mmm_mim_head_output=mmm_mim_head_output, ) @@ -548,17 +619,36 @@ def flava_model( def flava_model_for_pretraining( codebook_image_size: int = 112, pretrained_model_key: Optional[str] = None, + image_vocab_size: int = 8192, **flava_model_kwargs: Any, # TODO: Add parameters for loss here ) -> FLAVAForPreTraining: model = flava_model(**flava_model_kwargs) hidden_size = flava_model_kwargs.get("hidden_size") or 768 + text_vocab_size = flava_model_kwargs.get("vocab_size") or 30522 itm_head = ITMHead(hidden_size) + mlm_head = MaskedPredictionHead(hidden_size=hidden_size, vocab_size=text_vocab_size) + mim_head = MaskedPredictionHead( + hidden_size=hidden_size, vocab_size=image_vocab_size + ) + mmm_mlm_head = MaskedPredictionHead( + hidden_size=hidden_size, vocab_size=text_vocab_size + ) + mmm_mim_head = MaskedPredictionHead( + hidden_size=hidden_size, vocab_size=image_vocab_size + ) losses = FLAVAPretrainingLoss() codebook = DalleVAEEncoder(image_size=codebook_image_size) flava = FLAVAForPreTraining( - model=model, image_codebook=codebook, loss=losses, itm_head=itm_head + model=model, + image_codebook=codebook, + loss=losses, + itm_head=itm_head, + mlm_head=mlm_head, + mim_head=mim_head, + mmm_mlm_head=mmm_mlm_head, + mmm_mim_head=mmm_mim_head, ) if pretrained_model_key is not None: diff --git a/torchmultimodal/modules/losses/flava.py b/torchmultimodal/modules/losses/flava.py index cca58a0ba..565956a06 100644 --- a/torchmultimodal/modules/losses/flava.py +++ b/torchmultimodal/modules/losses/flava.py @@ -11,7 +11,6 @@ import torch from torch import nn, Tensor -from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm from torchmultimodal.modules.losses.contrastive_loss_with_temperature import ( contrastive_loss_with_temperature, ContrastiveLossOutput, @@ -130,45 +129,6 @@ def forward( return ITMLossOutput(logits=scores, loss=loss) -class MaskedPredictionHead(nn.Module): - def __init__( - self, - hidden_size: int = 768, - vocab_size: int = 30522, - transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu, - layer_norm_eps: float = 1e-5, - use_fp32_layer_norm: bool = True, - **kwargs: Any, - ): - super().__init__() - - self.dense = nn.Linear(hidden_size, hidden_size) - self.transform_act_fn = transform_act_fn - - self.layer_norm: nn.LayerNorm - if use_fp32_layer_norm: - self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps) - else: - self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) - - self.bias = nn.Parameter(torch.zeros(vocab_size)) - - # Need a link between the two variables so that the bias is - # correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states: Tensor) -> Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - class MaskedPredictionLoss(nn.Module): def __init__( self, @@ -181,36 +141,29 @@ def __init__( **kwargs: Any, ): super().__init__() - - self.cls = MaskedPredictionHead( - hidden_size=hidden_size, - vocab_size=vocab_size, - transform_act_fn=transform_act_fn, - layer_norm_eps=layer_norm_eps, - ) self.ignore_index = ignore_index self.vocab_size = vocab_size self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) self.ignore_nan = ignore_nan def forward( - self, hidden_states: Tensor, masked_labels: Optional[Tensor] = None + self, + prediction: Tensor, + masked_labels: Optional[Tensor] = None, + pos_mask: Optional[Tensor] = None, ) -> MaskedPredictionLossOutput: - if self.training: - assert_labels_are_present(masked_labels, "masked labels") - - if masked_labels is not None: - masked_tokens = masked_labels.ne(self.ignore_index) - masked_labels = masked_labels[masked_tokens] - sequence_output = hidden_states[masked_tokens, :] - else: - sequence_output = hidden_states - prediction = self.cls(sequence_output) + if pos_mask is not None: + masked_labels = masked_labels[pos_mask] + masked_tokens = masked_labels.ne(self.ignore_index) + masked_labels = masked_labels[masked_tokens] if masked_labels is None: masked_loss = prediction.sum() * 0 else: + if pos_mask is not None: + prediction = prediction[pos_mask] + prediction = prediction[masked_tokens, :] masked_loss = self.ce_loss( prediction.view(-1, self.vocab_size), masked_labels.view(-1), @@ -371,6 +324,10 @@ def forward( projected_image_embeddings: Optional[Tensor] = None, projected_text_embeddings: Optional[Tensor] = None, itm_logits: Optional[Tensor] = None, + mlm_head_output: Optional[Tensor] = None, + mim_head_output: Optional[Tensor] = None, + mmm_mlm_head_output: Optional[Tensor] = None, + mmm_mim_head_output: Optional[Tensor] = None, ) -> FLAVAPretrainingLossOutput: outputs = FLAVAPretrainingLossOutput() pos_mask = None @@ -380,28 +337,28 @@ def forward( # text, but that is a research question :) if ( - image_masked_sequence is not None + mim_head_output is not None and self.mim_weight > 0 and multimodal_masked_sequence is None ): # Remove CLS token from image_masked_sequence - start_index = -mim_labels.size(1) if mim_labels is not None else 1 outputs.mim_output = self.mim_loss( - image_masked_sequence[:, start_index:, :], mim_labels + mim_head_output[:, start_index:, :], mim_labels ) outputs.mim_output.loss *= self.mim_weight outputs.losses.mim_loss = outputs.mim_output.loss # Check multimodal_masked_sequence to make sure this is unimodal case + if ( - text_masked_sequence is not None + mlm_head_output is not None and self.mlm_weight > 0 and multimodal_masked_sequence is None ): start_index = -mlm_labels.size(1) if mlm_labels is not None else 1 outputs.mlm_output = self.mlm_loss( - text_masked_sequence[:, start_index:, :], mlm_labels + mlm_head_output[:, start_index:, :], mlm_labels ) outputs.mlm_output.loss *= self.mlm_weight outputs.losses.mlm_loss = outputs.mlm_output.loss @@ -422,27 +379,14 @@ def forward( outputs.itm_output.loss *= self.itm_loss_weight outputs.losses.itm_loss = outputs.itm_output.loss - multimodal_masked_sequence = multimodal_masked_sequence[pos_mask] - if mlm_labels is not None: - mlm_labels = mlm_labels[pos_mask] - if mim_labels is not None: - mim_labels = mim_labels[pos_mask] - - if multimodal_masked_sequence is not None and self.mmm_text_loss_weight > 0: - start_index = ( - -mlm_labels.size(1) - if mlm_labels is not None - else -(text_masked_sequence.size(1) - 1) - ) - sequence_for_text = multimodal_masked_sequence[:, start_index:, :] + if mmm_mlm_head_output is not None and self.mmm_text_loss_weight > 0: outputs.mmm_text_output = self.mmm_loss.mlm( - sequence_for_text, - mlm_labels, + mmm_mlm_head_output, mlm_labels, pos_mask ) # type: ignore outputs.mmm_text_output.loss *= self.mmm_text_loss_weight outputs.losses.mmm_text_loss = outputs.mmm_text_output.loss - if multimodal_masked_sequence is not None and self.mmm_image_loss_weight > 0: + if mmm_mim_head_output is not None and self.mmm_image_loss_weight > 0: # Starts from 2 because of 2 CLS, one for multimodal encoder and one # that comes from image encoder. total_indices = ( @@ -450,10 +394,8 @@ def forward( if mlm_labels is not None else (image_masked_sequence.size(1) - 1) ) - sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :] outputs.mmm_image_output = self.mmm_loss.mim( - sequence_for_image, - mim_labels, + mmm_mim_head_output, mim_labels, pos_mask ) # type: ignore outputs.mmm_image_output.loss *= self.mmm_image_loss_weight outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss