Skip to content

Commit

Permalink
Add optional field for multimodal pooled embeddings (facebookresearch…
Browse files Browse the repository at this point in the history
…#519)

Summary:
Pull Request resolved: facebookresearch#519

MaMMUT (in next diff) could output the multimodal pooled embeddings directly.

Reviewed By: ebsmothers, satyanshukla

Differential Revision: D52821534

fbshipit-source-id: a1251365384f03dcdadab7d2984fdc5d277ca26b
  • Loading branch information
zhangtemplar authored and facebook-github-bot committed Jan 20, 2024
1 parent 63c629a commit 6bf3779
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tests/models/coca/test_coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchmultimodal.models.coca.coca_model import (
coca_vit,
CoCaForPretraining,
CoCaModelOutput,
MultimodalOutput,
)


Expand Down Expand Up @@ -105,7 +105,7 @@ def expected(
):
pooled_val = 0.3536
logit_val = 8.0
return CoCaModelOutput(
return MultimodalOutput(
image_pooled_output=pooled_val
* torch.ones(batch_size, attention_pooler_output_dim),
text_pooled_output=pooled_val * torch.ones(batch_size, text_output_dim),
Expand Down
9 changes: 5 additions & 4 deletions torchmultimodal/models/coca/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
)


class CoCaModelOutput(NamedTuple):
class MultimodalOutput(NamedTuple):
image_pooled_output: Tensor
text_pooled_output: Tensor
multimodal_embeddings: Tensor
multimodal_pooled_embeddings: Optional[Tensor] = None


class CoCaModel(nn.Module):
Expand Down Expand Up @@ -67,15 +68,15 @@ def __init__(

def forward(
self, images: Tensor, texts: Tensor, text_padding_mask: Optional[Tensor] = None
) -> CoCaModelOutput:
) -> MultimodalOutput:
"""
Args:
images (Tensor): Tensor of size (bsz, c, h, w) containing image pixels.
texts (Tensor): Tensor of size (bsz, seq_len) containing text tokens.
text_padding_mask (Optional[Tensor]): Boolean mask indicating padded tokens.
True for unpadded tokens, False for padded tokens. Default: None
Returns:
CoCaModelOutput containing pooled image embeddings, text embeddings,
MultimodalOutput containing pooled image embeddings, text embeddings,
and multimodal embeddings.
"""

Expand Down Expand Up @@ -122,7 +123,7 @@ def forward(
text_tokens, captioning_image_embeddings
)

return CoCaModelOutput(
return MultimodalOutput(
contrastive_image_embeddings,
contrastive_text_embeddings,
multimodal_embeddings,
Expand Down

0 comments on commit 6bf3779

Please sign in to comment.