Skip to content

Commit

Permalink
tokenizer method
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 18, 2024
1 parent d856fee commit 84176c2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ from titok_pytorch.titok import TiTokTokenizer

images = torch.randn(2, 3, 256, 256)

tokenizer = TiTokTokenizer(dim = 512)
titok = TiTokTokenizer(dim = 512)

loss = tokenizer(images)
loss = titok(images)
loss.backward()

# after much training

codes = titok.tokenize(images)
```

## Citations
Expand Down
12 changes: 11 additions & 1 deletion titok_pytorch/titok.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,14 @@ def __init__(
Rearrange('b h w (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
)

@torch.no_grad()
def tokenize(self, images):
return self.forward(images, return_codebook_ids = True)

def forward(
self,
images
images,
return_codebook_ids = False
):
batch = images.shape[0]
orig_images = images
Expand Down Expand Up @@ -136,6 +141,11 @@ def forward(

quantized, indices, _ = self.vq(latents)

# whether to early return

if return_codebook_ids:
return indices

# append mask tokens

mask_tokens = repeat(self.mask_tokens, 'n d -> b n d', b = batch)
Expand Down

0 comments on commit 84176c2

Please sign in to comment.