Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Loading state dict in a feature extraction network #2215

Closed
ioangatop opened this issue Jun 26, 2024 · 1 comment
Closed

[BUG] Loading state dict in a feature extraction network #2215

ioangatop opened this issue Jun 26, 2024 · 1 comment
Assignees

Comments

@ioangatop
Copy link

ioangatop commented Jun 26, 2024

Describe the bug

Hi Ross! I'm facing a small issue with the features extractor, here are some details:

The function create_model supports the argument of checkpoint_path which allows to load custom model weights. However, when we want to load a model as feature extractor, the model is wrapped around the FeatureGetterNet class, and the loading fails as the keys do not much anymore; the FeatureGetterNet stores the model under self.model so in order to work, the state dict keys should have a prefix model., for example class_token -> model.class_token

Additionally, one workaround is to do the loading of the model after the initialisation, but this also fails as some networks, like vision transformer, prune some layers and thus the state_dict has extra keys

To Reproduce

from urllib import request

from timm.models import _helpers
import timm


# download weights
request.urlretrieve("https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", "dino_deitsmall16_pretrain.pth")

# build and load model -- works as expected
model = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    checkpoint_path="dino_deitsmall16_pretrain.pth",
)

# RuntimeError: Error(s) in loading state_dict for FeatureGetterNet:
#   Missing key(s) in state_dict: "model.cls_token", "model.pos_embed", ...
#  Unexpected key(s) in state_dict: "cls_token", "pos_embed", ...
backbone = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    features_only=True,
    checkpoint_path="dino_deitsmall16_pretrain.pth",
)

# RuntimeError: Error(s) in loading state_dict for VisionTransformer:
#   Unexpected key(s) in state_dict: "norm.weight", "norm.bias". 
backbone = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    features_only=True,
)
_helpers.load_checkpoint(backbone.model, "dino_deitsmall16_pretrain.pth")

As always, thanks a lot 🙏

@ioangatop ioangatop added the bug Something isn't working label Jun 26, 2024
@rwightman
Copy link
Collaborator

rwightman commented Jun 27, 2024

@ioangatop if you want classifier weights loaded into feature extraction wrapped models, you need to load weights as 'pretrained' so that they are loaded before the model is mutated.

See related discussion, should work with >= 0.9 timm version https://github.com/hugginface/pytorch-image-models/discussions/1941

Although, example in that discussion should be a bit differentl, use the 'overlay' arg as in the train script

if args.pretrained_path:
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
factory_kwargs['pretrained_cfg_overlay'] = dict(
file=args.pretrained_path,
num_classes=-1, # force head adaptation
)

The overlay dict is merged with the models normal pretrained_cfg, the pretrained_cfg arg fully overrides it.

Alternative to using the file key in the pretrained_cfg override dict, you can also use url to download from somewhere else, or hf_hub_id for a HF hub location.

@rwightman rwightman removed the bug Something isn't working label Jul 10, 2024
@huggingface huggingface locked and limited conversation to collaborators Jul 10, 2024
@rwightman rwightman converted this issue into discussion #2227 Jul 10, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants