diff --git a/torchmultimodal/models/masked_auto_encoder/position_embeddings.py b/torchmultimodal/models/masked_auto_encoder/position_embeddings.py index 2d83ba2c..ee4f857f 100644 --- a/torchmultimodal/models/masked_auto_encoder/position_embeddings.py +++ b/torchmultimodal/models/masked_auto_encoder/position_embeddings.py @@ -11,7 +11,10 @@ def get_3d_sin_cos_embeddings( - embed_dim: int, temporal_size: int, spatial_size: Tuple[int, int] + embed_dim: int, + temporal_size: int, + spatial_size: Tuple[int, int], + include_cls_embed: bool = True, ) -> Tensor: """ 3d position sin cos embeddings. This implementation has been adapted from internal @@ -20,6 +23,7 @@ def get_3d_sin_cos_embeddings( embed_dim (int): embedding dimension of the position embedding temporal_size (int): temporal input dimensions of the grid spatial_size (Tuple[int, int]): spatial input dimensions of the grid + include_cls_embed (bool): Whether to include positional embedding for [CLS] token. Defaults to True. return: embed (Tensor[int]): [1+temporal_size*spatial_size[0]*spatial_size[1], embed_dim] (w/ cls_token) """ @@ -60,17 +64,21 @@ def get_3d_sin_cos_embeddings( embed = embed.reshape([-1, embed_dim]) # [T*H*W, D] # Add pos embed for cls token - embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0) + if include_cls_embed: + embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0) embed = embed.unsqueeze(0) return embed -def get_2d_sin_cos_embeddings(embed_dim: int, input_size: Tuple[int, int]) -> Tensor: +def get_2d_sin_cos_embeddings( + embed_dim: int, input_size: Tuple[int, int], include_cls_embed: bool = True +) -> Tensor: """ 2d position sin cos embeddings. Args: embed_dim (int): embedding dimension of the position embedding input_size (Tuple[int, int]): input dimensions of the grid + include_cls_embed (bool): Whether to include positional embedding for [CLS] token. Defaults to True. """ # dim gets halved twice, once for h and w axis and once for sin and cos @@ -85,7 +93,8 @@ def get_2d_sin_cos_embeddings(embed_dim: int, input_size: Tuple[int, int]) -> Te # h*w x embed_dim embed = torch.cat([embed_w, embed_h], dim=1) # Add pos embed for cls token - embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0) + if include_cls_embed: + embed = torch.cat([torch.zeros(1, embed_dim), embed], dim=0) embed = embed.unsqueeze(0) return embed