Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to change SDXL pipeline precision to bfloat16 with assign a new value for torch_dtype variable? #3691

Closed
gokerguner opened this issue Mar 1, 2024 · 7 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@gokerguner
Copy link

gokerguner commented Mar 1, 2024

Description

I try to develop a txt2img SDXL application, depends on the repo(forked from this repo) mentioned in Relevant Files:

If I tried run SDXL in a normal pipeline, it is enough for change torch_dtype precision to bfloat16:

from diffusers import DiffusionPipeline
import torch
import time

prompt = "a man"
seed = 42
pipeline = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16,
).to("cuda")

for i in range(50):
    starttime = time.time()   
    image2 = pipeline([prompt]*2, generator=torch.manual_seed(seed), num_inference_steps=25).images[0:2]
    endtime = time.time()
    time_past = endtime - starttime
    print(time_past)

My wish is to implement this precision change in the above repo. In stable_diffusion_pipeline.py, there are two methods naming as initialize_latents and decode_latent. When I change dtypes in these methods, VAE section accelerates by 0.1 seconds.

Current output looks like:

SD-XL Base Pipeline
|------------|--------------|
|   Module   |   Latency    |
|------------|--------------|
|    CLIP    |     14.95 ms |
| UNet x 25  |   4886.70 ms |
|  VAE-Dec   |    630.26 ms |
|------------|--------------|
|  Pipeline  |   5674.99 ms |
|------------|--------------|
Throughput: 0.35 image/s
Saving image 1 / 2 to: output/txt2img-xl-fp16-Astronaut_-1-2250.png
Saving image 2 / 2 to: output/txt2img-xl-fp16-Astronaut_-2-8143.png
|------------|--------------|
|    e2e     |   5674.99 ms |
|------------|--------------|

I've try to precision change on UNET. In denoise_latent method I implement new timestep definition and control it with printing, in line 539:

timestep = timestep.to(dtype=torch.bfloat16) 
timestep_float = timestep.bfloat16() if timestep.dtype != torch.bfloat16 else timestep
print(type(timestep_float), " timestep_float: ", timestep_float, " timestep.dtype: ",timestep.dtype)

Even though change timestep.dtype successfully(at least as I see in print output), my UNet section is still run for nearly 5 seconds. How can I implement the precision change truly?

print outputs look like:

<class 'torch.Tensor'>  timestep_float:  tensor(960., device='cuda:0', dtype=torch.bfloat16)  timestep.dtype:  torch.bfloat16

Environment

TensorRT Version: torch-tensorrt 1.5.0.dev0

NVIDIA GPU: A100 80GB

NVIDIA Driver Version: 535.86.10

CUDA Version: 12.2

CUDNN Version: nvidia-cudnn-cu12 8.9.7.29

Operating System: Google VM a2-ultragpu-1g type, Debian GNU/Linux 11

Python Version : 3.10.6

PyTorch Version : torch 2.1.0a0+4136153

Docker --version: Docker version 20.10.17, build 100c701

Relevant Files

Repo link: https://github.com/rajeevsrao/TensorRT/tree/release/8.6/demo/Diffusion

Steps To Reproduce

Follow the setup instructions in README file of the repo, then run the command below:

Commands or scripts: python3 demo_txt2img_xl.py "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" --width 1024 --height 1024 --denoising-steps 25 --repeat-prompt 2 --num-warmup-runs 0

If you run the code for the first time, compilation time might be lasts more than 30 min.

@gokerguner
Copy link
Author

There is an update:

I printed noise_pred variable in stable_diffusion_pipeline.py file, I see still different dtype. I did deep diving, and came into infer() method in utilities.py. And I try to change self.tensors.keys() dtypes with:

for k in self.tensors.keys():
            self.tensors[k] = self.tensors[k].to(dtype=torch.bfloat16)

But, I take new error:

[E] 1: [genericReformat.cuh::copyVectorizedRunKernel::1579] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
False
Traceback (most recent call last):
  File "/workspace/mode.py", line 124, in <module>
    images, pipeline_time = run_sd_xl_inference(warmup=False, verbose=args.verbose)
  File "/workspace/mode.py", line 110, in run_sd_xl_inference
    images, time_base = demo_base.infer(prompt, negative_prompt, image_height, image_width, warmup=warmup, verbose=verbose, seed=args.seed, return_type="image")
  File "/workspace/txt2img_xl_pipeline.py", line 108, in infer
    text_embeddings = self.encode_prompt(prompt, negative_prompt, 
  File "/workspace/stable_diffusion_pipeline.py", line 461, in encode_prompt
    outputs = self.runEngine(encoder, {"input_ids": text_input_ids_inp})
  File "/workspace/stable_diffusion_pipeline.py", line 392, in runEngine
    return engine.infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph)
  File "/workspace/utilities.py", line 294, in infer
    raise ValueError(f"ERROR: inference failed.")
ValueError: ERROR: inference failed.

@zerollzeng
Copy link
Collaborator

@rajeevsrao @ttyio ^ ^

@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Mar 8, 2024
@gokerguner
Copy link
Author

It seems latest version(9.2) can carry to bfloat16 support. But still I confused how can I implement that correctly.

@ttyio
Copy link
Collaborator

ttyio commented Mar 9, 2024

The SD demo use FP16, the code to enable FP16 build is in

engine = engine_from_network(

What's the motivation to move to BF16?

@gokerguner
Copy link
Author

gokerguner commented Mar 12, 2024

We reference the blog post below. We want to speed up our SDXL processes, which is the main reason for our interest in TensorRT. Apparently, the bfloat16 conversion will also give us the necessary acceleration. When we apply this to the Vae layer, we see a speedup of about 15%. Now we are trying to do the same for UNet.

Pytorch blog

We were using version 8.6, but bfloat16 support seems to have come with 9.2. Is this useful to us?
9.2 update

@ttyio
Copy link
Collaborator

ttyio commented Apr 16, 2024

Yes bfloat16 helps in some of the kernels, but for MHA, we can get more perf gain using the FP16/INT8.

FYI, we also have a INT8 SDXL https://github.com/NVIDIA/TensorRT/tree/release/10.0/demo/Diffusion

@ttyio
Copy link
Collaborator

ttyio commented Jul 2, 2024

closing since no activity for more than 3 weeks, pls reopen if you still have question, thanks all!

@ttyio ttyio closed this as completed Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants