From 2f95258f27797f142df0d5dc2388e034285ab7e5 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 20 Jun 2024 06:41:48 -0700 Subject: [PATCH] expose a few more vit hparams --- README.md | 7 +++---- pyproject.toml | 2 +- titok_pytorch/titok.py | 8 ++++++++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 65844bc..24cc5ef 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,10 @@ from titok_pytorch import TiTokTokenizer images = torch.randn(2, 3, 256, 256) titok = TiTokTokenizer( - dim = 512, + dim = 1024, + patch_size = 32, num_latent_tokens = 32, # they claim only 32 tokens needed - codebook_size = 8192 # codebook size 8192 + codebook_size = 4096 # codebook size 4096 ) loss = titok(images) @@ -42,8 +43,6 @@ assert recon_images.shape == images.shape ## Todo - [ ] add multi-resolution patches -- [ ] add lfq -- [ ] support video ## Citations diff --git a/pyproject.toml b/pyproject.toml index dbab49d..198f8fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "titok-pytorch" -version = "0.0.4" +version = "0.0.5" description = "TiTok - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/titok_pytorch/titok.py b/titok_pytorch/titok.py index 39beb6a..a4cb6ca 100644 --- a/titok_pytorch/titok.py +++ b/titok_pytorch/titok.py @@ -44,7 +44,11 @@ def __init__( channels = 3, num_latent_tokens = 32, enc_depth = 6, + enc_heads = 8, + enc_dim_head = 64, dec_depth = 6, + dec_heads = 8, + dec_dim_head = 64, codebook_size = 8192, enc_kwargs: dict = dict(), dec_kwargs: dict = dict(), @@ -84,6 +88,8 @@ def __init__( self.encoder = Encoder( dim = dim, depth = enc_depth, + heads = enc_heads, + attn_dim_head = enc_dim_head, **enc_kwargs ) @@ -97,6 +103,8 @@ def __init__( self.decoder = Encoder( dim = dim, depth = dec_depth, + heads = dec_heads, + attn_dim_head = dec_dim_head, **dec_kwargs )