Skip to content

Commit

Permalink
[FLAVA] Move masked prediction head to flava_for_pretraining
Browse files Browse the repository at this point in the history
ghstack-source-id: f1ba0bfa02fedc5ef86c9c6d824e4ba98f47c6e9
Pull Request resolved: #195
  • Loading branch information
ankitade committed Aug 22, 2022
1 parent 6c4880a commit 06d3436
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 148 deletions.
148 changes: 139 additions & 9 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
FLAVA_FOR_PRETRAINED_MAPPING = {
# This will no longer load with the updated model, but keeping here just in case
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm.pt",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt",
}

FLAVA_MODEL_MAPPING = {
Expand Down Expand Up @@ -314,6 +314,50 @@ def forward(self, hidden_states: Tensor):
return logits


class MaskedPredictionHead(nn.Module):
def __init__(
self,
hidden_size: int = 768,
vocab_size: int = 30522,
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
layer_norm_eps: float = 1e-5,
use_fp32_layer_norm: bool = True,
ignore_index: int = -1,
**kwargs: Any,
):
super().__init__()

self.dense = nn.Linear(hidden_size, hidden_size)
self.transform_act_fn = transform_act_fn

self.layer_norm: nn.LayerNorm
if use_fp32_layer_norm:
self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps)
else:
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)

self.bias = nn.Parameter(torch.zeros(vocab_size))

# Need a link between the two variables so that the bias is
# correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
self.ignore_index = ignore_index

def forward(self, hidden_states: Tensor, masked_labels: Tensor) -> Tensor:
masked_tokens = masked_labels.ne(self.ignore_index)
sequence_output = hidden_states[masked_tokens, :]

head_output = self.dense(sequence_output)
head_output = self.transform_act_fn(head_output)
head_output = self.layer_norm(head_output)
head_output = self.decoder(head_output)
return head_output


class FLAVAForPreTraining(nn.Module, PretrainedMixin):
# TODOs:
# 1. Expose logit scale
Expand All @@ -325,12 +369,20 @@ def __init__(
image_codebook: nn.Module,
loss: nn.Module,
itm_head: nn.Module,
mlm_head: nn.Module,
mim_head: nn.Module,
mmm_mlm_head: nn.Module,
mmm_mim_head: nn.Module,
):
super().__init__()
self.model = model
self.image_codebook = image_codebook
self.loss = loss
self.itm_head = itm_head
self.mlm_head = mlm_head
self.mim_head = mim_head
self.mmm_mlm_head = mmm_mlm_head
self.mmm_mim_head = mmm_mim_head

def encode_image(
self,
Expand Down Expand Up @@ -380,24 +432,83 @@ def forward(
)
multimodal_masked_sequence = flava_output.multimodal_masked.last_hidden_state
itm_logits = None

image_masked_sequence = flava_output.image_masked.last_hidden_state
text_masked_sequence = flava_output.text_masked.last_hidden_state
mlm_head_output = (
mim_head_output
) = mmm_mlm_head_output = mmm_mim_head_output = None
pos_mask = None
if image_masked_sequence is not None and multimodal_masked_sequence is None:
# Remove CLS token from image_masked_sequence
start_index = -image_labels.size(1) if image_labels is not None else 1
mim_head_output = self.mim_head(
image_masked_sequence[:, start_index:, :], image_labels
)

if text_masked_sequence is not None and multimodal_masked_sequence is None:
start_index = -mlm_labels.size(1) if mlm_labels is not None else 1
mlm_head_output = self.mlm_head(
text_masked_sequence[:, start_index:, :], mlm_labels
)

mmm_mlm_labels = mlm_labels
mmm_mim_labels = image_labels

if multimodal_masked_sequence is not None:
if itm_labels is not None:
pos_pairs = itm_labels.ne(0)
pos_mask = torch.where(
pos_pairs.any(), pos_pairs, pos_pairs.new([True])
)
else:
pos_mask = torch.ones(
multimodal_masked_sequence.size(0),
device=multimodal_masked_sequence.device,
).bool()
itm_logits = self.itm_head(multimodal_masked_sequence)

multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]
if mlm_labels is not None:
mmm_mlm_labels = mlm_labels[pos_mask]
if image_labels is not None:
mmm_mim_labels = image_labels[pos_mask]

if multimodal_masked_sequence is not None:
start_index = (
-mmm_mlm_labels.size(1)
if mmm_mlm_labels is not None
else -(text_masked_sequence.size(1) - 1)
)
sequence_for_text = multimodal_masked_sequence[:, start_index:, :]
mmm_mlm_head_output = self.mmm_mlm_head(sequence_for_text, mmm_mlm_labels)

if multimodal_masked_sequence is not None:
# Starts from 2 because of 2 CLS, one for multimodal encoder and one
# that comes from image encoder.
total_indices = (
mmm_mim_labels.size(1)
if mmm_mim_labels is not None
else (image_masked_sequence.size(1) - 1)
)
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels)

return self.loss(
image_sequence=flava_output.image.last_hidden_state,
text_sequence=flava_output.text.last_hidden_state,
image_masked_sequence=flava_output.image_masked.last_hidden_state,
text_masked_sequence=flava_output.text_masked.last_hidden_state,
multimodal_sequence=flava_output.multimodal.last_hidden_state
if not skip_unmasked_mm_encoder
else None,
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
pos_mask=pos_mask,
itm_labels=itm_labels,
mim_labels=image_labels,
mlm_labels=mlm_labels,
mmm_mlm_labels=mmm_mlm_labels,
mmm_mim_labels=mmm_mim_labels,
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
itm_logits=itm_logits,
mlm_head_output=mlm_head_output,
mim_head_output=mim_head_output,
mmm_mlm_head_output=mmm_mlm_head_output,
mmm_mim_head_output=mmm_mim_head_output,
)


Expand Down Expand Up @@ -548,17 +659,36 @@ def flava_model(
def flava_model_for_pretraining(
codebook_image_size: int = 112,
pretrained_model_key: Optional[str] = None,
image_vocab_size: int = 8192,
**flava_model_kwargs: Any,
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
model = flava_model(**flava_model_kwargs)
hidden_size = flava_model_kwargs.get("hidden_size") or 768
text_vocab_size = flava_model_kwargs.get("vocab_size") or 30522
itm_head = ITMHead(hidden_size)
mlm_head = MaskedPredictionHead(hidden_size=hidden_size, vocab_size=text_vocab_size)
mim_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=image_vocab_size
)
mmm_mlm_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=text_vocab_size
)
mmm_mim_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=image_vocab_size
)
losses = FLAVAPretrainingLoss()
codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model, image_codebook=codebook, loss=losses, itm_head=itm_head
model=model,
image_codebook=codebook,
loss=losses,
itm_head=itm_head,
mlm_head=mlm_head,
mim_head=mim_head,
mmm_mlm_head=mmm_mlm_head,
mmm_mim_head=mmm_mim_head,
)

if pretrained_model_key is not None:
Expand Down
Loading

0 comments on commit 06d3436

Please sign in to comment.