Skip to content

Commit

Permalink
[FLAVA]Separate the pretraining loss from the pretraininig model
Browse files Browse the repository at this point in the history
ghstack-source-id: d0845c05530a19026b144a145823eb7a44d54f4f
Pull Request resolved: #278
  • Loading branch information
ankitade committed Aug 22, 2022
1 parent 06d3436 commit c1043aa
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
33 changes: 27 additions & 6 deletions examples/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Tuple
from typing import Any, List, Tuple

import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torchmetrics import Accuracy
from torchmultimodal.models.flava.model import (
flava_model_for_classification,
flava_model_for_pretraining,
)
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss
from transformers.optimization import get_cosine_schedule_with_warmup


def get_optimizers_for_lightning(
model: torch.nn.Module,
parameters: List[Tensor],
learning_rate: float,
adam_eps: float,
adam_weight_decay: float,
Expand All @@ -26,7 +28,7 @@ def get_optimizers_for_lightning(
max_steps: int,
):
optimizer = torch.optim.AdamW(
model.parameters(),
parameters,
lr=learning_rate,
betas=adam_betas,
eps=adam_eps,
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(
self.adam_weight_decay = adam_weight_decay
self.warmup_steps = warmup_steps
self.max_steps = max_steps
self.loss = FLAVAPretrainingLoss()

def training_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
Expand Down Expand Up @@ -104,11 +107,29 @@ def _step(self, batch, batch_idx):
itm_labels=batch.get("itm_labels", None),
required_embedding=required_embedding,
)
return output

loss = self.loss(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
return loss

def configure_optimizers(self):
parameters = self.model.parameters() + self.loss.parameters()
return get_optimizers_for_lightning(
self.model,
parameters,
self.learning_rate,
self.adam_eps,
self.adam_weight_decay,
Expand Down Expand Up @@ -194,7 +215,7 @@ def _step(self, batch, batch_idx):

def configure_optimizers(self):
return get_optimizers_for_lightning(
self.model,
self.model.parameters(),
self.learning_rate,
self.adam_eps,
self.adam_weight_decay,
Expand Down
26 changes: 23 additions & 3 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,24 @@ class FLAVAForClassificationOutput(ModelOutput):
loss: Tensor


@dataclass
class FLAVAForPretrainingOutput:
multimodal_masked_sequence: Tensor
pos_mask: Tensor
mim_labels: Tensor
mlm_labels: Tensor
mmm_mlm_labels: Tensor
mmm_mim_labels: Tensor
itm_labels: Tensor
projected_image_embeddings: Tensor
projected_text_embeddings: Tensor
itm_logits: Tensor
mlm_head_output: Tensor
mim_head_output: Tensor
mmm_mlm_head_output: Tensor
mmm_mim_head_output: Tensor


class FLAVAModel(nn.Module, PretrainedMixin):
def __init__(
self,
Expand Down Expand Up @@ -469,8 +487,10 @@ def forward(
itm_logits = self.itm_head(multimodal_masked_sequence)

multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]

if mlm_labels is not None:
mmm_mlm_labels = mlm_labels[pos_mask]

if image_labels is not None:
mmm_mim_labels = image_labels[pos_mask]

Expand All @@ -494,14 +514,14 @@ def forward(
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels)

return self.loss(
return FLAVAForPretrainingOutput(
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
pos_mask=pos_mask,
itm_labels=itm_labels,
mim_labels=image_labels,
mlm_labels=mlm_labels,
mmm_mlm_labels=mmm_mlm_labels,
mmm_mim_labels=mmm_mim_labels,
mmm_mlm_labels=mmm_mlm_labels,
itm_labels=itm_labels,
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
itm_logits=itm_logits,
Expand Down

0 comments on commit c1043aa

Please sign in to comment.