Skip to content

PyTorch Implementation of Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model

License

Notifications You must be signed in to change notification settings

VachanVY/Transfusion.torch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Transfusion [Paper]

  • Transfusion is a Multi-Modal Transformer, it can generate text like GPTs and images like Diffusion Models, all at once in one go not separately!
  • It can easily switch between text and image modalities for generations, and it is nothing complicated, just a single transformer with some modality-specific components!
  • This can easily be extended to other modalities like videos, audio, etc, but for now, it can only take images and text as input
  • TODO: Train on a large Multi-Modal Dataset (something like tiny stories dataset with images in between illustrating the story...?)
from src import LLaMA, Transfussion

class config:
    ... # Fill in some parameters for the model | see src/configs.py for reference

model = Transfussion(
    model=LLaMA(config),
    config=config
)

text_and_images = [
    [
        torch.randint(0, 10, (39,)), # text
        # You get "image" after passing the image to PatchOps.patchify() while preprocessing
        (torch.randn(345, config.patch_size**2 * config.in_channels), torch.randint(0, config.num_timesteps, (1,))), # (image, timestep)
        torch.randint(0, 10, (14,)) # text
    ],
    [
        torch.randint(0, 10, (16,)), # text
        # You get "image" after passing the image to PatchOps.patchify() while preprocessing
        (torch.randn(359, config.patch_size**2 * config.in_channels), torch.randint(0, config.num_timesteps, (1,))), # (image, timestep)
        torch.randint(0, 10, (5,)), # text
        # You get "image" after passing the image to PatchOps.patchify() while preprocessing
        (torch.randn(2, config.patch_size**2 * config.in_channels), torch.randint(0, config.num_timesteps, (1,))),   # (image, timestep)
        torch.randint(0, 10, (9,))  # text
    ]
]
output = model(text_and_images, [["text", "image", "text"], ["text", "image", "text", "image", "text"]])

Contents

Introduction

  • image
  • Transfusion by pretraining a transformer model on 50% text and 50% image data using a different objective for each modality: next token prediction for text and diffusion for images
  • We apply causal attention for text tokens and bidirectional attention for image patches. For inference, we introduce a decoding algorithm that combines the standard practices of text generation from language models and image generation from diffusion models
  • Intra-image bidirectional attention is important, and replacing it with causal attention hurts text-to-image generation

Language Modelling Utils and Loss

  • Autoregressive Classification
  • Usual Cross-Entropy Loss

Diffusion Utils and Loss

  • image
  • Noise Schedule: cosine scheduler (We found that while the linear noise schedule used in Ho et al. (2020) worked well for high-resolution images, it was sub-optimal for images of resolution 64 × 64 and 32 × 32) image
  • Loss: Mean Squared Error
  • Latent Image Representation: Variational autoencoders (VAEs) [Kingma and Welling, 2013] can save compute by encoding images into a lower-dimensional latent space

Data Representation

  • Discrete text and continuous images
  • Each text string is tokenized into a sequence of discrete tokens from a fixed vocabulary, where each token is represented as an integer

Model Architecture

  • image
  • The vast majority of the model’s parameters belong to a single transformer, which processes every sequence, regardless of modality (We follow Llama’s [Touvron et al., 2023a] flavour of the transformer block, which includes the SwiGLU activation function [Shazeer, 2020] and RoPE [Su et al., 2024])
  • To convert our data into this space, we use lightweight modality-specific components with unshared parameters
  • For text, these are the embedding matrices
  • Images, we experiment with two alternatives for compressing local windows of k × k patch vectors into a single transformer vector (and vice versa):
    1. a simple linear layer (We add an embedding of the timestep t to every patch vector before the linear layer)
    2. up and down blocks of a U-Net (We replace the U-Net’s AdaLayerNorm with regular layer norm in our implementation)
  • Transfusion Attention: While text is naturally sequential, images are not, and are usually modelled with unrestricted (bidirectional) attention. Transfusion combines both attention patterns by applying causal attention to every element in the sequence, and bidirectional attention within the aspects of each individual image

Training Objective

  • LM loss is computed per token (When the input is a BOI token, we do not compute any loss), while diffusion loss is computed per image, which may span multiple elements (image patches) in the sequence
  • Specifically, we add noise ϵ to each input latent image x0 according to the diffusion process to produce xt before patchification, and then compute the image-level diffusion loss
  • image
  • image

Optimization

  • AdamW => | betas=(0.9, 0.95) | eps=1e-8 | lr=3e-4 | warmup=4000 | min_lr=1.5e-5 | weight_decay=0.1 | clip_norm=1.0 |
  • balancing_coeff (lambda in loss function) = 5

Inference

  • image
  • 250 diffusion steps (but trained on 1000 timesteps)
  • cfg_coeff = 5.0

image


image