Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLAVA]Move itm head to flava model for pretraining #132

Draft
wants to merge 7 commits into
base: gh/ankitade/9/base
Choose a base branch
from
44 changes: 38 additions & 6 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
FLAVA_FOR_PRETRAINED_MAPPING = {
# This will no longer load with the updated model, but keeping here just in case
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_text_encoder.pt",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm.pt",
}

FLAVA_MODEL_MAPPING = {
Expand Down Expand Up @@ -292,18 +292,45 @@ def encode_mm(
return self.mm_encoder(fused_state)


class TwoWayHead(nn.Module):
def __init__(self, hidden_size: int = 768, **kwargs: Any):
super().__init__()

self.seq_relationship = nn.Linear(hidden_size, 2)

def forward(self, pooled_output):
return self.seq_relationship(pooled_output)


class ITMHead(nn.Module):
def __init__(self, hidden_size: int = 768):
super().__init__()
self.pooler = Pooler(hidden_size=hidden_size)
self.cls = TwoWayHead(hidden_size=hidden_size)

def forward(self, hidden_states: Tensor):
pooled_output = self.pooler(hidden_states)
logits = self.cls(pooled_output)
return logits


class FLAVAForPreTraining(nn.Module, PretrainedMixin):
# TODOs:
# 1. Expose logit scale
# 2. For FLAVA model, allow interpolating the embeddings to
# for patch embeddings
def __init__(
self, model: FLAVAModel, image_codebook: nn.Module, loss: FLAVAPretrainingLoss
) -> None:
self,
model: FLAVAModel,
image_codebook: nn.Module,
loss: nn.Module,
itm_head: nn.Module,
):
super().__init__()
self.model = model
self.image_codebook = image_codebook
self.loss = loss
self.itm_head = itm_head

def encode_image(
self,
Expand Down Expand Up @@ -351,6 +378,10 @@ def forward(
required_embedding=required_embedding,
skip_unmasked_mm_encoder=skip_unmasked_mm_encoder,
)
multimodal_masked_sequence = flava_output.multimodal_masked.last_hidden_state
itm_logits = None
if multimodal_masked_sequence is not None:
itm_logits = self.itm_head(multimodal_masked_sequence)

return self.loss(
image_sequence=flava_output.image.last_hidden_state,
Expand All @@ -366,6 +397,7 @@ def forward(
mlm_labels=mlm_labels,
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
itm_logits=itm_logits,
)


Expand Down Expand Up @@ -520,13 +552,13 @@ def flava_model_for_pretraining(
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
model = flava_model(**flava_model_kwargs)
hidden_size = flava_model_kwargs.get("hidden_size") or 768
itm_head = ITMHead(hidden_size)
losses = FLAVAPretrainingLoss()
codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model,
image_codebook=codebook,
loss=losses,
model=model, image_codebook=codebook, loss=losses, itm_head=itm_head
)

if pretrained_model_key is not None:
Expand Down
25 changes: 5 additions & 20 deletions torchmultimodal/modules/losses/flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,41 +92,25 @@ def forward(self, hidden_states: Tensor) -> Tensor:
return pooled_output


class TwoWayHead(nn.Module):
def __init__(self, hidden_size: int = 768, **kwargs: Any):
super().__init__()

self.seq_relationship = nn.Linear(hidden_size, 2)

def forward(self, pooled_output: Tensor) -> Tensor:
return self.seq_relationship(pooled_output)


class ITMLoss(nn.Module):
def __init__(
self,
hidden_size: int = 768,
ignore_index: int = -1,
**kwargs: Any,
):
super().__init__()
self.pooler = Pooler(hidden_size=hidden_size)
self.cls = TwoWayHead(hidden_size=hidden_size)
self.ce_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)

def forward(
self,
hidden_states: Tensor,
scores: Tensor,
labels: Tensor,
) -> ITMLossOutput:
if self.training:
assert_labels_are_present(labels, "itm labels")

pooled_output = self.pooler(hidden_states)
scores = self.cls(pooled_output)

if labels is None:
loss = pooled_output.sum() * 0
loss = scores.sum() * 0
else:
loss = self.ce_loss(
scores.view(-1, 2),
Expand Down Expand Up @@ -309,7 +293,6 @@ def __init__(
):
super().__init__()
self.itm_loss = ITMLoss(
hidden_size=hidden_size,
ignore_index=ignore_index,
)
self.contrastive_loss = FLAVAGlobalContrastiveLoss(
Expand Down Expand Up @@ -375,6 +358,7 @@ def forward(
mlm_labels: Optional[Tensor] = None,
projected_image_embeddings: Optional[Tensor] = None,
projected_text_embeddings: Optional[Tensor] = None,
itm_logits: Optional[Tensor] = None,
) -> FLAVAPretrainingLossOutput:
outputs = FLAVAPretrainingLossOutput()
pos_mask = None
Expand Down Expand Up @@ -411,6 +395,7 @@ def forward(
outputs.losses.mlm_loss = outputs.mlm_output.loss

if multimodal_masked_sequence is not None and self.itm_loss_weight > 0:
assert itm_logits is not None
if itm_labels is not None:
pos_pairs = itm_labels.ne(0)
pos_mask = torch.where(
Expand All @@ -421,7 +406,7 @@ def forward(
multimodal_masked_sequence.size(0),
device=multimodal_masked_sequence.device,
).bool()
outputs.itm_output = self.itm_loss(multimodal_masked_sequence, itm_labels)
outputs.itm_output = self.itm_loss(itm_logits, itm_labels)
outputs.itm_output.loss *= self.itm_loss_weight
outputs.losses.itm_loss = outputs.itm_output.loss

Expand Down