From 6bf3779a064dc72cde48793521a5be151695fc62 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Fri, 19 Jan 2024 20:07:42 -0800 Subject: [PATCH] Add optional field for multimodal pooled embeddings (#519) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/519 MaMMUT (in next diff) could output the multimodal pooled embeddings directly. Reviewed By: ebsmothers, satyanshukla Differential Revision: D52821534 fbshipit-source-id: a1251365384f03dcdadab7d2984fdc5d277ca26b --- tests/models/coca/test_coca_model.py | 4 ++-- torchmultimodal/models/coca/coca_model.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/models/coca/test_coca_model.py b/tests/models/coca/test_coca_model.py index 1c4c66e7..996be8f6 100644 --- a/tests/models/coca/test_coca_model.py +++ b/tests/models/coca/test_coca_model.py @@ -10,7 +10,7 @@ from torchmultimodal.models.coca.coca_model import ( coca_vit, CoCaForPretraining, - CoCaModelOutput, + MultimodalOutput, ) @@ -105,7 +105,7 @@ def expected( ): pooled_val = 0.3536 logit_val = 8.0 - return CoCaModelOutput( + return MultimodalOutput( image_pooled_output=pooled_val * torch.ones(batch_size, attention_pooler_output_dim), text_pooled_output=pooled_val * torch.ones(batch_size, text_output_dim), diff --git a/torchmultimodal/models/coca/coca_model.py b/torchmultimodal/models/coca/coca_model.py index 3a100f7a..f1e42ae9 100644 --- a/torchmultimodal/models/coca/coca_model.py +++ b/torchmultimodal/models/coca/coca_model.py @@ -24,10 +24,11 @@ ) -class CoCaModelOutput(NamedTuple): +class MultimodalOutput(NamedTuple): image_pooled_output: Tensor text_pooled_output: Tensor multimodal_embeddings: Tensor + multimodal_pooled_embeddings: Optional[Tensor] = None class CoCaModel(nn.Module): @@ -67,7 +68,7 @@ def __init__( def forward( self, images: Tensor, texts: Tensor, text_padding_mask: Optional[Tensor] = None - ) -> CoCaModelOutput: + ) -> MultimodalOutput: """ Args: images (Tensor): Tensor of size (bsz, c, h, w) containing image pixels. @@ -75,7 +76,7 @@ def forward( text_padding_mask (Optional[Tensor]): Boolean mask indicating padded tokens. True for unpadded tokens, False for padded tokens. Default: None Returns: - CoCaModelOutput containing pooled image embeddings, text embeddings, + MultimodalOutput containing pooled image embeddings, text embeddings, and multimodal embeddings. """ @@ -122,7 +123,7 @@ def forward( text_tokens, captioning_image_embeddings ) - return CoCaModelOutput( + return MultimodalOutput( contrastive_image_embeddings, contrastive_text_embeddings, multimodal_embeddings,