Skip to content

Commit

Permalink
Add explicit builders for pretrained encoders to MAE file
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers committed Oct 6, 2023
1 parent 06098cd commit f488945
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions torchmultimodal/models/masked_auto_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
get_2d_sin_cos_embeddings,
)
from torchmultimodal.models.masked_auto_encoder.swin_decoder import SwinTransformer
from torchmultimodal.modules.encoders.vision_transformer import (
VisionTransformer,
vit_b_16,
vit_l_16,
)
from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings
from torchmultimodal.modules.layers.transformer import (
TransformerEncoder,
Expand Down Expand Up @@ -331,6 +336,16 @@ def vit_l_16_image_mae() -> MaskedAutoEncoder:
)


def vit_b_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer:
ckpt_path = MAE_MODEL_MAPPING["vit_b16_image"] if pretrained else None
return vit_b_16(final_layer_norm_eps=None, ckpt_path=ckpt_path)


def vit_l_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer:
ckpt_path = MAE_MODEL_MAPPING["vit_l16_image"] if pretrained else None
return vit_l_16(final_layer_norm_eps=None, ckpt_path=ckpt_path)


def audio_mae(
*,
# patch embedding
Expand Down Expand Up @@ -456,3 +471,13 @@ def vit_l_16_audio_mae() -> MaskedAutoEncoder:
decoder_heads=16,
decoder_dim_feedforward=2048,
)


def vit_b_16_audio_mae_encoder(pretrained: bool = False) -> VisionTransformer:
ckpt_path = MAE_MODEL_MAPPING["vit_b16_audio"] if pretrained else None
return vit_b_16(
final_layer_norm_eps=None,
num_channels=1,
image_size=(1024, 128),
ckpt_path=ckpt_path,
)

0 comments on commit f488945

Please sign in to comment.