Skip to content

Commit

Permalink
Switch hf hub entries for new aimv2 / dfn weights to point to timm lo…
Browse files Browse the repository at this point in the history
…cations. Undo forced device for SDR linspace, part of another change.
  • Loading branch information
rwightman committed Dec 31, 2024
1 parent cc7fd34 commit b0068ba
Showing 1 changed file with 17 additions and 28 deletions.
45 changes: 17 additions & 28 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def __init__(
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device='cpu')] # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
Expand Down Expand Up @@ -1158,22 +1158,12 @@ def _convert_aimv2(
k = k.replace('preprocessor.pos_embed', 'pos_embed')
k = k.replace('trunk.', '')
k = k.replace('post_trunk_norm.', 'norm.')

# packed ver, FIXME to delete
# if 'mlp.fc1' in k:
# if k in out_dict:
# v = torch.cat([v, out_dict[k]], dim=0)
# elif 'mlp.fc3' in k:
# k = k.replace('mlp.fc3', 'mlp.fc1')
# if k in out_dict:
# v = torch.cat([out_dict[k], v], dim=0)
k = k.replace('mlp.fc1', 'mlp.fc1_g')
k = k.replace('mlp.fc3', 'mlp.fc1_x')

out_dict[k] = v

return out_dict


def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor],
model: VisionTransformer,
Expand Down Expand Up @@ -1688,8 +1678,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
license='apple-ascl',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
#hf_hub_id='timm/',
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14-39B', hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
license='apple-ascl',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_224.dfn2b': _cfg(
Expand Down Expand Up @@ -2177,59 +2166,59 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
),

'aimv2_large_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-224',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_224.apple_pt_dist': _cfg(
hf_hub_id='apple/aimv2-large-patch14-224-distilled',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-224',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-224',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-224',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-336',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_336.apple_pt_dist': _cfg(
hf_hub_id='apple/aimv2-large-patch14-336-distilled',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-336',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-336',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-336',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-448',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-448',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-448',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-448',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),

Expand Down

0 comments on commit b0068ba

Please sign in to comment.