From 0d87caeffff595d74109828fad306dd6dcdd25a2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 28 Dec 2024 21:05:38 -0800 Subject: [PATCH] Switch aimv2 to used packed SwiGLU --- timm/layers/mlp.py | 6 +++--- timm/models/vision_transformer.py | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index 09472eede..d1e6774cc 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -83,9 +83,9 @@ def __init__( def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 - fc1_mid = self.fc1.bias.shape[0] // 2 - nn.init.ones_(self.fc1.bias[fc1_mid:]) - nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + if self.fc1.bias is not None: + nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:]) + nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6) def forward(self, x): x = self.fc1(x) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 7935089da..2dbe6ff70 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1150,25 +1150,25 @@ def _convert_aimv2( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, ) -> Dict[str, torch.Tensor]: - #import re out_dict = {} - for k, v in state_dict.items(): k = k.replace('norm_1', 'norm1') k = k.replace('norm_2', 'norm2') k = k.replace('preprocessor.patchifier.', 'patch_embed.') k = k.replace('preprocessor.pos_embed', 'pos_embed') k = k.replace('trunk.', '') - k = k.replace('mlp.fc1', 'mlp.fc1_g') - k = k.replace('mlp.fc3', 'mlp.fc1_x') k = k.replace('post_trunk_norm.', 'norm.') - # if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): - # out_dict[k.replace("w12", "fc1")] = v - # continue - # elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): - # out_dict[k.replace("w3", "fc2")] = v - # continue + + if 'mlp.fc1' in k: + if k in out_dict: + v = torch.cat([v, out_dict[k]], dim=0) + elif 'mlp.fc3' in k: + k = k.replace('mlp.fc3', 'mlp.fc1') + if k in out_dict: + v = torch.cat([out_dict[k], v], dim=0) + out_dict[k] = v + return out_dict def checkpoint_filter_fn( @@ -3448,8 +3448,8 @@ def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTra rms_norm = partial(RmsNorm, eps=1e-5) model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False, - mlp_ratio=2.75, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLU, - qkv_bias=False, proj_bias=False, + mlp_ratio=5.5, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLUPacked, + qkv_bias=False, proj_bias=False, act_layer='silu' ) model = _create_vision_transformer( 'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs))