Skip to content

A Contrastive Learning Boost from Intermediate Pre-Trained Representations

License

Notifications You must be signed in to change notification settings

ml-jku/MIM-Refiner

Repository files navigation

MIM-Refiner

[#1 ImageNet-1K SSL (without extra data)] [#1 ImageNet-1K Clustering (without extra data)]

[Project Page] [Paper] [Models] [Codebase Demo Video] [Model Training Demo Video] [BibTeX]

Pytorch implementation and pre-trained models of MIM-Refiner.

mimrefiner_schematic

MIM-Refiner efficiently combines the advantages of MIM and ID models and surpasses previous state-of-the-art methods while being easy to scale up to extremely large models.

mimrefiner_spider

Pre-trained Models

Pre-trained models can be found here

They can also be loaded via torchhub:

import torch

# MAE
model = torch.hub.load("ml-jku/MIM-Refiner", "mae_refined_l16")
model = torch.hub.load("ml-jku/MIM-Refiner", "mae_refined_h14")
model = torch.hub.load("ml-jku/MIM-Refiner", "mae_refined_twob14")
# D2V2
model = torch.hub.load("ml-jku/MIM-Refiner", "d2v2_refined_l16")
model = torch.hub.load("ml-jku/MIM-Refiner", "d2v2_refined_h14")
# dBOT
model = torch.hub.load("ml-jku/MIM-Refiner", "dbot_refined_l16")
model = torch.hub.load("ml-jku/MIM-Refiner", "dbot_refined_h14")
# CrossMAE
model = torch.hub.load("ml-jku/MIM-Refiner", "crossmae_refined_l16")

An example how to use torchhub models for a k-NN classifier can be found here.

python eval_knn_torchhub.py --model mae_refined_l16 --data_train /imagenet/train/ --data_test /imagenet/val

Note that the results of this script can differ slightly from the the paper results as the paper results remove the last LayerNorm of pre-norm ViTs and use bfloat16 precision.

Train your own models

Instructions to setup the codebase on your own environment are provided in SETUP_CODE, SETUP_DATA and SETUP_MODELS.

A video to motivate design choices of the codebase and give an overview of the codebase can be found here.

Configurations to train, evaluate or analyze models can be found here. Note that MIM-Refiner is trained in 2 stages. "stage 2" trains only the ID heads with a frozen encoder, to ensure a good and stable learning signal for "stage 3" where the encoder is then trained. "stage 2" needs significantly less compute resources and can also be used to get a quick estimate if hyperparameters are suited (temperature, head learning rate, ...).

A demo that showcases how to train models can be found here.

As the trained models are quite large, they also need a bunch of memory and therefore need multi-GPU/multi-node training. If memory is a bottleneck, there are multiple tradeoffs (see this issue).

External Evaluation Frameworks

The evaluations of VTAB-1K were done with this codebase by loading the pre-trained models from torchhub.

The evaluations for ADE20K semantic segmentation were done with this codebase by loading the pre-trained models from torchhub.

Citation

If you like our work, please consider giving it a star ⭐ and cite us

@article{alkin2024mimrefiner,
      title={{MIM-Refiner}: A Contrastive Learning Boost from Intermediate Pre-Trained Representations}, 
      author={Benedikt Alkin and Lukas Miklautz and Sepp Hochreiter and Johannes Brandstetter},
      journal={arXiv preprint arXiv:2402.10093},
      year={2024}
}

About

A Contrastive Learning Boost from Intermediate Pre-Trained Representations

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published