Skip to content

arshadshk/Position-Prediction-Pretraining

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Position-Prediction-Pretraining-Pytorch

Pytorch Implementation of : Position Prediction as an Effective Pretraining Strategy
https://arxiv.org/abs/2207.07611
https://proceedings.mlr.press/v162/zhai22a.html

About

This is an unofficial implementation/replication built on top of DeiT and timm (pytorch-image-transformers) and is limited to vision models. Vision model(ViT) in timm is extended to support Attention Masking, Position Prediction and Position loss, whereas DieT is used for required pretraining configurations that the authors have used( page4 section4.2 ). In this paper, the authors propose simple alternative to content reconstruction – that of predicting locations from content, without providing positional information for it. Doing so requires the Transformer to understand the positional relationships between different parts of the input, from their content alone. This amounts to an efficient implementation where the pretext task is a classification problem among all possible positions for each input token. The paper contains both Vision and Speech benchmarks, but this repo only supports vision models.

Usage

Clone this repo, and then update the submoules/sub-repos. Make sure to use DieT and timm from the submodules provided here, as the original versions dont support attention masking, position prediction and losses that were present in the paper.

    git submodule update --init --recursive     

Model

import torch
import sys
sys.path.append("./pytorch-image-models/")

from timm.models import create_model 

model = create_model(   
                        model_name= "deit_small_patch4_32",
                        pretrained=False,
                        num_classes=100,
                        drop_rate=0.3,
                        drop_path_rate=0.1,
                        drop_block_rate=None,
                        img_size=32,
                        attn_mask=True,
                        eta=0.75,       
                        patch_stride=4
                    )
B,C,H,W = 64,3,32,32
out = model( torch.randn(B,C,H,W) )
print( out.shape )
# torch.Size([64, 65, 64]) 
# (batch, cls + patches, position_logits) 
# N:patches and +1 for cls

Pretraining

Pretrainig on (CIFAR-100) dataset without distillation:

    python3 deit/main.py    --attn_mask true  \
                            --eta 0.5  \
                            --patch_stride 4  \
                            --data-set CIFAR  \
                            --data-path ./data/cifar-100-python \
                            --datadownload true  \
                            --device cpu  \
                            --num_workers 0  \
                            --model deit_small_patch4_32  \
                            --smoothing 0  \
                            --mixup 0  \
                            --cutmix 0  \
                            --distillation-alpha 0  

Pretrainig using distillation: (WIP)

Finetunning (WIP)


Results

Results from paper :

Results from running this repo:
Dataset : CIFAR 100
epochs : 362
eta : 0.75
( $\eta = 0.75$ )
checkpoint link : https://www.kaggle.com/datasets/arshadshaikh7/pretrainedvit
script to plot results :


Citations

@InProceedings{pmlr-v162-zhai22a,
  title = 	 {Position Prediction as an Effective Pretraining Strategy},
  author =       {Zhai, Shuangfei and Jaitly, Navdeep and Ramapuram, Jason and Busbridge, Dan and Likhomanenko, Tatiana and Cheng, Joseph Y and Talbott, Walter and Huang, Chen and Goh, Hanlin and Susskind, Joshua M},
  booktitle = 	 {Proceedings of the 39th International Conference on Machine Learning},
  pages = 	 {26010--26027},
  year = 	 {2022},
  pdf = 	 {https://proceedings.mlr.press/v162/zhai22a/zhai22a.pdf},
  url = 	 {https://proceedings.mlr.press/v162/zhai22a.html},

About

Position Prediction as an Effective Pretraining Strategy

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages