Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Larger batch size to generate images in text2im.ipynb? #29

Open
brijow opened this issue Mar 10, 2022 · 2 comments
Open

Larger batch size to generate images in text2im.ipynb? #29

brijow opened this issue Mar 10, 2022 · 2 comments

Comments

@brijow
Copy link

brijow commented Mar 10, 2022

Hi, in the example notebook text2im.ipynb, I'm not clear on how to use a larger batch size that 1, or the recommended way to generate many images?

I'd like to play around with the model and generate several thousand images for some captions I have collected and evaluate the overall quality of results... however, I'm not clear on the best way to do this, rather than something along the lines of the psuedo-code below:

for each caption in my dataset:
      tokens = encode(caption)
      model_kwargs = {...}
      sample =  diffusion.p_sample_loop(...)
      save_sample(sample)

Would there be a faster way to do this than (more/less) following the recipe above?

@woctezuma
Copy link

Check the code of the notebook at:

@tristanengst
Copy link

Not sure if this is too late to be helpful, but the following is about twice as fast as looping over a list of captions and seems to be what you want. I've cleaned it up form my own file, so I haven't had a chance to run it and there may be an error lurking somewhere. The basic idea is to replace tiling the tokens coming from a single prompt—see the multiplications by batch_size in the original code—with additional tokens.

from PIL import Image
from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    model_and_diffusion_defaults_upsampler
)
import torch
import matplotlib.pyplot as plt

has_cuda = torch.cuda.is_available()
device = torch.device('cpu' if not has_cuda else 'cuda')
# Create base glide.
options = model_and_diffusion_defaults()
options['use_fp16'] = has_cuda
options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
glide, diffusion = create_model_and_diffusion(**options)
glide.eval()
if has_cuda:
    glide.convert_to_fp16()
glide.to(device)
glide.load_state_dict(load_checkpoint('base', device))
print('total base parameters', sum(x.numel() for x in glide.parameters()))

guidance_scale = 3.0
upsample_temp = 0.997

# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = torch.cat([half, half], dim=0)
    model_out = glide(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = torch.cat([half_eps, half_eps], dim=0)
    return torch.cat([eps, rest], dim=1)

def glide_generate(prompts):
    """Returns a tensor of images where the ith image is generated from the ith prompt in [prompts].
    
    Args:
    prompts    -- list of string prompts
    """
    batch_size = len(prompts)
    tokens = [glide.tokenizer.encode(p) for p in prompts]
    tokens_and_masks = [glide.tokenizer.padded_tokens_and_mask(t, options['text_ctx']) for t in tokens]
    tokens = [t for t,_ in tokens_and_masks]
    masks = [m for _,m in tokens_and_masks]

    # Create the classifier-free guidance tokens (empty)
    full_batch_size = batch_size * 2
    uncond_tokens, uncond_mask = glide.tokenizer.padded_tokens_and_mask([], options['text_ctx'])

    # Pack the tokens together into glide kwargs.
    model_kwargs = dict(
        tokens=torch.tensor(tokens + [uncond_tokens] * batch_size, device=device),
        mask=torch.tensor(masks + [uncond_mask] * batch_size, dtype=torch.bool, device=device))

    glide.del_cache()
    samples = diffusion.p_sample_loop(
        model_fn,
        (full_batch_size, 3, options["image_size"], options["image_size"]),
        device=device,
        clip_denoised=True,
        progress=True,
        model_kwargs=model_kwargs,
        cond_fn=None,
    )[:batch_size]
    glide.del_cache()

    # Uncomment what's below to validate the function
    # scaled = ((samples + 1)*127.5).round().clamp(0,255).to(torch.uint8).cpu()
    # for s in scaled:
    #     plt.imshow(s.permute(1, 2, 0)  )
    #     plt.show()

    return (samples + 1) / 2
    
# Uncomment what's below to validate the function
# glide_generate(["a painting of a blue bird", "a painting of a red cat", "a painting of a purple apple"])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants