Skip to content

Commit

Permalink
Add (almost) full set of aimv2 model instances. Switch back to unpack…
Browse files Browse the repository at this point in the history
…ed SwiGLU. Verify correctness. Add DFN L/14 39B weight.
  • Loading branch information
rwightman committed Dec 30, 2024
1 parent a4146b7 commit 1d6ebeb
Showing 1 changed file with 250 additions and 20 deletions.
270 changes: 250 additions & 20 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
SwiGLU, get_act_layer, get_norm_layer, LayerType
get_act_layer, get_norm_layer, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
Expand Down Expand Up @@ -1159,13 +1159,16 @@ def _convert_aimv2(
k = k.replace('trunk.', '')
k = k.replace('post_trunk_norm.', 'norm.')

if 'mlp.fc1' in k:
if k in out_dict:
v = torch.cat([v, out_dict[k]], dim=0)
elif 'mlp.fc3' in k:
k = k.replace('mlp.fc3', 'mlp.fc1')
if k in out_dict:
v = torch.cat([out_dict[k], v], dim=0)
# packed ver, FIXME to delete
# if 'mlp.fc1' in k:
# if k in out_dict:
# v = torch.cat([v, out_dict[k]], dim=0)
# elif 'mlp.fc3' in k:
# k = k.replace('mlp.fc3', 'mlp.fc1')
# if k in out_dict:
# v = torch.cat([out_dict[k], v], dim=0)
k = k.replace('mlp.fc1', 'mlp.fc1_g')
k = k.replace('mlp.fc3', 'mlp.fc1_x')

out_dict[k] = v

Expand Down Expand Up @@ -1682,18 +1685,27 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:

'vit_base_patch16_clip_224.dfn2b': _cfg(
hf_hub_id='timm/',
license='apple-ascl',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
#hf_hub_id='timm/',
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14-39B', hf_hub_filename='open_clip_pytorch_model.bin',
license='apple-ascl',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_224.dfn2b': _cfg(
hf_hub_id='timm/',
license='apple-ascl',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.dfn5b': _cfg(
hf_hub_id='timm/',
license='apple-ascl',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_huge_patch14_clip_378.dfn5b': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
license='apple-ascl',
notes=('natively QuickGELU, use quickgelu model variant for original results',),
crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),

Expand Down Expand Up @@ -2164,11 +2176,62 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
),

'vit_large_patch14_aimv2_224': _cfg(
'aimv2_large_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-224',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), crop_pct=1.0,
num_classes=0),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_224.apple_pt_dist': _cfg(
hf_hub_id='apple/aimv2-large-patch14-224-distilled',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-224',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-224',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-224',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-336',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_336.apple_pt_dist': _cfg(
hf_hub_id='apple/aimv2-large-patch14-336-distilled',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-336',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-336',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-336',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-448',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-448',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-448',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-448',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),

'test_vit.r160_in1k': _cfg(
hf_hub_id='timm/',
Expand Down Expand Up @@ -3442,17 +3505,171 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran


@register_model
def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Large AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Huge AIM-v2 model
"""

model_args = dict(
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 1B AIM-v2 model
"""
rms_norm = partial(RmsNorm, eps=1e-5)
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False,
mlp_ratio=5.5, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLUPacked,
qkv_bias=False, proj_bias=False, act_layer='silu'
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs))
'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 3B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Large AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Huge AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 1B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 3B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Large AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Huge AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 1B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT 3B AIM-v2 model
"""
model_args = dict(
patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
)
model = _create_vision_transformer(
'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model


Expand Down Expand Up @@ -3487,6 +3704,19 @@ def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
return model


@register_model
def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Test
"""
model_args = dict(
patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3,
class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True,
norm_layer='rmsnorm',
)
model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs))
return model


register_model_deprecations(__name__, {
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
Expand Down

0 comments on commit 1d6ebeb

Please sign in to comment.