Skip to content

Commit

Permalink
should suffice
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 18, 2024
1 parent 84176c2 commit 82ea31e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 23 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,15 @@ loss = titok(images)
loss.backward()

# after much training
# extract codes for gpt, maskgit, whatever

codes = titok.tokenize(images)

# reconstructing images from codes

recon_images = titok.codebook_ids_to_images(codes)

assert recon_images.shape == images.shape
```

## Citations
Expand Down
23 changes: 23 additions & 0 deletions tests/test_titok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
import torch
from titok_pytorch import TiTokTokenizer

def test_titok():

images = torch.randn(2, 3, 256, 256)

titok = TiTokTokenizer(dim = 512)

loss = titok(images)
loss.backward()

# after much training
# extract codes for gpt, maskgit, whatever

codes = titok.tokenize(images)

# reconstructing images from codes

recon_images = titok.codebook_ids_to_images(codes)

assert recon_images.shape == images.shape
1 change: 1 addition & 0 deletions titok_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from titok_pytorch.titok import TiTokTokenizer
62 changes: 39 additions & 23 deletions titok_pytorch/titok.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from math import sqrt

import torch
from torch import nn
import torch.nn.functional as F
Expand All @@ -21,11 +23,14 @@ def exists(v):
def divisible_by(num, den):
return (num % den) == 0

def pack_one(t, pattern):
return pack([t], pattern)
def pack_square_height_width(t):
assert t.ndim == 4
return rearrange(t, 'b h w d -> b (h w) d')

def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def unpack_square_height_width(t):
assert t.ndim == 3
hw = int(sqrt(t.shape[1]))
return rearrange(t, 'b (h w) d -> b h w d', h = hw, w = hw)

# tokenizer

Expand Down Expand Up @@ -102,10 +107,37 @@ def __init__(
def tokenize(self, images):
return self.forward(images, return_codebook_ids = True)

def codebook_ids_to_images(self, token_ids):
codes = self.vq.get_output_from_indices(token_ids)
return self.decode(codes)

def decode(self, latents):
batch = latents.shape[0]

# append mask tokens

mask_tokens = repeat(self.mask_tokens, 'n d -> b n d', b = batch)

tokens, mask_packed_shape = pack([mask_tokens, latents], 'b * d')

# decode

tokens = self.decoder(tokens)

tokens, _ = unpack(tokens, mask_packed_shape, 'b * d')

tokens = unpack_square_height_width(tokens)

# tokens to image patches

recon = self.tokens_to_image(tokens)
return recon

def forward(
self,
images,
return_codebook_ids = False
return_codebook_ids = False,
return_recons = False
):
batch = images.shape[0]
orig_images = images
Expand All @@ -114,7 +146,7 @@ def forward(

tokens = self.image_to_tokens(images)

tokens, height_width_shape = pack_one(tokens, 'b * d')
tokens = pack_square_height_width(tokens)

# add absolute positions

Expand Down Expand Up @@ -146,23 +178,7 @@ def forward(
if return_codebook_ids:
return indices

# append mask tokens

mask_tokens = repeat(self.mask_tokens, 'n d -> b n d', b = batch)

tokens, mask_packed_shape = pack([mask_tokens, quantized], 'b * d')

# decode

tokens = self.decoder(tokens)

tokens, _ = unpack(tokens, mask_packed_shape, 'b * d')

tokens = unpack_one(tokens, height_width_shape, 'b * d')

# tokens to image patches

recon = self.tokens_to_image(tokens)
recon = self.decode(latents)

# reconstruction loss

Expand Down

0 comments on commit 82ea31e

Please sign in to comment.