diff --git a/torchmultimodal/models/masked_auto_encoder/model.py b/torchmultimodal/models/masked_auto_encoder/model.py index e0d77394..56383b2f 100644 --- a/torchmultimodal/models/masked_auto_encoder/model.py +++ b/torchmultimodal/models/masked_auto_encoder/model.py @@ -13,6 +13,11 @@ get_2d_sin_cos_embeddings, ) from torchmultimodal.models.masked_auto_encoder.swin_decoder import SwinTransformer +from torchmultimodal.modules.encoders.vision_transformer import ( + VisionTransformer, + vit_b_16, + vit_l_16, +) from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings from torchmultimodal.modules.layers.transformer import ( TransformerEncoder, @@ -20,6 +25,13 @@ ) +MAE_MODEL_MAPPING = { + "vit_b16_image": "https://download.pytorch.org/models/multimodal/mae/mae_pretrained_vit_base.pth", + "vit_l16_image": "https://download.pytorch.org/models/multimodal/mae/mae_pretrained_vit_large.pth", + "vit_b16_audio": "https://download.pytorch.org/models/multimodal/audio_mae/audio_mae_pretrained_vit_base.pth", +} + + class MAEOutput(NamedTuple): encoder_output: Union[TransformerOutput, Tensor] decoder_pred: Optional[Tensor] = None @@ -324,6 +336,16 @@ def vit_l_16_image_mae() -> MaskedAutoEncoder: ) +def vit_b_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_b16_image"] if pretrained else None + return vit_b_16(final_layer_norm_eps=None, ckpt_path=ckpt_path) + + +def vit_l_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_l16_image"] if pretrained else None + return vit_l_16(final_layer_norm_eps=None, ckpt_path=ckpt_path) + + def audio_mae( *, # patch embedding @@ -449,3 +471,13 @@ def vit_l_16_audio_mae() -> MaskedAutoEncoder: decoder_heads=16, decoder_dim_feedforward=2048, ) + + +def vit_b_16_audio_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_b16_audio"] if pretrained else None + return vit_b_16( + final_layer_norm_eps=None, + num_channels=1, + image_size=(1024, 128), + ckpt_path=ckpt_path, + ) diff --git a/torchmultimodal/modules/encoders/vision_transformer.py b/torchmultimodal/modules/encoders/vision_transformer.py index f051a490..4e170ee6 100644 --- a/torchmultimodal/modules/encoders/vision_transformer.py +++ b/torchmultimodal/modules/encoders/vision_transformer.py @@ -14,6 +14,7 @@ TransformerEncoder, TransformerOutput, ) +from torchmultimodal.utils.common import load_module_from_url class VisionTransformer(nn.Module): @@ -148,6 +149,7 @@ def vision_transformer( drop_path_rate: Optional[float] = None, patch_drop_rate: Optional[Union[float, Tuple[float, float]]] = None, pooler: Optional[nn.Module] = None, + ckpt_path: str = None, ) -> VisionTransformer: """ Args: @@ -198,6 +200,8 @@ def vision_transformer( vit = VisionTransformer( embeddings=image_embedding, encoder=transformer_encoder, pooler=pooler ) + if ckpt_path: + load_module_from_url(vit, ckpt_path) return vit