Skip to content

Commit

Permalink
[FLAVA]Separate the pretraining loss from the pretraininig model
Browse files Browse the repository at this point in the history
ghstack-source-id: 417f0746292f666c28f5092eb4545ab0dca567c0
Pull Request resolved: #278
  • Loading branch information
ankitade committed Aug 23, 2022
1 parent 06d3436 commit 476d23b
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 16 deletions.
24 changes: 22 additions & 2 deletions examples/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@

import torch
from pytorch_lightning import LightningModule
from torch import nn
from torchmetrics import Accuracy
from torchmultimodal.models.flava.model import (
flava_model_for_classification,
flava_model_for_pretraining,
)
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss
from transformers.optimization import get_cosine_schedule_with_warmup


def get_optimizers_for_lightning(
model: torch.nn.Module,
model: nn.Module,
learning_rate: float,
adam_eps: float,
adam_weight_decay: float,
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(
self.adam_weight_decay = adam_weight_decay
self.warmup_steps = warmup_steps
self.max_steps = max_steps
self.loss = FLAVAPretrainingLoss(logit_scale=self.model.logit_scale)

def training_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
Expand Down Expand Up @@ -104,7 +107,24 @@ def _step(self, batch, batch_idx):
itm_labels=batch.get("itm_labels", None),
required_embedding=required_embedding,
)
return output

loss = self.loss(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
return loss

def configure_optimizers(self):
return get_optimizers_for_lightning(
Expand Down
52 changes: 51 additions & 1 deletion test/models/flava/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
flava_model_for_classification,
flava_model_for_pretraining,
)
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -139,8 +140,25 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
image_input = inputs_pretraining("image")
text_input = inputs_pretraining("text")
flava = pretraining_model()

losses = FLAVAPretrainingLoss(flava.logit_scale)
output = flava(*mm_input)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)

actual = output.losses
expected = dict(
mmm_text_loss=10.9567,
Expand All @@ -153,6 +171,22 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
self._assert_tensor_dicts_equal(actual, expected)

output = flava(*image_input)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
actual = output.losses
expected = dict(
mmm_text_loss=None,
Expand All @@ -165,6 +199,22 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
self._assert_tensor_dicts_equal(actual, expected)

output = flava(*text_input)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
actual = output.losses
expected = dict(
mmm_text_loss=None,
Expand Down
50 changes: 50 additions & 0 deletions test/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
FLAVAOutput,
)
from torchmultimodal.modules.layers.transformer import TransformerOutput
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss

NUM_CLASSES = 2

Expand Down Expand Up @@ -61,6 +62,7 @@ def test_forward_pretraining(self):
mlm_labels[:, 1:3] = text[:, 1:3]
itm_labels = torch.tensor((0, 1), dtype=torch.long)
flava = flava_model_for_pretraining()
losses = FLAVAPretrainingLoss(flava.logit_scale)
flava.eval()
output = flava(
image=image,
Expand All @@ -72,6 +74,22 @@ def test_forward_pretraining(self):
itm_labels=itm_labels,
mlm_labels=mlm_labels,
)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
self.assertIsNone(output.mlm_output)
self.assertIsNone(output.mim_output)
self.assertIsNotNone(output.global_contrastive_output)
Expand All @@ -96,6 +114,22 @@ def test_forward_pretraining(self):
itm_labels=itm_labels,
mlm_labels=mlm_labels,
)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
self.assertIsNone(output.mlm_output)
self.assertIsNotNone(output.mim_output)
self.assertIsNone(output.global_contrastive_output)
Expand All @@ -120,6 +154,22 @@ def test_forward_pretraining(self):
itm_labels=itm_labels,
mlm_labels=mlm_labels,
)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
self.assertIsNotNone(output.mlm_output)
self.assertIsNone(output.mim_output)
self.assertIsNone(output.global_contrastive_output)
Expand Down
43 changes: 30 additions & 13 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
TransformerEncoder,
TransformerOutput,
)
from torchmultimodal.modules.losses.flava import (
FLAVAPretrainingLoss,
FLAVAPretrainingLossOutput,
Pooler,
)
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLossOutput, Pooler
from torchmultimodal.utils.common import ModelOutput, PretrainedMixin
from typing_extensions import Literal

