From 6f32ca1f14ef48d81cc46f8d0f660b00f16debd9 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Fri, 6 Oct 2023 11:57:02 -0700 Subject: [PATCH] Add pretrained MAE weights, option to load checkpoints in ViT builder (#479) Summary: For MAE fine-tuning, fine-tuning occurs just on the encoder (ViT). This change allows easy loading of MAE pretrained weights directly into our ViT class. Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/479 Test Plan: ``` python -m pytest -v tests/models/* ... ========== 207 passed, 25 warnings in 424.67s (0:07:04) =========================== python -m pytest -v tests/modules/* ... ======================== 192 passed, 2 skipped, 22 warnings in 10.75s ========================== ``` Test instantiating ViT using MAE pretrained weights for each of the 3 checkpoints: Screenshot 2023-10-05 at 6 39 02 PM Reviewed By: kartikayk Differential Revision: D50015711 Pulled By: ebsmothers fbshipit-source-id: e09fd02560b31574427b9f66373f12e7fd663f06 --- .../models/masked_auto_encoder/model.py | 32 +++++++++++++++++++ .../modules/encoders/vision_transformer.py | 4 +++ 2 files changed, 36 insertions(+) diff --git a/torchmultimodal/models/masked_auto_encoder/model.py b/torchmultimodal/models/masked_auto_encoder/model.py index e0d77394..56383b2f 100644 --- a/torchmultimodal/models/masked_auto_encoder/model.py +++ b/torchmultimodal/models/masked_auto_encoder/model.py @@ -13,6 +13,11 @@ get_2d_sin_cos_embeddings, ) from torchmultimodal.models.masked_auto_encoder.swin_decoder import SwinTransformer +from torchmultimodal.modules.encoders.vision_transformer import ( + VisionTransformer, + vit_b_16, + vit_l_16, +) from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings from torchmultimodal.modules.layers.transformer import ( TransformerEncoder, @@ -20,6 +25,13 @@ ) +MAE_MODEL_MAPPING = { + "vit_b16_image": "https://download.pytorch.org/models/multimodal/mae/mae_pretrained_vit_base.pth", + "vit_l16_image": "https://download.pytorch.org/models/multimodal/mae/mae_pretrained_vit_large.pth", + "vit_b16_audio": "https://download.pytorch.org/models/multimodal/audio_mae/audio_mae_pretrained_vit_base.pth", +} + + class MAEOutput(NamedTuple): encoder_output: Union[TransformerOutput, Tensor] decoder_pred: Optional[Tensor] = None @@ -324,6 +336,16 @@ def vit_l_16_image_mae() -> MaskedAutoEncoder: ) +def vit_b_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_b16_image"] if pretrained else None + return vit_b_16(final_layer_norm_eps=None, ckpt_path=ckpt_path) + + +def vit_l_16_image_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_l16_image"] if pretrained else None + return vit_l_16(final_layer_norm_eps=None, ckpt_path=ckpt_path) + + def audio_mae( *, # patch embedding @@ -449,3 +471,13 @@ def vit_l_16_audio_mae() -> MaskedAutoEncoder: decoder_heads=16, decoder_dim_feedforward=2048, ) + + +def vit_b_16_audio_mae_encoder(pretrained: bool = False) -> VisionTransformer: + ckpt_path = MAE_MODEL_MAPPING["vit_b16_audio"] if pretrained else None + return vit_b_16( + final_layer_norm_eps=None, + num_channels=1, + image_size=(1024, 128), + ckpt_path=ckpt_path, + ) diff --git a/torchmultimodal/modules/encoders/vision_transformer.py b/torchmultimodal/modules/encoders/vision_transformer.py index f051a490..4e170ee6 100644 --- a/torchmultimodal/modules/encoders/vision_transformer.py +++ b/torchmultimodal/modules/encoders/vision_transformer.py @@ -14,6 +14,7 @@ TransformerEncoder, TransformerOutput, ) +from torchmultimodal.utils.common import load_module_from_url class VisionTransformer(nn.Module): @@ -148,6 +149,7 @@ def vision_transformer( drop_path_rate: Optional[float] = None, patch_drop_rate: Optional[Union[float, Tuple[float, float]]] = None, pooler: Optional[nn.Module] = None, + ckpt_path: str = None, ) -> VisionTransformer: """ Args: @@ -198,6 +200,8 @@ def vision_transformer( vit = VisionTransformer( embeddings=image_embedding, encoder=transformer_encoder, pooler=pooler ) + if ckpt_path: + load_module_from_url(vit, ckpt_path) return vit