diff --git a/README.md b/README.md index 7a735ef..5e3ecf1 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,20 @@ Implementation of TiTok, proposed by Bytedance in An Image is Worth 32 Tokens for Reconstruction and Generation +## Usage + +```python +import torch +from titok_pytorch.titok import TiTokTokenizer + +images = torch.randn(2, 3, 256, 256) + +tokenizer = TiTokTokenizer(dim = 512) + +loss = tokenizer(images) +loss.backward() +``` + ## Citations ```bibtex diff --git a/pyproject.toml b/pyproject.toml index f3e0dc5..c20f6e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,8 +25,10 @@ classifiers=[ ] dependencies = [ + "einops>=0.8.0", "torch>=2.0", - "einops>=0.8.0" + "x-transformers>=1.30.20", + "vector-quantize-pytorch>=1.14.26" ] [project.urls] diff --git a/titok_pytorch/titok.py b/titok_pytorch/titok.py index e69de29..57eb13e 100644 --- a/titok_pytorch/titok.py +++ b/titok_pytorch/titok.py @@ -0,0 +1,161 @@ +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn import Module, ModuleList + +from einops.layers.torch import Rearrange +from einops import rearrange, repeat, pack, unpack + +from vector_quantize_pytorch import ( + VectorQuantize as VQ, + LFQ +) + +from x_transformers import Encoder + +# helpers + +def exists(v): + return v is not None + +def divisible_by(num, den): + return (num % den) == 0 + +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + +# tokenizer + +class TiTokTokenizer(Module): + def __init__( + self, + *, + dim, + image_size = 256, + patch_size = 32, + channels = 3, + num_latent_tokens = 32, + enc_depth = 6, + dec_depth = 6, + codebook_size = 8192, + enc_kwargs: dict = dict(), + dec_kwargs: dict = dict(), + vq_kwargs: dict = dict() + ): + super().__init__() + """ + ein notation: + b - batch + c - channels + p - patch + h - height + w - width + l - latents + """ + + assert divisible_by(image_size, patch_size) + + dim_patch = channels * patch_size ** 2 + num_tokens = (image_size // patch_size) ** 2 + + self.latents = nn.Parameter(torch.zeros(num_latent_tokens, dim)) + self.pos_emb = nn.Parameter(torch.zeros(num_tokens, dim)) + self.mask_tokens = nn.Parameter(torch.zeros(num_tokens, dim)) + + nn.init.normal_(self.latents, std = 0.02) + nn.init.normal_(self.pos_emb, std = 0.02) + nn.init.normal_(self.mask_tokens, std = 0.02) + + self.image_to_tokens = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b h w (c p1 p2)', p1 = patch_size, p2 = patch_size), + nn.Linear(dim_patch, dim) + ) + + self.encoder = Encoder( + dim = dim, + depth = enc_depth, + **enc_kwargs + ) + + self.vq = VQ( + dim = dim, + codebook_dim = dim, + codebook_size = codebook_size, + **vq_kwargs + ) + + self.decoder = Encoder( + dim = dim, + depth = dec_depth, + **dec_kwargs + ) + + self.tokens_to_image = nn.Sequential( + nn.Linear(dim, dim_patch), + Rearrange('b h w (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size) + ) + + def forward( + self, + images + ): + batch = images.shape[0] + orig_images = images + + # image patches to tokens + + tokens = self.image_to_tokens(images) + + tokens, height_width_shape = pack_one(tokens, 'b * d') + + # add absolute positions + + pos_emb = repeat(self.pos_emb, 'n d -> b n d', b = batch) + + tokens = tokens + pos_emb + + # concat latents + + latents = repeat(self.latents, 'l d -> b l d', b = batch) + + tokens, latents_packed_shape = pack([tokens, latents], 'b * d') + + # encoder + + tokens = self.encoder(tokens) + + # slice out latents and pass through vq as codes + # this is the important line of code and main proposal of the paper + + _, latents = unpack(tokens, latents_packed_shape, 'b * d') + + # vq + + quantized, indices, _ = self.vq(latents) + + # append mask tokens + + mask_tokens = repeat(self.mask_tokens, 'n d -> b n d', b = batch) + + tokens, mask_packed_shape = pack([mask_tokens, quantized], 'b * d') + + # decode + + tokens = self.decoder(tokens) + + tokens, _ = unpack(tokens, mask_packed_shape, 'b * d') + + tokens = unpack_one(tokens, height_width_shape, 'b * d') + + # tokens to image patches + + recon = self.tokens_to_image(tokens) + + # reconstruction loss + + recon_loss = F.mse_loss(recon, orig_images) + + return recon_loss