Skip to content

Commit

Permalink
[FLAVA] Move projections from contrastive loss to model
Browse files Browse the repository at this point in the history
ghstack-source-id: 0236ad8b49c985b62f1334a6f92945e0bce742c7
Pull Request resolved: #106
  • Loading branch information
ankitade committed Jul 23, 2022
1 parent 0ad6763 commit c605279
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 39 deletions.
2 changes: 2 additions & 0 deletions test/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def setUp(self):
mm_encoder=mm_encoder,
image_to_mm_projection=image_to_mm_projection,
text_to_mm_projection=text_to_mm_projection,
text_projection=nn.Identity(),
image_projection=nn.Identity(),
)

def _assert_empty(self, field):
Expand Down
115 changes: 89 additions & 26 deletions torchmultimodal/models/flava/flava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,17 @@

FLAVAOutput = namedtuple(
"FLAVAOutput",
["image", "image_masked", "text", "text_masked", "multimodal", "multimodal_masked"],
defaults=(None, None, None, None, None, None),
[
"image",
"image_masked",
"text",
"text_masked",
"multimodal",
"multimodal_masked",
"projected_image_embeddings",
"projected_text_embeddings",
],
defaults=(None, None, None, None, None, None, None, None),
)
FLAVAOutput.__annotations__ = {
"image": FLAVATransformerOutput,
Expand All @@ -51,7 +60,7 @@


FLAVA_FOR_PRETRAINED_MAPPING = {
"flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt"
}


Expand Down Expand Up @@ -99,6 +108,8 @@ def __init__(
mm_encoder: nn.Module,
image_to_mm_projection: nn.Module,
text_to_mm_projection: nn.Module,
text_projection: nn.Module,
image_projection: nn.Module,
**kwargs: Any,
) -> None:
super().__init__()
Expand All @@ -107,6 +118,8 @@ def __init__(
self.mm_encoder = mm_encoder
self.image_to_mm_projection = image_to_mm_projection
self.text_to_mm_projection = text_to_mm_projection
self.text_projection = text_projection
self.image_projection = image_projection

def forward(
self,
Expand All @@ -125,30 +138,50 @@ def forward(
else:
required_embedding = "text"

image_outputs = self._encode_data_to_embeddings(
image_encoding_out = self._encode_data_to_embeddings(
image,
required_embedding,
["image", "mm"],
self.encode_image,
partial(self.encode_image, projection=True),
)
text_outputs = self._encode_data_to_embeddings(
if len(image_encoding_out) == 2:
image_outputs, projected_image_embeddings = (
image_encoding_out[0],
image_encoding_out[1],
)
else:
image_outputs = image_encoding_out
projected_image_embeddings = None

text_encoding_out = self._encode_data_to_embeddings(
text,
required_embedding,
["text", "mm"],
self.encode_text,
partial(self.encode_text, projection=True),
)
if len(text_encoding_out) == 2:
text_outputs, projected_text_embeddings = (
text_encoding_out[0],
text_encoding_out[1],
)
else:
text_outputs = text_encoding_out
projected_text_embeddings = None

image_masked_outputs = self._encode_data_to_embeddings(
image,
required_embedding,
["image", "mm"],
partial(self.encode_image, image_patches_mask=image_patches_mask),
)
assert type(image_masked_outputs) == FLAVATransformerOutput
text_masked_outputs = self._encode_data_to_embeddings(
text_masked,
required_embedding,
["text", "mm"],
self.encode_text,
)
assert type(text_masked_outputs) == FLAVATransformerOutput

multimodal_outputs = FLAVATransformerOutput()
multimodal_masked_outputs = FLAVATransformerOutput()
Expand Down Expand Up @@ -182,39 +215,60 @@ def forward(
text_masked=text_masked_outputs,
multimodal=multimodal_outputs,
multimodal_masked=multimodal_masked_outputs,
projected_image_embeddings=projected_image_embeddings,
projected_text_embeddings=projected_text_embeddings,
)

def encode_image(
self, image: Tensor, image_patches_mask: Optional[Tensor] = None
) -> Optional[FLAVATransformerOutput]:
self,
image: Tensor,
image_patches_mask: Optional[Tensor] = None,
projection: bool = False,
) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]:
if image_patches_mask is not None:
return self.image_encoder(image, image_patches_mask)
encoded_image = self.image_encoder(image, image_patches_mask)
else:
return self.image_encoder(image)
encoded_image = self.image_encoder(image)
if projection:
projected_embeddings = self.image_projection(
encoded_image.last_hidden_state[:, 0, :]
)
return encoded_image, projected_embeddings
return encoded_image

def encode_text(
self,
text: Tensor,
text_mask: Optional[Tensor] = None,
) -> Optional[FLAVATransformerOutput]:
self, text: Tensor, text_mask: Optional[Tensor] = None, projection: bool = False
) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]:
# TODO(asg): Give proper parameter names when implementing text encoder
return self.text_encoder(
encoded_text = self.text_encoder(
input_ids=text,
attention_mask=text_mask,
)
if projection:
projected_embeddings = self.text_projection(
encoded_text.last_hidden_state[:, 0, :]
)
return encoded_text, projected_embeddings
return encoded_text

def _encode_data_to_embeddings(
self,
data: Optional[Tensor],
selected_head_encoder: EMBEDDING_OPTIONS,
encoder_options: List[EMBEDDING_OPTIONS],
encode_callable: Callable[..., FLAVATransformerOutput],
) -> Optional[FLAVATransformerOutput]:
output = FLAVATransformerOutput()
encode_callable: Callable[
...,
Union[
Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]
],
],
) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]:
output: Union[
Tuple[FLAVATransformerOutput, Tensor], FLAVATransformerOutput
] = FLAVATransformerOutput()

