From 82ea31e51ecb6ccbb612f42264d0fa3325734137 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 18 Jun 2024 08:44:04 -0700 Subject: [PATCH] should suffice --- README.md | 7 +++++ tests/test_titok.py | 23 +++++++++++++++ titok_pytorch/__init__.py | 1 + titok_pytorch/titok.py | 62 ++++++++++++++++++++++++--------------- 4 files changed, 70 insertions(+), 23 deletions(-) create mode 100644 tests/test_titok.py diff --git a/README.md b/README.md index 3121e81..4e15c7e 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,15 @@ loss = titok(images) loss.backward() # after much training +# extract codes for gpt, maskgit, whatever codes = titok.tokenize(images) + +# reconstructing images from codes + +recon_images = titok.codebook_ids_to_images(codes) + +assert recon_images.shape == images.shape ``` ## Citations diff --git a/tests/test_titok.py b/tests/test_titok.py new file mode 100644 index 0000000..d778303 --- /dev/null +++ b/tests/test_titok.py @@ -0,0 +1,23 @@ +import pytest +import torch +from titok_pytorch import TiTokTokenizer + +def test_titok(): + + images = torch.randn(2, 3, 256, 256) + + titok = TiTokTokenizer(dim = 512) + + loss = titok(images) + loss.backward() + + # after much training + # extract codes for gpt, maskgit, whatever + + codes = titok.tokenize(images) + + # reconstructing images from codes + + recon_images = titok.codebook_ids_to_images(codes) + + assert recon_images.shape == images.shape diff --git a/titok_pytorch/__init__.py b/titok_pytorch/__init__.py index e69de29..85eed79 100644 --- a/titok_pytorch/__init__.py +++ b/titok_pytorch/__init__.py @@ -0,0 +1 @@ +from titok_pytorch.titok import TiTokTokenizer diff --git a/titok_pytorch/titok.py b/titok_pytorch/titok.py index 62e745c..c50110f 100644 --- a/titok_pytorch/titok.py +++ b/titok_pytorch/titok.py @@ -1,3 +1,5 @@ +from math import sqrt + import torch from torch import nn import torch.nn.functional as F @@ -21,11 +23,14 @@ def exists(v): def divisible_by(num, den): return (num % den) == 0 -def pack_one(t, pattern): - return pack([t], pattern) +def pack_square_height_width(t): + assert t.ndim == 4 + return rearrange(t, 'b h w d -> b (h w) d') -def unpack_one(t, ps, pattern): - return unpack(t, ps, pattern)[0] +def unpack_square_height_width(t): + assert t.ndim == 3 + hw = int(sqrt(t.shape[1])) + return rearrange(t, 'b (h w) d -> b h w d', h = hw, w = hw) # tokenizer @@ -102,10 +107,37 @@ def __init__( def tokenize(self, images): return self.forward(images, return_codebook_ids = True) + def codebook_ids_to_images(self, token_ids): + codes = self.vq.get_output_from_indices(token_ids) + return self.decode(codes) + + def decode(self, latents): + batch = latents.shape[0] + + # append mask tokens + + mask_tokens = repeat(self.mask_tokens, 'n d -> b n d', b = batch) + + tokens, mask_packed_shape = pack([mask_tokens, latents], 'b * d') + + # decode + + tokens = self.decoder(tokens) + + tokens, _ = unpack(tokens, mask_packed_shape, 'b * d') + + tokens = unpack_square_height_width(tokens) + + # tokens to image patches + + recon = self.tokens_to_image(tokens) + return recon + def forward( self, images, - return_codebook_ids = False + return_codebook_ids = False, + return_recons = False ): batch = images.shape[0] orig_images = images @@ -114,7 +146,7 @@ def forward( tokens = self.image_to_tokens(images) - tokens, height_width_shape = pack_one(tokens, 'b * d') + tokens = pack_square_height_width(tokens) # add absolute positions @@ -146,23 +178,7 @@ def forward( if return_codebook_ids: return indices - # append mask tokens - - mask_tokens = repeat(self.mask_tokens, 'n d -> b n d', b = batch) - - tokens, mask_packed_shape = pack([mask_tokens, quantized], 'b * d') - - # decode - - tokens = self.decoder(tokens) - - tokens, _ = unpack(tokens, mask_packed_shape, 'b * d') - - tokens = unpack_one(tokens, height_width_shape, 'b * d') - - # tokens to image patches - - recon = self.tokens_to_image(tokens) + recon = self.decode(latents) # reconstruction loss