From f4889457ff6860a27d6f1e2766b74ff488f6eb8d Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Thu, 5 Oct 2023 18:34:30 -0700 Subject: [PATCH] Add explicit builders for pretrained encoders to MAE file --- .../models/masked_auto_encoder/model.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/torchmultimodal/models/masked_auto_encoder/model.py b/torchmultimodal/models/masked_auto_encoder/model.py index 431ee383..56383b2f 100644 --- a/torchmultimodal/models/masked_auto_encoder/model.py +++ b/torchmultimodal/models/masked_auto_encoder/model.py @@ -13,6 +13,11 @@ get_2d_sin_cos_embeddings, ) from torchmultimodal.models.masked_auto_encoder.swin_decoder import SwinTransformer +from torchmultimodal.modules.encoders.vision_transformer import ( + VisionTransformer, + vit_b_16, + vit_l_16, +) from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings from torchmultimodal.modules.layers.transformer import ( TransformerEncoder, @@ -331,6 +336,16 @@ def vit_l_16_image_mae() -> MaskedAutoEncoder: ) +def vit_b_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_b16_image"] if pretrained else None + return vit_b_16(final_layer_norm_eps=None, ckpt_path=ckpt_path) + + +def vit_l_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_l16_image"] if pretrained else None + return vit_l_16(final_layer_norm_eps=None, ckpt_path=ckpt_path) + + def audio_mae( *, # patch embedding @@ -456,3 +471,13 @@ def vit_l_16_audio_mae() -> MaskedAutoEncoder: decoder_heads=16, decoder_dim_feedforward=2048, ) + + +def vit_b_16_audio_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_b16_audio"] if pretrained else None + return vit_b_16( + final_layer_norm_eps=None, + num_channels=1, + image_size=(1024, 128), + ckpt_path=ckpt_path, + )