Skip to content

PyTorch implementation tools for using vision transformer with Google pretrained model. (npz2pth)

License

Notifications You must be signed in to change notification settings

rookiie/Vit_Pretrain_Tools

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Vit_Pretrain_Tools

PyTorch implementation tools for using vision transformer with Google pretrained model. (npz2pth)

Details

  • Transformer.py provides a generic pytorch implementation of Vision Transformer. And the basic module draws on SETR.

  • weights/npz2pth.py help us to convert the pre-trained model provided by Google in NPZ format into the applicable pth format.

Usage

  • Edit the config in Transformer.py to build the vit type you want. More detail setting

  • Edit the config in npz2pth.py. Note that we do not check the output categories number of the classification head. Download the suitable npz file according to your config from Vision Transformer. And this config dict has slight differences compared to config mentioned above.

python npz2pth.py --config=[$CONFIG_DICT] --npz_path=[$NPZ_FILE] --save_path=[$SAVE]
  • Initialize the parameter with the pretrained model. When key 'pretrained' is not setting in the cfg dict, function init_weight() will initialize network randomly.
S_16_224_config = {
        'depth': 12,
        'model_name': 'vit_S_16_224_image21K', # pth name
        'image_size':(512,512), # size of the images fed into network
        'heads': 6,
        'embed_dim': 384,
        'mlp_ratio': 4,
        'patch_size':16,
        'drop_rate': 0.1,
    }
cfg = S_16_224_config
cfg['pretrained'] = './weight/' # root path of pth file

model = build_vit_from_cfg(cfg)
model.init_weight()

About

PyTorch implementation tools for using vision transformer with Google pretrained model. (npz2pth)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages