From 8cf653a03549bfeabf0d8b88ed4024da15649167 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jun 2024 13:16:39 -0700 Subject: [PATCH] Add MobileCLIP-B & conversion. Add ViTamin configs. Some refactoring of transformer module. * Move NLD -> LND transpose into Transformer module forward(). * Started working on CustomTransformer for MobileCLIP-S0 text-tower but scope too large. Leaving CustomTransformer in for potential use in future. --- src/open_clip/convert.py | 17 +- src/open_clip/model.py | 2 - src/open_clip/model_configs/MobileCLIP-B.json | 21 +++ ...{mobileclip_s1.json => MobileCLIP-S1.json} | 0 ...{mobileclip_s2.json => MobileCLIP-S2.json} | 0 .../model_configs/ViTamin-B-LTT.json | 20 +++ src/open_clip/model_configs/ViTamin-B.json | 20 +++ .../model_configs/ViTamin-L-256.json | 20 +++ .../model_configs/ViTamin-L-336.json | 20 +++ src/open_clip/model_configs/ViTamin-L.json | 20 +++ .../model_configs/ViTamin-L2-256.json | 20 +++ .../model_configs/ViTamin-L2-336.json | 20 +++ src/open_clip/model_configs/ViTamin-L2.json | 20 +++ .../model_configs/ViTamin-S-LTT.json | 20 +++ src/open_clip/model_configs/ViTamin-S.json | 20 +++ .../model_configs/ViTamin-XL-256.json | 20 +++ .../model_configs/ViTamin-XL-336.json | 20 +++ .../model_configs/ViTamin-XL-384.json | 20 +++ src/open_clip/pretrained.py | 62 ++++++- src/open_clip/transformer.py | 158 ++++++++++++++---- 20 files changed, 476 insertions(+), 44 deletions(-) create mode 100644 src/open_clip/model_configs/MobileCLIP-B.json rename src/open_clip/model_configs/{mobileclip_s1.json => MobileCLIP-S1.json} (100%) rename src/open_clip/model_configs/{mobileclip_s2.json => MobileCLIP-S2.json} (100%) create mode 100644 src/open_clip/model_configs/ViTamin-B-LTT.json create mode 100644 src/open_clip/model_configs/ViTamin-B.json create mode 100644 src/open_clip/model_configs/ViTamin-L-256.json create mode 100644 src/open_clip/model_configs/ViTamin-L-336.json create mode 100644 src/open_clip/model_configs/ViTamin-L.json create mode 100644 src/open_clip/model_configs/ViTamin-L2-256.json create mode 100644 src/open_clip/model_configs/ViTamin-L2-336.json create mode 100644 src/open_clip/model_configs/ViTamin-L2.json create mode 100644 src/open_clip/model_configs/ViTamin-S-LTT.json create mode 100644 src/open_clip/model_configs/ViTamin-S.json create mode 100644 src/open_clip/model_configs/ViTamin-XL-256.json create mode 100644 src/open_clip/model_configs/ViTamin-XL-336.json create mode 100644 src/open_clip/model_configs/ViTamin-XL-384.json diff --git a/src/open_clip/convert.py b/src/open_clip/convert.py index 0bfe35112..84571e0f1 100644 --- a/src/open_clip/convert.py +++ b/src/open_clip/convert.py @@ -139,11 +139,14 @@ def _convert_openclip_txt(module: TextTransformer, prefix): @torch.no_grad() -def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict): - from timm.models.fastvit import _checkpoint_filter_fn - - def _convert_timm_img(state_dict, prefix='image_encoder.'): - timm_state_dict = _checkpoint_filter_fn(state_dict, model.visual.trunk) +def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): + + def _convert_timm_img(state_dict): + if fastvit: + from timm.models.fastvit import checkpoint_filter_fn + else: + from timm.models.vision_transformer_hybrid import checkpoint_filter_fn + timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} return timm_state_dict @@ -181,5 +184,7 @@ def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported) state_dict = convert_mobile_clip_state_dict(model, state_dict) - + if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: + # convert b model + state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) return state_dict diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 469d7f5a9..5a0fc935f 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -272,9 +272,7 @@ def encode_text(self, text, normalize: bool = False): x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) - x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=self.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] x, _ = text_global_pool(x, text, self.text_pool_type) if self.text_projection is not None: diff --git a/src/open_clip/model_configs/MobileCLIP-B.json b/src/open_clip/model_configs/MobileCLIP-B.json new file mode 100644 index 000000000..9907d86b3 --- /dev/null +++ b/src/open_clip/model_configs/MobileCLIP-B.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_base_mci_224", + "timm_model_pretrained": false, + "timm_pool": "token", + "timm_proj": null, + "timm_drop": 0.0, + "timm_drop_path": 0.0, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "no_causal_mask": false + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/mobileclip_s1.json b/src/open_clip/model_configs/MobileCLIP-S1.json similarity index 100% rename from src/open_clip/model_configs/mobileclip_s1.json rename to src/open_clip/model_configs/MobileCLIP-S1.json diff --git a/src/open_clip/model_configs/mobileclip_s2.json b/src/open_clip/model_configs/MobileCLIP-S2.json similarity index 100% rename from src/open_clip/model_configs/mobileclip_s2.json rename to src/open_clip/model_configs/MobileCLIP-S2.json diff --git a/src/open_clip/model_configs/ViTamin-B-LTT.json b/src/open_clip/model_configs/ViTamin-B-LTT.json new file mode 100644 index 000000000..775621409 --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-B-LTT.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_base_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-B.json b/src/open_clip/model_configs/ViTamin-B.json new file mode 100644 index 000000000..bf09a8e69 --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-B.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vitamin_base_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-L-256.json b/src/open_clip/model_configs/ViTamin-L-256.json new file mode 100644 index 000000000..66990842e --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-L-256.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_large_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-L-336.json b/src/open_clip/model_configs/ViTamin-L-336.json new file mode 100644 index 000000000..63aa8cebe --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-L-336.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_large_336", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 336 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-L.json b/src/open_clip/model_configs/ViTamin-L.json new file mode 100644 index 000000000..c74e56e9d --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-L.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_large_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-L2-256.json b/src/open_clip/model_configs/ViTamin-L2-256.json new file mode 100644 index 000000000..68465befb --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-L2-256.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "vitamin_large2_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-L2-336.json b/src/open_clip/model_configs/ViTamin-L2-336.json new file mode 100644 index 000000000..4b48a5263 --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-L2-336.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "vitamin_large2_336", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 336 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-L2.json b/src/open_clip/model_configs/ViTamin-L2.json new file mode 100644 index 000000000..3d14b7109 --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-L2.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "vitamin_large2_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-S-LTT.json b/src/open_clip/model_configs/ViTamin-S-LTT.json new file mode 100644 index 000000000..b01c95b41 --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-S-LTT.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "vitamin_small_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-S.json b/src/open_clip/model_configs/ViTamin-S.json new file mode 100644 index 000000000..1fb6cd24a --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-S.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "timm_model_name": "vitamin_small_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-XL-256.json b/src/open_clip/model_configs/ViTamin-XL-256.json new file mode 100644 index 000000000..68f672f0c --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-XL-256.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1152, + "vision_cfg": { + "timm_model_name": "vitamin_xlarge_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1152, + "heads": 16, + "layers": 27 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-XL-336.json b/src/open_clip/model_configs/ViTamin-XL-336.json new file mode 100644 index 000000000..116c30e73 --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-XL-336.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1152, + "vision_cfg": { + "timm_model_name": "vitamin_xlarge_336", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 336 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1152, + "heads": 16, + "layers": 27 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViTamin-XL-384.json b/src/open_clip/model_configs/ViTamin-XL-384.json new file mode 100644 index 000000000..3070f70e7 --- /dev/null +++ b/src/open_clip/model_configs/ViTamin-XL-384.json @@ -0,0 +1,20 @@ +{ + "embed_dim": 1152, + "vision_cfg": { + "timm_model_name": "vitamin_xlarge_384", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1152, + "heads": 16, + "layers": 27 + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index 399f9bdae..a7ed923a1 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -454,10 +454,64 @@ def _mccfg(url='', hf_hub='', **kwargs): mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), ), - "mobileclip_s1": dict( - datacompdr=_mccfg(url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt')), - "mobileclip_s2": dict( - datacompdr=_mccfg(url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt')) + "MobileCLIP-S1": dict( + datacompdr=_mccfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s1.pt')), + "MobileCLIP-S2": dict( + datacompdr=_mccfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_s2.pt')), + "MobileCLIP-B": dict( + datacompdr=_mccfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt'), + datacompdr_lt=_mccfg( + url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt'), + ), + + "ViTamin-S": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'), + ), + "ViTamin-S-LTT": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'), + ), + "ViTamin-B": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'), + ), + "ViTamin-B-LTT": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'), + ), + "ViTamin-L": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'), + ), + "ViTamin-L-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'), + ), + "ViTamin-L-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'), + ), + "ViTamin-L-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'), + ), + "ViTamin-L2": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'), + ), + "ViTamin-L2-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'), + ), + "ViTamin-L2-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'), + ), + "ViTamin-L2-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'), + ), + "ViTamin-XL-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'), + ), + "ViTamin-XL-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'), + ), + "ViTamin-XL-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'), + ), } diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 6d4e604d8..e76ec4328 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -1,6 +1,6 @@ from collections import OrderedDict import math -from typing import Callable, Optional, Sequence, Tuple +from typing import Callable, List, Optional, Sequence, Tuple, Union from functools import partial import torch @@ -89,14 +89,15 @@ def forward(self, x): class Attention(nn.Module): def __init__( self, - dim, - num_heads=8, - qkv_bias=True, - scaled_cosine=False, - scale_heads=False, - logit_scale_max=math.log(1. / 0.01), - attn_drop=0., - proj_drop=0. + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + scaled_cosine: bool = False, + scale_heads: bool = False, + logit_scale_max: float = math.log(1. / 0.01), + batch_first: bool = True, + attn_drop: float = 0., + proj_drop: float = 0. ): super().__init__() self.scaled_cosine = scaled_cosine @@ -106,6 +107,8 @@ def __init__( self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.logit_scale_max = logit_scale_max + self.batch_first = batch_first + self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention') # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) @@ -127,36 +130,55 @@ def __init__( self.out_drop = nn.Dropout(proj_drop) def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + if self.batch_first: + x = x.transpose(0, 1) + L, N, C = x.shape q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1) + k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1) + v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1) + + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() attn = attn.view(N, self.num_heads, L, L) * logit_scale attn = attn.view(-1, L, L) + if attn_mask is not None: + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = torch.bmm(attn, v) else: - q = q * self.scale - attn = torch.bmm(q, k.transpose(-1, -2)) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, float("-inf")) - attn_mask = new_attn_mask - attn += attn_mask - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.use_fsdpa: + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + if attn_mask is not None: + attn += attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = torch.bmm(attn, v) - x = torch.bmm(attn, v) if self.head_scale is not None: x = x.view(N, self.num_heads, L, C) * self.head_scale x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + + if self.batch_first: + x = x.transpose(0, 1) + x = self.out_proj(x) x = self.out_drop(x) return x @@ -237,7 +259,6 @@ def forward( ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None - x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -256,6 +277,7 @@ def __init__( scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, + batch_first: bool = True, ): super().__init__() @@ -264,6 +286,7 @@ def __init__( d_model, n_head, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, + batch_first=batch_first, ) self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() @@ -278,6 +301,9 @@ def __init__( ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + def get_reference_weight(self): + return self.mlp.c_fc.weight + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.ls_2(self.mlp(self.ln_2(x))) @@ -306,7 +332,13 @@ def __init__( self.resblocks = nn.ModuleList([ ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) for _ in range(layers) ]) @@ -316,12 +348,79 @@ def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x.transpose(0, 1).contiguous() # NLD -> LND + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + x = x.transpose(0, 1) # LND -> NLD + return x + + +class CustomTransformer(nn.Module): + """ A custom transformer that can use different block types. """ + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + batch_first: bool = True, + block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock', + ): + super().__init__() + self.width = width + self.layers = layers + self.batch_first = batch_first # run trasnformer stack in batch first (N, L, D) + self.grad_checkpointing = False + + if isinstance(block_types, str): + block_types = [block_types] * layers + assert len(block_types) == layers + + def _create_block(bt: str): + if bt == 'CustomResidualAttentionBlock': + return CustomResidualAttentionBlock( + width, + heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + batch_first=batch_first, + ) + else: + assert False + + self.resblocks = nn.ModuleList([ + _create_block(bt) + for bt in block_types + ]) + + def get_cast_dtype(self) -> torch.dtype: + weight = self.resblocks[0].get_reference_weight() + if hasattr(weight, 'int8_original_dtype'): + return weight.int8_original_dtype + return weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + if not self.batch_first: + x = x.transpose(0, 1) # NLD -> LND + for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) + + if not self.batch_first: + x = x.transpose(0, 1) # NLD -> LND return x @@ -511,10 +610,7 @@ def forward(self, x: torch.Tensor): x = self.patch_dropout(x) x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD if self.attn_pool is not None: if self.attn_pool_contrastive is not None: @@ -683,9 +779,7 @@ def forward(self, text): attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] x = x + self.positional_embedding[:seq_len].to(cast_dtype) - x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x, attn_mask=attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] if self.cls_emb is not None: