Skip to content

Commit

Permalink
[FLAVA] Move masked prediction head to flava_for_pretraining
Browse files Browse the repository at this point in the history
ghstack-source-id: 0ca76d743540249a472057b64d4f4a85d6c943e3
Pull Request resolved: #195
  • Loading branch information
ankitade committed Aug 20, 2022
1 parent 4f995c6 commit c973d20
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 84 deletions.
94 changes: 92 additions & 2 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
FLAVA_FOR_PRETRAINED_MAPPING = {
# This will no longer load with the updated model, but keeping here just in case
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm.pt",

"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt",
}

FLAVA_MODEL_MAPPING = {
Expand Down Expand Up @@ -314,6 +315,45 @@ def forward(self, hidden_states: Tensor):
return logits


class MaskedPredictionHead(nn.Module):
def __init__(
self,
hidden_size: int = 768,
vocab_size: int = 30522,
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
layer_norm_eps: float = 1e-5,
use_fp32_layer_norm: bool = True,
**kwargs: Any,
):
super().__init__()

self.dense = nn.Linear(hidden_size, hidden_size)
self.transform_act_fn = transform_act_fn

self.layer_norm: nn.LayerNorm
if use_fp32_layer_norm:
self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps)
else:
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)

self.bias = nn.Parameter(torch.zeros(vocab_size))

# Need a link between the two variables so that the bias is
# correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states


class FLAVAForPreTraining(nn.Module, PretrainedMixin):
# TODOs:
# 1. Expose logit scale
Expand All @@ -325,12 +365,20 @@ def __init__(
image_codebook: nn.Module,
loss: nn.Module,
itm_head: nn.Module,
mlm_head: nn.Module,
mim_head: nn.Module,
mmm_mlm_head: nn.Module,
mmm_mim_head: nn.Module,
):
super().__init__()
self.model = model
self.image_codebook = image_codebook
self.loss = loss
self.itm_head = itm_head
self.mlm_head = mlm_head
self.mim_head = mim_head
self.mmm_mlm_head = mmm_mlm_head
self.mmm_mim_head = mmm_mim_head

def encode_image(
self,
Expand Down Expand Up @@ -383,6 +431,25 @@ def forward(
if multimodal_masked_sequence is not None:
itm_logits = self.itm_head(multimodal_masked_sequence)

image_masked_sequence = flava_output.image_masked.last_hidden_state
text_masked_sequence = flava_output.text_masked.last_hidden_state
mlm_head_output = (
mim_head_output
) = mmm_mlm_head_output = mmm_mim_head_output = None

if image_masked_sequence is not None and multimodal_masked_sequence is None:
mim_head_output = self.mim_head(image_masked_sequence)
if text_masked_sequence is not None and multimodal_masked_sequence is None:
mlm_head_output = self.mlm_head(text_masked_sequence)
if multimodal_masked_sequence is not None:
start_index = -(text_masked_sequence.size(1))
mmm_text_sequence = multimodal_masked_sequence[:, start_index:, :]
mmm_mlm_head_output = self.mmm_mlm_head(mmm_text_sequence)
if multimodal_masked_sequence is not None:
total_indices = image_masked_sequence.size(1) - 1
mmm_image_sequence = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
mmm_mim_head_output = self.mmm_mim_head(mmm_image_sequence)

return self.loss(
image_sequence=flava_output.image.last_hidden_state,
text_sequence=flava_output.text.last_hidden_state,
Expand All @@ -398,6 +465,10 @@ def forward(
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
itm_logits=itm_logits,
mlm_head_output=mlm_head_output,
mim_head_output=mim_head_output,
mmm_mlm_head_output=mmm_mlm_head_output,
mmm_mim_head_output=mmm_mim_head_output,
)


Expand Down Expand Up @@ -548,17 +619,36 @@ def flava_model(
def flava_model_for_pretraining(
codebook_image_size: int = 112,
pretrained_model_key: Optional[str] = None,
image_vocab_size: int = 8192,
**flava_model_kwargs: Any,
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
model = flava_model(**flava_model_kwargs)
hidden_size = flava_model_kwargs.get("hidden_size") or 768
text_vocab_size = flava_model_kwargs.get("vocab_size") or 30522
itm_head = ITMHead(hidden_size)
mlm_head = MaskedPredictionHead(hidden_size=hidden_size, vocab_size=text_vocab_size)
mim_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=image_vocab_size
)
mmm_mlm_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=text_vocab_size
)
mmm_mim_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=image_vocab_size
)
losses = FLAVAPretrainingLoss()
codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model, image_codebook=codebook, loss=losses, itm_head=itm_head
model=model,
image_codebook=codebook,
loss=losses,
itm_head=itm_head,
mlm_head=mlm_head,
mim_head=mim_head,
mmm_mlm_head=mmm_mlm_head,
mmm_mim_head=mmm_mim_head,
)

if pretrained_model_key is not None:
Expand Down
106 changes: 24 additions & 82 deletions torchmultimodal/modules/losses/flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import torch
from torch import nn, Tensor
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
from torchmultimodal.modules.losses.contrastive_loss_with_temperature import (
contrastive_loss_with_temperature,
ContrastiveLossOutput,
Expand Down Expand Up @@ -130,45 +129,6 @@ def forward(
return ITMLossOutput(logits=scores, loss=loss)


class MaskedPredictionHead(nn.Module):
def __init__(
self,
hidden_size: int = 768,
vocab_size: int = 30522,
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
layer_norm_eps: float = 1e-5,
use_fp32_layer_norm: bool = True,
**kwargs: Any,
):
super().__init__()

self.dense = nn.Linear(hidden_size, hidden_size)
self.transform_act_fn = transform_act_fn

self.layer_norm: nn.LayerNorm
if use_fp32_layer_norm:
self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps)
else:
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)

self.bias = nn.Parameter(torch.zeros(vocab_size))

# Need a link between the two variables so that the bias is
# correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states


class MaskedPredictionLoss(nn.Module):
def __init__(
self,
Expand All @@ -181,36 +141,29 @@ def __init__(
**kwargs: Any,
):
super().__init__()

self.cls = MaskedPredictionHead(
hidden_size=hidden_size,
vocab_size=vocab_size,
transform_act_fn=transform_act_fn,
layer_norm_eps=layer_norm_eps,
)
self.ignore_index = ignore_index
self.vocab_size = vocab_size
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
self.ignore_nan = ignore_nan

def forward(
self, hidden_states: Tensor, masked_labels: Optional[Tensor] = None
self,
prediction: Tensor,
masked_labels: Optional[Tensor] = None,
pos_mask: Optional[Tensor] = None,
) -> MaskedPredictionLossOutput:
if self.training:
assert_labels_are_present(masked_labels, "masked labels")

if masked_labels is not None:
masked_tokens = masked_labels.ne(self.ignore_index)
masked_labels = masked_labels[masked_tokens]
sequence_output = hidden_states[masked_tokens, :]
else:
sequence_output = hidden_states

prediction = self.cls(sequence_output)
if pos_mask is not None:
masked_labels = masked_labels[pos_mask]
masked_tokens = masked_labels.ne(self.ignore_index)
masked_labels = masked_labels[masked_tokens]

if masked_labels is None:
masked_loss = prediction.sum() * 0
else:
if pos_mask is not None:
prediction = prediction[pos_mask]
prediction = prediction[masked_tokens, :]
masked_loss = self.ce_loss(
prediction.view(-1, self.vocab_size),
masked_labels.view(-1),
Expand Down Expand Up @@ -371,6 +324,10 @@ def forward(
projected_image_embeddings: Optional[Tensor] = None,
projected_text_embeddings: Optional[Tensor] = None,
itm_logits: Optional[Tensor] = None,
mlm_head_output: Optional[Tensor] = None,
mim_head_output: Optional[Tensor] = None,
mmm_mlm_head_output: Optional[Tensor] = None,
mmm_mim_head_output: Optional[Tensor] = None,
) -> FLAVAPretrainingLossOutput:
outputs = FLAVAPretrainingLossOutput()
pos_mask = None
Expand All @@ -380,28 +337,28 @@ def forward(
# text, but that is a research question :)

if (
image_masked_sequence is not None
mim_head_output is not None
and self.mim_weight > 0
and multimodal_masked_sequence is None
):
# Remove CLS token from image_masked_sequence

start_index = -mim_labels.size(1) if mim_labels is not None else 1
outputs.mim_output = self.mim_loss(
image_masked_sequence[:, start_index:, :], mim_labels
mim_head_output[:, start_index:, :], mim_labels
)
outputs.mim_output.loss *= self.mim_weight
outputs.losses.mim_loss = outputs.mim_output.loss

# Check multimodal_masked_sequence to make sure this is unimodal case

if (
text_masked_sequence is not None
mlm_head_output is not None
and self.mlm_weight > 0
and multimodal_masked_sequence is None
):
start_index = -mlm_labels.size(1) if mlm_labels is not None else 1
outputs.mlm_output = self.mlm_loss(
text_masked_sequence[:, start_index:, :], mlm_labels
mlm_head_output[:, start_index:, :], mlm_labels
)
outputs.mlm_output.loss *= self.mlm_weight
outputs.losses.mlm_loss = outputs.mlm_output.loss
Expand All @@ -422,38 +379,23 @@ def forward(
outputs.itm_output.loss *= self.itm_loss_weight
outputs.losses.itm_loss = outputs.itm_output.loss

multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]
if mlm_labels is not None:
mlm_labels = mlm_labels[pos_mask]
if mim_labels is not None:
mim_labels = mim_labels[pos_mask]

if multimodal_masked_sequence is not None and self.mmm_text_loss_weight > 0:
start_index = (
-mlm_labels.size(1)
if mlm_labels is not None
else -(text_masked_sequence.size(1) - 1)
)
sequence_for_text = multimodal_masked_sequence[:, start_index:, :]
if mmm_mlm_head_output is not None and self.mmm_text_loss_weight > 0:
outputs.mmm_text_output = self.mmm_loss.mlm(
sequence_for_text,
mlm_labels,
mmm_mlm_head_output, mlm_labels, pos_mask
) # type: ignore
outputs.mmm_text_output.loss *= self.mmm_text_loss_weight
outputs.losses.mmm_text_loss = outputs.mmm_text_output.loss

if multimodal_masked_sequence is not None and self.mmm_image_loss_weight > 0:
if mmm_mim_head_output is not None and self.mmm_image_loss_weight > 0:
# Starts from 2 because of 2 CLS, one for multimodal encoder and one
# that comes from image encoder.
total_indices = (
mim_labels.size(1)
if mlm_labels is not None
else (image_masked_sequence.size(1) - 1)
)
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
outputs.mmm_image_output = self.mmm_loss.mim(
sequence_for_image,
mim_labels,
mmm_mim_head_output, mim_labels, pos_mask
) # type: ignore
outputs.mmm_image_output.loss *= self.mmm_image_loss_weight
outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss
Expand Down

0 comments on commit c973d20

Please sign in to comment.