Skip to content

Commit

Permalink
[FLAVA]Move itm head to flava model for pretraining
Browse files Browse the repository at this point in the history
ghstack-source-id: 9555f3ae021a61f70831d537f587b8cb2a3b8d9c
Pull Request resolved: #132
  • Loading branch information
ankitade committed Jul 23, 2022
1 parent 27c8eec commit a586c88
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
44 changes: 38 additions & 6 deletions torchmultimodal/models/flava/flava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@


FLAVA_FOR_PRETRAINED_MAPPING = {
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt"
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_cl_itm.pt"
}


Expand Down Expand Up @@ -287,18 +287,45 @@ def encode_mm(
return self.mm_encoder(fused_state)


class TwoWayHead(nn.Module):
def __init__(self, hidden_size: int = 768, **kwargs: Any):
super().__init__()

self.seq_relationship = nn.Linear(hidden_size, 2)

def forward(self, pooled_output):
return self.seq_relationship(pooled_output)


class ITMHead(nn.Module):
def __init__(self, hidden_size: int = 768):
super().__init__()
self.pooler = Pooler(hidden_size=hidden_size)
self.cls = TwoWayHead(hidden_size=hidden_size)

def forward(self, hidden_states: Tensor):
pooled_output = self.pooler(hidden_states)
logits = self.cls(pooled_output)
return logits


class FLAVAForPreTraining(nn.Module, PretrainedMixin):
# TODOs:
# 1. Expose logit scale
# 2. For FLAVA model, allow interpolating the embeddings to
# for patch embeddings
def __init__(
self, model: FLAVAModel, image_codebook: nn.Module, loss: FLAVAPretrainingLoss
) -> None:
self,
model: FLAVAModel,
image_codebook: nn.Module,
loss: nn.Module,
itm_head: nn.Module,
):
super().__init__()
self.model = model
self.image_codebook = image_codebook
self.loss = loss
self.itm_head = itm_head

def encode_image(
self,
Expand Down Expand Up @@ -346,6 +373,10 @@ def forward(
required_embedding=required_embedding,
skip_unmasked_mm_encoder=skip_unmasked_mm_encoder,
)
multimodal_masked_sequence = flava_output.multimodal_masked.last_hidden_state
itm_logits = None
if multimodal_masked_sequence is not None:
itm_logits = self.itm_head(multimodal_masked_sequence)

return self.loss(
image_sequence=flava_output.image.last_hidden_state,
Expand All @@ -361,6 +392,7 @@ def forward(
mlm_labels=mlm_labels,
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
itm_logits=itm_logits,
)


Expand Down Expand Up @@ -516,13 +548,13 @@ def flava_model_for_pretraining(
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
model = flava_model(**flava_model_kwargs)
hidden_size = flava_model_kwargs.get("hidden_size") or 768
itm_head = ITMHead(hidden_size)
losses = FLAVAPretrainingLoss()
codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model,
image_codebook=codebook,
loss=losses,
model=model, image_codebook=codebook, loss=losses, itm_head=itm_head
)

if pretrained_model_key is not None:
Expand Down
11 changes: 4 additions & 7 deletions torchmultimodal/modules/losses/flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,16 @@ def __init__(
**kwargs: Any,
):
super().__init__()
self.pooler = Pooler(hidden_size=hidden_size)
self.cls = TwoWayHead(hidden_size=hidden_size)
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)

def forward(
self,
hidden_states: Tensor,
scores: Tensor,
labels: Tensor,
) -> ITMLossOutput:
if self.training:
assert_labels_are_present(labels, "itm labels")

pooled_output = self.pooler(hidden_states)
scores = self.cls(pooled_output)

if labels is None:
loss = pooled_output.sum() * 0
else:
Expand Down Expand Up @@ -375,6 +370,7 @@ def forward(
mlm_labels: Optional[Tensor] = None,
projected_image_embeddings: Optional[Tensor] = None,
projected_text_embeddings: Optional[Tensor] = None,
itm_logits: Optional[Tensor] = None,
) -> FLAVAPretrainingLossOutput:
outputs = FLAVAPretrainingLossOutput()
pos_mask = None
Expand Down Expand Up @@ -411,6 +407,7 @@ def forward(
outputs.losses.mlm_loss = outputs.mlm_output.loss

if multimodal_masked_sequence is not None and self.itm_loss_weight > 0:
assert itm_logits is not None
if itm_labels is not None:
pos_pairs = itm_labels.ne(0)
pos_mask = torch.where(
Expand All @@ -421,7 +418,7 @@ def forward(
multimodal_masked_sequence.size(0),
device=multimodal_masked_sequence.device,
).bool()
outputs.itm_output = self.itm_loss(multimodal_masked_sequence, itm_labels)
outputs.itm_output = self.itm_loss(itm_logits, itm_labels)
outputs.itm_output.loss *= self.itm_loss_weight
outputs.losses.itm_loss = outputs.itm_output.loss

Expand Down

0 comments on commit a586c88

Please sign in to comment.