From 84176c2ad78bf47f83fbc3105deb37791b79abb5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 18 Jun 2024 07:28:13 -0700 Subject: [PATCH] tokenizer method --- README.md | 8 ++++++-- titok_pytorch/titok.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5e3ecf1..3121e81 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,14 @@ from titok_pytorch.titok import TiTokTokenizer images = torch.randn(2, 3, 256, 256) -tokenizer = TiTokTokenizer(dim = 512) +titok = TiTokTokenizer(dim = 512) -loss = tokenizer(images) +loss = titok(images) loss.backward() + +# after much training + +codes = titok.tokenize(images) ``` ## Citations diff --git a/titok_pytorch/titok.py b/titok_pytorch/titok.py index 57eb13e..62e745c 100644 --- a/titok_pytorch/titok.py +++ b/titok_pytorch/titok.py @@ -98,9 +98,14 @@ def __init__( Rearrange('b h w (c p1 p2) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size) ) + @torch.no_grad() + def tokenize(self, images): + return self.forward(images, return_codebook_ids = True) + def forward( self, - images + images, + return_codebook_ids = False ): batch = images.shape[0] orig_images = images @@ -136,6 +141,11 @@ def forward( quantized, indices, _ = self.vq(latents) + # whether to early return + + if return_codebook_ids: + return indices + # append mask tokens mask_tokens = repeat(self.mask_tokens, 'n d -> b n d', b = batch)