From 4cc6e0d263a5c97ca0732add541755ef8f0d105a Mon Sep 17 00:00:00 2001 From: Arkabandhu Chowdhury Date: Tue, 14 Nov 2023 09:19:36 -0800 Subject: [PATCH] Correcting cls token issues and enabling unified ViT-MAE Summary: 1. Made cls token optional for MAE encoder and decoder. 2. Modified reconstruction lightning module to work with unified modalities 3. Added unified configs Reviewed By: tsungyu Differential Revision: D49884612 Privacy Context Container: L1098566 fbshipit-source-id: 4ce66046f5002958b211778e916b3dcdd9b08755 --- .../masked_auto_encoder/position_embeddings.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/torchmultimodal/models/masked_auto_encoder/position_embeddings.py b/torchmultimodal/models/masked_auto_encoder/position_embeddings.py index 2d83ba2cd..ee4f857fa 100644 --- a/torchmultimodal/models/masked_auto_encoder/position_embeddings.py +++ b/torchmultimodal/models/masked_auto_encoder/position_embeddings.py @@ -11,7 +11,10 @@ def get_3d_sin_cos_embeddings( - embed_dim: int, temporal_size: int, spatial_size: Tuple[int, int] + embed_dim: int, + temporal_size: int, + spatial_size: Tuple[int, int], + include_cls_embed: bool = True, ) -> Tensor: """ 3d position sin cos embeddings. This implementation has been adapted from internal @@ -20,6 +23,7 @@ def get_3d_sin_cos_embeddings( embed_dim (int): embedding dimension of the position embedding temporal_size (int): temporal input dimensions of the grid spatial_size (Tuple[int, int]): spatial input dimensions of the grid + include_cls_embed (bool): Whether to include positional embedding for [CLS] token. Defaults to True. return: embed (Tensor[int]): [1+temporal_size*spatial_size[0]*spatial_size[1], embed_dim] (w/ cls_token) """ @@ -60,17 +64,21 @@ def get_3d_sin_cos_embeddings( embed = embed.reshape([-1, embed_dim]) # [T*H*W, D] # Add pos embed for cls token - embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0) + if include_cls_embed: + embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0) embed = embed.unsqueeze(0) return embed -def get_2d_sin_cos_embeddings(embed_dim: int, input_size: Tuple[int, int]) -> Tensor: +def get_2d_sin_cos_embeddings( + embed_dim: int, input_size: Tuple[int, int], include_cls_embed: bool = True +) -> Tensor: """ 2d position sin cos embeddings. Args: embed_dim (int): embedding dimension of the position embedding input_size (Tuple[int, int]): input dimensions of the grid + include_cls_embed (bool): Whether to include positional embedding for [CLS] token. Defaults to True. """ # dim gets halved twice, once for h and w axis and once for sin and cos @@ -85,7 +93,8 @@ def get_2d_sin_cos_embeddings(embed_dim: int, input_size: Tuple[int, int]) -> Te # h*w x embed_dim embed = torch.cat([embed_w, embed_h], dim=1) # Add pos embed for cls token - embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0) + if include_cls_embed: + embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0) embed = embed.unsqueeze(0) return embed