diff --git a/test/models/flava/test_flava.py b/test/models/flava/test_flava.py index 900f319a1..10df6e23c 100644 --- a/test/models/flava/test_flava.py +++ b/test/models/flava/test_flava.py @@ -167,6 +167,8 @@ def setUp(self): mm_encoder=mm_encoder, image_to_mm_projection=image_to_mm_projection, text_to_mm_projection=text_to_mm_projection, + text_projection=nn.Identity(), + image_projection=nn.Identity(), ) def _assert_empty(self, field): diff --git a/torchmultimodal/models/flava/flava_model.py b/torchmultimodal/models/flava/flava_model.py index 56ffaa741..513698217 100644 --- a/torchmultimodal/models/flava/flava_model.py +++ b/torchmultimodal/models/flava/flava_model.py @@ -37,8 +37,17 @@ FLAVAOutput = namedtuple( "FLAVAOutput", - ["image", "image_masked", "text", "text_masked", "multimodal", "multimodal_masked"], - defaults=(None, None, None, None, None, None), + [ + "image", + "image_masked", + "text", + "text_masked", + "multimodal", + "multimodal_masked", + "projected_image_embeddings", + "projected_text_embeddings", + ], + defaults=(None, None, None, None, None, None, None, None), ) FLAVAOutput.__annotations__ = { "image": FLAVATransformerOutput, @@ -51,7 +60,7 @@ FLAVA_FOR_PRETRAINED_MAPPING = { - "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin", + "flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt" } @@ -99,6 +108,8 @@ def __init__( mm_encoder: nn.Module, image_to_mm_projection: nn.Module, text_to_mm_projection: nn.Module, + text_projection: nn.Module, + image_projection: nn.Module, **kwargs: Any, ) -> None: super().__init__() @@ -107,6 +118,8 @@ def __init__( self.mm_encoder = mm_encoder self.image_to_mm_projection = image_to_mm_projection self.text_to_mm_projection = text_to_mm_projection + self.text_projection = text_projection + self.image_projection = image_projection def forward( self, @@ -125,30 +138,50 @@ def forward( else: required_embedding = "text" - image_outputs = self._encode_data_to_embeddings( + image_encoding_out = self._encode_data_to_embeddings( image, required_embedding, ["image", "mm"], - self.encode_image, + partial(self.encode_image, projection=True), ) - text_outputs = self._encode_data_to_embeddings( + if len(image_encoding_out) == 2: + image_outputs, projected_image_embeddings = ( + image_encoding_out[0], + image_encoding_out[1], + ) + else: + image_outputs = image_encoding_out + projected_image_embeddings = None + + text_encoding_out = self._encode_data_to_embeddings( text, required_embedding, ["text", "mm"], - self.encode_text, + partial(self.encode_text, projection=True), ) + if len(text_encoding_out) == 2: + text_outputs, projected_text_embeddings = ( + text_encoding_out[0], + text_encoding_out[1], + ) + else: + text_outputs = text_encoding_out + projected_text_embeddings = None + image_masked_outputs = self._encode_data_to_embeddings( image, required_embedding, ["image", "mm"], partial(self.encode_image, image_patches_mask=image_patches_mask), ) + assert type(image_masked_outputs) == FLAVATransformerOutput text_masked_outputs = self._encode_data_to_embeddings( text_masked, required_embedding, ["text", "mm"], self.encode_text, ) + assert type(text_masked_outputs) == FLAVATransformerOutput multimodal_outputs = FLAVATransformerOutput() multimodal_masked_outputs = FLAVATransformerOutput() @@ -182,39 +215,60 @@ def forward( text_masked=text_masked_outputs, multimodal=multimodal_outputs, multimodal_masked=multimodal_masked_outputs, + projected_image_embeddings=projected_image_embeddings, + projected_text_embeddings=projected_text_embeddings, ) def encode_image( - self, image: Tensor, image_patches_mask: Optional[Tensor] = None - ) -> Optional[FLAVATransformerOutput]: + self, + image: Tensor, + image_patches_mask: Optional[Tensor] = None, + projection: bool = False, + ) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]: if image_patches_mask is not None: - return self.image_encoder(image, image_patches_mask) + encoded_image = self.image_encoder(image, image_patches_mask) else: - return self.image_encoder(image) + encoded_image = self.image_encoder(image) + if projection: + projected_embeddings = self.image_projection( + encoded_image.last_hidden_state[:, 0, :] + ) + return encoded_image, projected_embeddings + return encoded_image def encode_text( - self, - text: Tensor, - text_mask: Optional[Tensor] = None, - ) -> Optional[FLAVATransformerOutput]: + self, text: Tensor, text_mask: Optional[Tensor] = None, projection: bool = False + ) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]: # TODO(asg): Give proper parameter names when implementing text encoder - return self.text_encoder( + encoded_text = self.text_encoder( input_ids=text, attention_mask=text_mask, ) + if projection: + projected_embeddings = self.text_projection( + encoded_text.last_hidden_state[:, 0, :] + ) + return encoded_text, projected_embeddings + return encoded_text def _encode_data_to_embeddings( self, data: Optional[Tensor], selected_head_encoder: EMBEDDING_OPTIONS, encoder_options: List[EMBEDDING_OPTIONS], - encode_callable: Callable[..., FLAVATransformerOutput], - ) -> Optional[FLAVATransformerOutput]: - output = FLAVATransformerOutput() + encode_callable: Callable[ + ..., + Union[ + Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput] + ], + ], + ) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]: + output: Union[ + Tuple[FLAVATransformerOutput, Tensor], FLAVATransformerOutput + ] = FLAVATransformerOutput() if data is not None and selected_head_encoder in encoder_options: output = encode_callable(data) - return output def encode_mm( @@ -251,9 +305,9 @@ def encode_image( image: Tensor, cls_index: int = 0, ) -> Tensor: - transformer_output = self.model.encode_image(image) - embeddings = transformer_output.last_hidden_state - return self.loss.contrastive_loss.image_projection(embeddings[:, cls_index, :]) + encoded_result = self.model.encode_image(image, projection=True) + encoded_image = encoded_result[1] + return encoded_image def encode_text( self, @@ -261,9 +315,9 @@ def encode_text( text_mask: Optional[Tensor] = None, cls_index: int = 0, ) -> Tensor: - transformer_output = self.model.encode_text(text, text_mask) - embeddings = transformer_output.last_hidden_state - return self.loss.contrastive_loss.text_projection(embeddings[:, cls_index, :]) + encoded_result = self.model.encode_text(text, text_mask, projection=True) + encoded_text = encoded_result[1] + return encoded_text # TODO: Add options to enable losses selectively def forward( @@ -305,6 +359,8 @@ def forward( itm_labels=itm_labels, mim_labels=image_labels, mlm_labels=mlm_labels, + projected_image_embeddings=flava_output.projected_image_embeddings, + projected_text_embeddings=flava_output.projected_text_embeddings, ) @@ -392,6 +448,8 @@ def flava_model( multimodal_intermediate_activation: Callable[..., Tensor] = nn.functional.gelu, multimodal_attention_probs_dropout_prob: float = 0.0, multimodal_layer_norm_eps: float = 1e-12, + # projection + text_and_image_proj_size: int = 768, **kwargs: Any, ) -> FLAVAModel: image_encoder = flava_image_encoder( @@ -437,12 +495,17 @@ def flava_model( image_to_mm_projection = nn.Linear(image_hidden_size, multimodal_hidden_size) text_to_mm_projection = nn.Linear(text_hidden_size, multimodal_hidden_size) + image_projection = nn.Linear(image_hidden_size, text_and_image_proj_size) + text_projection = nn.Linear(text_hidden_size, text_and_image_proj_size) + return FLAVAModel( image_encoder=image_encoder, text_encoder=text_encoder, mm_encoder=mm_encoder, image_to_mm_projection=image_to_mm_projection, text_to_mm_projection=text_to_mm_projection, + text_projection=text_projection, + image_projection=image_projection, ) diff --git a/torchmultimodal/modules/losses/flava.py b/torchmultimodal/modules/losses/flava.py index 9c4e93037..c62d4f5ba 100644 --- a/torchmultimodal/modules/losses/flava.py +++ b/torchmultimodal/modules/losses/flava.py @@ -253,22 +253,16 @@ def __init__( else: self.logit_scale = nn.Parameter(logit_scale * torch.ones([])) - self.image_projection = nn.Linear(image_embedding_size, projection_size) - self.text_projection = nn.Linear(text_embedding_size, projection_size) - self.image_embedding_index = image_embedding_index - self.text_embedding_index = text_embedding_index - def forward( self, image_sequence: Tensor, text_sequence: Tensor, mask: Tensor, ) -> FLAVAGlobalContrastiveLossOutput: - text_embedding = nn.functional.normalize( - self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1 - ) + + text_embedding = nn.functional.normalize(text_sequence, dim=-1) image_embedding = nn.functional.normalize( - self.image_projection(image_sequence[:, self.image_embedding_index, :]), + image_sequence, dim=-1, ) @@ -380,6 +374,8 @@ def forward( itm_labels: Optional[Tensor] = None, mim_labels: Optional[Tensor] = None, mlm_labels: Optional[Tensor] = None, + projected_image_embeddings: Optional[Tensor] = None, + projected_text_embeddings: Optional[Tensor] = None, ) -> FLAVAPretrainingLossOutput: outputs = FLAVAPretrainingLossOutput() pos_mask = None @@ -387,6 +383,7 @@ def forward( # Check multimodal_masked_sequence to make sure this is unimodal case # This specific case can though be backpropagated directly as MIM is independent of # text, but that is a research question :) + if ( image_masked_sequence is not None and self.mim_weight > 0 @@ -466,13 +463,13 @@ def forward( outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss if ( - image_sequence is not None - and text_sequence is not None + projected_image_embeddings is not None + and projected_text_embeddings is not None and self.contrastive_loss_weight > 0 ): outputs.global_contrastive_output = self.contrastive_loss( - image_sequence, - text_sequence, + projected_image_embeddings, + projected_text_embeddings, pos_mask, ) outputs.global_contrastive_output.loss *= self.contrastive_loss_weight