if data is not None and selected_head_encoder in encoder_options:
output = encode_callable(data)

return output

def encode_mm(
Expand Down Expand Up @@ -251,19 +305,19 @@ def encode_image(
image: Tensor,
cls_index: int = 0,
) -> Tensor:
transformer_output = self.model.encode_image(image)
embeddings = transformer_output.last_hidden_state
return self.loss.contrastive_loss.image_projection(embeddings[:, cls_index, :])
encoded_result = self.model.encode_image(image, projection=True)
encoded_image = encoded_result[1]
return encoded_image

def encode_text(
self,
text: Tensor,
text_mask: Optional[Tensor] = None,
cls_index: int = 0,
) -> Tensor:
transformer_output = self.model.encode_text(text, text_mask)
embeddings = transformer_output.last_hidden_state
return self.loss.contrastive_loss.text_projection(embeddings[:, cls_index, :])
encoded_result = self.model.encode_text(text, text_mask, projection=True)
encoded_text = encoded_result[1]
return encoded_text

# TODO: Add options to enable losses selectively
def forward(
Expand Down Expand Up @@ -305,6 +359,8 @@ def forward(
itm_labels=itm_labels,
mim_labels=image_labels,
mlm_labels=mlm_labels,
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
)


Expand Down Expand Up @@ -392,6 +448,8 @@ def flava_model(
multimodal_intermediate_activation: Callable[..., Tensor] = nn.functional.gelu,
multimodal_attention_probs_dropout_prob: float = 0.0,
multimodal_layer_norm_eps: float = 1e-12,
# projection
text_and_image_proj_size: int = 768,
**kwargs: Any,
) -> FLAVAModel:
image_encoder = flava_image_encoder(
Expand Down Expand Up @@ -437,12 +495,17 @@ def flava_model(
image_to_mm_projection = nn.Linear(image_hidden_size, multimodal_hidden_size)
text_to_mm_projection = nn.Linear(text_hidden_size, multimodal_hidden_size)

image_projection = nn.Linear(image_hidden_size, text_and_image_proj_size)
text_projection = nn.Linear(text_hidden_size, text_and_image_proj_size)

return FLAVAModel(
image_encoder=image_encoder,
text_encoder=text_encoder,
mm_encoder=mm_encoder,
image_to_mm_projection=image_to_mm_projection,
text_to_mm_projection=text_to_mm_projection,
text_projection=text_projection,
image_projection=image_projection,
)


Expand Down
23 changes: 10 additions & 13 deletions torchmultimodal/modules/losses/flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,22 +253,16 @@ def __init__(
else:
self.logit_scale = nn.Parameter(logit_scale * torch.ones([]))

self.image_projection = nn.Linear(image_embedding_size, projection_size)
self.text_projection = nn.Linear(text_embedding_size, projection_size)
self.image_embedding_index = image_embedding_index
self.text_embedding_index = text_embedding_index

def forward(
self,
image_sequence: Tensor,
text_sequence: Tensor,
mask: Tensor,
) -> FLAVAGlobalContrastiveLossOutput:
text_embedding = nn.functional.normalize(
self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1
)

text_embedding = nn.functional.normalize(text_sequence, dim=-1)
image_embedding = nn.functional.normalize(
self.image_projection(image_sequence[:, self.image_embedding_index, :]),
image_sequence,
dim=-1,
)

Expand Down Expand Up @@ -380,13 +374,16 @@ def forward(
itm_labels: Optional[Tensor] = None,
mim_labels: Optional[Tensor] = None,
mlm_labels: Optional[Tensor] = None,
projected_image_embeddings: Optional[Tensor] = None,
projected_text_embeddings: Optional[Tensor] = None,
) -> FLAVAPretrainingLossOutput:
outputs = FLAVAPretrainingLossOutput()
pos_mask = None

# Check multimodal_masked_sequence to make sure this is unimodal case
# This specific case can though be backpropagated directly as MIM is independent of
# text, but that is a research question :)

if (
image_masked_sequence is not None
and self.mim_weight > 0
Expand Down Expand Up @@ -466,13 +463,13 @@ def forward(
outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss

if (
image_sequence is not None
and text_sequence is not None
projected_image_embeddings is not None
and projected_text_embeddings is not None
and self.contrastive_loss_weight > 0
):
outputs.global_contrastive_output = self.contrastive_loss(
image_sequence,
text_sequence,
projected_image_embeddings,
projected_text_embeddings,
pos_mask,
)
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
Expand Down

0 comments on commit c605279

Please sign in to comment.