From f50d7d1436618941115430c350fefb63b2a07274 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 9 Oct 2024 07:32:25 -0700 Subject: [PATCH] add a hypersphere vit, adapted from https://arxiv.org/abs/2410.01131 --- README.md | 9 ++ setup.py | 2 +- vit_pytorch/normalized_vit.py | 260 ++++++++++++++++++++++++++++++++++ vit_pytorch/rvt.py | 6 +- 4 files changed, 273 insertions(+), 4 deletions(-) create mode 100644 vit_pytorch/normalized_vit.py diff --git a/README.md b/README.md index 390bfe4..ecc405b 100644 --- a/README.md +++ b/README.md @@ -2133,4 +2133,13 @@ Coming from computer vision and new to transformers? Here are some resources tha } ``` +```bibtex +@inproceedings{Loshchilov2024nGPTNT, + title = {nGPT: Normalized Transformer with Representation Learning on the Hypersphere}, + author = {Ilya Loshchilov and Cheng-Ping Hsieh and Simeng Sun and Boris Ginsburg}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273026160} +} +``` + *I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon diff --git a/setup.py b/setup.py index 54afec4..53b727d 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.7.14', + version = '1.8.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/normalized_vit.py b/vit_pytorch/normalized_vit.py new file mode 100644 index 0000000..0c47ff1 --- /dev/null +++ b/vit_pytorch/normalized_vit.py @@ -0,0 +1,260 @@ +import torch +from torch import nn +from torch.nn import Module, ModuleList +import torch.nn.functional as F +import torch.nn.utils.parametrize as parametrize + +from einops import rearrange, reduce +from einops.layers.torch import Rearrange + +# functions + +def exists(v): + return v is not None + +def default(v, d): + return v if exists(v) else d + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def divisible_by(numer, denom): + return (numer % denom) == 0 + +def l2norm(t, dim = -1): + return F.normalize(t, dim = dim, p = 2) + +# for use with parametrize + +class L2Norm(Module): + def __init__(self, dim = -1): + super().__init__() + self.dim = dim + + def forward(self, t): + return l2norm(t, dim = self.dim) + +class NormLinear(Module): + def __init__( + self, + dim, + dim_out, + norm_dim_in = True + ): + super().__init__() + self.linear = nn.Linear(dim, dim_out, bias = False) + + parametrize.register_parametrization( + self.linear, + 'weight', + L2Norm(dim = -1 if norm_dim_in else 0) + ) + + @property + def weight(self): + return self.linear.weight + + def forward(self, x): + return self.linear(x) + +# attention and feedforward + +class Attention(Module): + def __init__( + self, + dim, + *, + dim_head = 64, + heads = 8, + dropout = 0. + ): + super().__init__() + dim_inner = dim_head * heads + self.to_q = NormLinear(dim, dim_inner) + self.to_k = NormLinear(dim, dim_inner) + self.to_v = NormLinear(dim, dim_inner) + + self.dropout = dropout + + self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** 0.25)) + + self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) + self.merge_heads = Rearrange('b h n d -> b n (h d)') + + self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False) + + def forward( + self, + x + ): + q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) + + q, k, v = map(self.split_heads, (q, k, v)) + + # query key rmsnorm + + q, k = map(l2norm, (q, k)) + q, k = (q * self.qk_scale), (k * self.qk_scale) + + # scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16 + + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p = self.dropout if self.training else 0., + scale = 1. + ) + + out = self.merge_heads(out) + return self.to_out(out) + +class FeedForward(Module): + def __init__( + self, + dim, + *, + dim_inner, + dropout = 0. + ): + super().__init__() + dim_inner = int(dim_inner * 2 / 3) + + self.dim = dim + self.dropout = nn.Dropout(dropout) + + self.to_hidden = NormLinear(dim, dim_inner) + self.to_gate = NormLinear(dim, dim_inner) + + self.hidden_scale = nn.Parameter(torch.ones(dim_inner)) + self.gate_scale = nn.Parameter(torch.ones(dim_inner)) + + self.to_out = NormLinear(dim_inner, dim, norm_dim_in = False) + + def forward(self, x): + hidden, gate = self.to_hidden(x), self.to_gate(x) + + hidden = hidden * self.hidden_scale + gate = gate * self.gate_scale * (self.dim ** 0.5) + + hidden = F.silu(gate) * hidden + + hidden = self.dropout(hidden) + return self.to_out(hidden) + +# classes + +class nViT(Module): + """ https://arxiv.org/abs/2410.01131 """ + + def __init__( + self, + *, + image_size, + patch_size, + num_classes, + dim, + depth, + heads, + mlp_dim, + dropout = 0., + channels = 3, + dim_head = 64, + residual_lerp_scale_init = None + ): + super().__init__() + image_height, image_width = pair(image_size) + + # calculate patching related stuff + + assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.' + + patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size) + patch_dim = channels * (patch_size ** 2) + num_patches = patch_height_dim * patch_width_dim + + self.channels = channels + self.patch_size = patch_size + + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + + self.abs_pos_emb = nn.Embedding(num_patches, dim) + + residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth) + + # layers + + self.dim = dim + self.layers = ModuleList([]) + self.residual_lerp_scales = nn.ParameterList([]) + + for _ in range(depth): + self.layers.append(ModuleList([ + Attention(dim, dim_head = dim_head, heads = heads, dropout = dropout), + FeedForward(dim, dim_inner = mlp_dim, dropout = dropout), + ])) + + self.residual_lerp_scales.append(nn.ParameterList([ + nn.Parameter(torch.ones(dim) * residual_lerp_scale_init), + nn.Parameter(torch.ones(dim) * residual_lerp_scale_init), + ])) + + self.logit_scale = nn.Parameter(torch.ones(num_classes)) + + self.to_pred = NormLinear(dim, num_classes) + + @torch.no_grad() + def norm_weights_(self): + for module in self.modules(): + if not isinstance(module, NormLinear): + continue + + normed = module.weight + original = module.linear.parametrizations.weight.original + + original.copy_(normed) + + def forward(self, images): + device = images.device + + tokens = self.to_patch_embedding(images) + + pos_emb = self.abs_pos_emb(torch.arange(tokens.shape[-2], device = device)) + + tokens = l2norm(tokens + pos_emb) + + for (attn, ff), (attn_alpha, ff_alpha) in zip(self.layers, self.residual_lerp_scales): + + attn_out = l2norm(attn(tokens)) + tokens = l2norm(tokens.lerp(attn_out, attn_alpha)) + + ff_out = l2norm(ff(tokens)) + tokens = l2norm(tokens.lerp(ff_out, ff_alpha)) + + pooled = reduce(tokens, 'b n d -> b d', 'mean') + + logits = self.to_pred(pooled) + logits = logits * self.logit_scale * (self.dim ** 0.5) + + return logits + +# quick test + +if __name__ == '__main__': + + v = nViT( + image_size = 256, + patch_size = 16, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 8, + mlp_dim = 2048, + ) + + img = torch.randn(4, 3, 256, 256) + logits = v(img) # (4, 1000) + assert logits.shape == (4, 1000) diff --git a/vit_pytorch/rvt.py b/vit_pytorch/rvt.py index 1d7559c..c46188c 100644 --- a/vit_pytorch/rvt.py +++ b/vit_pytorch/rvt.py @@ -3,14 +3,14 @@ import torch from torch import nn, einsum import torch.nn.functional as F -from torch.cuda.amp import autocast +from torch.amp import autocast from einops import rearrange, repeat from einops.layers.torch import Rearrange # rotary embeddings -@autocast(enabled = False) +@autocast('cuda', enabled = False) def rotate_every_two(x): x = rearrange(x, '... (d j) -> ... d j', j = 2) x1, x2 = x.unbind(dim = -1) @@ -24,7 +24,7 @@ def __init__(self, dim, max_freq = 10): scales = torch.linspace(1., max_freq / 2, self.dim // 4) self.register_buffer('scales', scales) - @autocast(enabled = False) + @autocast('cuda', enabled = False) def forward(self, x): device, dtype, n = x.device, x.dtype, int(sqrt(x.shape[-2]))