Skip to content

Commit

Permalink
Correcting cls token issues and enabling unified ViT-MAE
Browse files Browse the repository at this point in the history
Summary:
1. Made cls token optional for MAE encoder and decoder.
2. Modified reconstruction lightning module to work with unified modalities
3. Added unified configs

Reviewed By: tsungyu

Differential Revision:
D49884612

Privacy Context Container: L1098566

fbshipit-source-id: 4ce66046f5002958b211778e916b3dcdd9b08755
  • Loading branch information
Arkabandhu Chowdhury authored and facebook-github-bot committed Nov 14, 2023
1 parent e6b92b5 commit 4cc6e0d
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions torchmultimodal/models/masked_auto_encoder/position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@


def get_3d_sin_cos_embeddings(
embed_dim: int, temporal_size: int, spatial_size: Tuple[int, int]
embed_dim: int,
temporal_size: int,
spatial_size: Tuple[int, int],
include_cls_embed: bool = True,
) -> Tensor:
"""
3d position sin cos embeddings. This implementation has been adapted from internal
Expand All @@ -20,6 +23,7 @@ def get_3d_sin_cos_embeddings(
embed_dim (int): embedding dimension of the position embedding
temporal_size (int): temporal input dimensions of the grid
spatial_size (Tuple[int, int]): spatial input dimensions of the grid
include_cls_embed (bool): Whether to include positional embedding for [CLS] token. Defaults to True.
return:
embed (Tensor[int]): [1+temporal_size*spatial_size[0]*spatial_size[1], embed_dim] (w/ cls_token)
"""
Expand Down Expand Up @@ -60,17 +64,21 @@ def get_3d_sin_cos_embeddings(
embed = embed.reshape([-1, embed_dim]) # [T*H*W, D]

# Add pos embed for cls token
embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0)
if include_cls_embed:
embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0)
embed = embed.unsqueeze(0)
return embed


def get_2d_sin_cos_embeddings(embed_dim: int, input_size: Tuple[int, int]) -> Tensor:
def get_2d_sin_cos_embeddings(
embed_dim: int, input_size: Tuple[int, int], include_cls_embed: bool = True
) -> Tensor:
"""
2d position sin cos embeddings.
Args:
embed_dim (int): embedding dimension of the position embedding
input_size (Tuple[int, int]): input dimensions of the grid
include_cls_embed (bool): Whether to include positional embedding for [CLS] token. Defaults to True.
"""

# dim gets halved twice, once for h and w axis and once for sin and cos
Expand All @@ -85,7 +93,8 @@ def get_2d_sin_cos_embeddings(embed_dim: int, input_size: Tuple[int, int]) -> Te
# h*w x embed_dim
embed = torch.cat([embed_w, embed_h], dim=1)
# Add pos embed for cls token
embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0)
if include_cls_embed:
embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0)
embed = embed.unsqueeze(0)
return embed

Expand Down

0 comments on commit 4cc6e0d

Please sign in to comment.