Expand Down Expand Up @@ -62,7 +58,7 @@
FLAVA_FOR_PRETRAINED_MAPPING = {
# This will no longer load with the updated model, but keeping here just in case
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_no_loss.pt",
}

FLAVA_MODEL_MAPPING = {
Expand Down Expand Up @@ -105,6 +101,24 @@ class FLAVAForClassificationOutput(ModelOutput):
loss: Tensor


@dataclass
class FLAVAForPretrainingOutput:
multimodal_masked_sequence: Tensor
pos_mask: Tensor
mim_labels: Tensor
mlm_labels: Tensor
mmm_mlm_labels: Tensor
mmm_mim_labels: Tensor
itm_labels: Tensor
projected_image_embeddings: Tensor
projected_text_embeddings: Tensor
itm_logits: Tensor
mlm_head_output: Tensor
mim_head_output: Tensor
mmm_mlm_head_output: Tensor
mmm_mim_head_output: Tensor


class FLAVAModel(nn.Module, PretrainedMixin):
def __init__(
self,
Expand Down Expand Up @@ -367,22 +381,22 @@ def __init__(
self,
model: FLAVAModel,
image_codebook: nn.Module,
loss: nn.Module,
itm_head: nn.Module,
mlm_head: nn.Module,
mim_head: nn.Module,
mmm_mlm_head: nn.Module,
mmm_mim_head: nn.Module,
logit_scale: nn.Module,
):
super().__init__()
self.model = model
self.image_codebook = image_codebook
self.loss = loss
self.itm_head = itm_head
self.mlm_head = mlm_head
self.mim_head = mim_head
self.mmm_mlm_head = mmm_mlm_head
self.mmm_mim_head = mmm_mim_head
self.logit_scale = logit_scale

def encode_image(
self,
Expand Down Expand Up @@ -469,8 +483,10 @@ def forward(
itm_logits = self.itm_head(multimodal_masked_sequence)

multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]

if mlm_labels is not None:
mmm_mlm_labels = mlm_labels[pos_mask]

if image_labels is not None:
mmm_mim_labels = image_labels[pos_mask]

Expand All @@ -494,14 +510,14 @@ def forward(
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels)

return self.loss(
return FLAVAForPretrainingOutput(
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
pos_mask=pos_mask,
itm_labels=itm_labels,
mim_labels=image_labels,
mlm_labels=mlm_labels,
mmm_mlm_labels=mmm_mlm_labels,
mmm_mim_labels=mmm_mim_labels,
mmm_mlm_labels=mmm_mlm_labels,
itm_labels=itm_labels,
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
itm_logits=itm_logits,
Expand Down Expand Up @@ -660,6 +676,7 @@ def flava_model_for_pretraining(
codebook_image_size: int = 112,
pretrained_model_key: Optional[str] = None,
image_vocab_size: int = 8192,
logit_scale: float = math.log(1 / 0.07),
**flava_model_kwargs: Any,
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
Expand All @@ -677,18 +694,18 @@ def flava_model_for_pretraining(
mmm_mim_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=image_vocab_size
)
losses = FLAVAPretrainingLoss()

codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model,
image_codebook=codebook,
loss=losses,
itm_head=itm_head,
mlm_head=mlm_head,
mim_head=mim_head,
mmm_mlm_head=mmm_mlm_head,
mmm_mim_head=mmm_mim_head,
logit_scale=nn.Parameter(logit_scale * torch.ones([])),
)

if pretrained_model_key is not None:
Expand Down

0 comments on commit 476d23b

Please sign in to comment.