Skip to content

Commit

Permalink
Dev first block cache (#12)
Browse files Browse the repository at this point in the history
* implement first block cache

* fix

* fix

* fix

* add doc

* fix

* make flux work

* fix

* fix

* fix

* fix

* refactor

* fix

* fix

* Update fastest_hunyuan_video.md

* Update fastest_hunyuan_video.md

* fix

* fix

* fix

* fix
  • Loading branch information
chengzeyi authored Jan 3, 2025
1 parent 48f5fd2 commit 1324733
Show file tree
Hide file tree
Showing 32 changed files with 1,193 additions and 608 deletions.
280 changes: 85 additions & 195 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# ParaAttention

Context parallel attention that accelerates DiT model inference,
Context parallel attention that accelerates DiT model inference with dynamic caching,
supporting both [**Ulysses Style**](https://arxiv.org/abs/2309.14509) and [**Ring Style**](https://arxiv.org/abs/2310.01889) parallelism.

🔥[Fastest HunyuanVideo Inference with Context Parallelism and First Block Cache on NVIDIA L20 GPUs](doc/fastest_hunyuan_video.md)🔥

This aims to provide:

- [x] An easy to use interface to speed up model inference with context parallel and `torch.compile`. Make **`FLUX`**, **`HunyuanVideo`** and **`Mochi`** inference much faster losslessly.
- [x] An easy to use interface to speed up model inference with context parallel, dynamic caching and `torch.compile`. Make **`FLUX`**, **`HunyuanVideo`** and **`Mochi`** inference much faster losslessly.
- [x] A unified interface to run context parallel attention (***cfg-ulysses-ring***), as well as keeping the maximum performance while working with `torch.compile`
- [ ] The fastest accurate attention implemented in Triton, running 50% faster than the originial FA2 implementation on RTX 4090.

Expand All @@ -16,9 +18,55 @@ What's different from other implementations:
- Easy to use, too. If you want to use context parallelism with your custom model, you only need to wrap the call with our special `TorchFunctionMode` context manager.
- Easy to adjust. You can adjust the parallelism style and the mesh shape with a few lines of code.

# Key Features

### Context Parallelism

**Context Parallelism** (CP) is a method for parallelizing the processing of neural network activations across multiple GPUs by partitioning the input tensors along the sequence dimension.
Unlike Sequence Parallelism (SP) that partitions the activations of specific layers, CP divides the activations of all layers.
In `ParaAttention`, we are able to parallelize the attention layer with a mixture of Ulysses Style and Ring Style parallelism, called Unified Attention.
This allows us to achieve the best performance with different models and different hardware configurations.
We also provide a unified interface to parallelize the model inference.

You only need to call a single function to enable context parallelism on your `diffusers` pipeline:

```python
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe

parallelize_pipe(pipe)
```

### First Block Cache (Our Dynamic Caching)

Inspired by [TeaCache](https://github.com/ali-vilab/TeaCache) and other denoising caching algorithms, we introduce **First Block Cache** (FBCache) to use the residual output of the first transformer block as the cache indicator.
If the difference between the current and the previous residual output of the first transformer block is small enough, we can reuse the previous final residual output and skip the computation of all the following transformer blocks.
This can significantly reduce the computation cost of the model, achieving a speedup of up to 2x while maintaining high accuracy.

| Model | Optimizations | Preview |
| - | - | - |
| HunyuanVideo | Original | [Original](https://github.com/user-attachments/assets/883d771a-e74e-4081-aa2a-416985d6c713) |
| HunyuanVideo | FBCache | [FBCache](https://github.com/user-attachments/assets/f77c2f58-2b59-4dd1-a06a-a36974cb1e40) |
| FLUX.1-dev | Original | [Original](./assets/flux_original.png) |
| FLUX.1-dev | FBCache | [FBCache](./assets/flux_fbc.png) |

You only need to call a single function to enable First Block Cache on your `diffusers` pipeline:

```python
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe

apply_cache_on_pipe(
pipe,
# residual_diff_threshold=0.0,
)
```

# Officially Supported Models

You could run the following examples with `torchrun`.
## Context Parallelism with First Block Cache

You could run the following examples with `torchrun` to enable context parallelism with dynamic caching.
You can modify the code to enable `torch.compile` to further accelerate the model inference.
If you want quantization, please refer to [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) for more information.
For example, to run FLUX with 2 GPUs:

**NOTE**: To measure the performance correctly with `torch.compile`, you need to warm up the model by running it for a few iterations before measuring the performance.
Expand All @@ -33,12 +81,26 @@ torchrun --nproc_per_node=2 parallel_examples/run_flux.py
- [Mochi](parallel_examples/run_mochi.py)
- [CogVideoX](parallel_examples/run_cogvideox.py)

## Single GPU Inference with First Block Cache

You can also run the following examples with a single GPU and enable the First Block Cache to speed up the model inference.

```bash
python3 first_block_cache_examples/run_hunyuan_video.py
```

- [HunyuanVideo🚀](first_block_cache_examples/run_hunyuan_video.py)
- [Mochi](first_block_cache_examples/run_mochi.py)
- [CogVideoX](first_block_cache_examples/run_cogvideox.py)

**NOTE**: To run `HunyuanVideo`, you need to install `diffusers` from its latest master branch.
It is suggested to run `HunyuanVideo` with GPUs with at least 48GB memory, or you might experience OOM errors,
and the performance might be worse due to frequent memory re-allocation.

# Performance

## Context Parallelism (without First Block Cache)

| Model | GPU | Method | Wall Time (s) | Speedup |
| --- | --- | --- | --- | --- |
| FLUX.1-dev | A100-SXM4-80GB | Baseline | 13.843 | 1.00x |
Expand Down Expand Up @@ -93,197 +155,32 @@ pre-commit run --all-files

# Usage

## Run FLUX.1-dev with Parallel Inference

```python
import torch
import torch.distributed as dist
from diffusers import FluxPipeline

dist.init_process_group()

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to(f"cuda:{dist.get_rank()}")

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
from para_attn.parallel_vae.diffusers_adapters import parallelize_vae

mesh = init_context_parallel_mesh(
pipe.device.type,
max_ring_dim_size=2,
)
parallelize_pipe(
pipe,
mesh=mesh,
)
parallelize_vae(pipe.vae, mesh=mesh._flatten())

# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())

# torch._inductor.config.reorder_for_compute_comm_overlap = True
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

image = pipe(
"A cat holding a sign that says hello world",
num_inference_steps=28,
output_type="pil" if dist.get_rank() == 0 else "pt",
).images[0]

if dist.get_rank() == 0:
print("Saving image to flux.png")
image.save("flux.png")

dist.destroy_process_group()
```

Save the above code to `run_flux.py` and run it with `torchrun`:

```bash
torchrun --nproc_per_node=2 run_flux.py
```

## Run HunyuanVideo🚀 with Parallel Inference

**NOTE**: To run `HunyuanVideo`, you need to install `diffusers` from its latest master branch.
It is suggested to run `HunyuanVideo` with GPUs with at least 48GB memory, or you might experience OOM errors,
and the performance might be worse due to frequent memory re-allocation.

```python
import torch
import torch.distributed as dist
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

dist.init_process_group()

# [rank1]: RuntimeError: Expected mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
torch.backends.cuda.enable_cudnn_sdp(False)

model_id = "tencent/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id,
subfolder="transformer",
torch_dtype=torch.bfloat16,
revision="refs/pr/18",
)
pipe = HunyuanVideoPipeline.from_pretrained(
model_id,
transformer=transformer,
torch_dtype=torch.float16,
revision="refs/pr/18",
).to(f"cuda:{dist.get_rank()}")

pipe.vae.enable_tiling(
# Make it runnable on GPUs with 48GB memory
# tile_sample_min_height=128,
# tile_sample_stride_height=96,
# tile_sample_min_width=128,
# tile_sample_stride_width=96,
# tile_sample_min_num_frames=32,
# tile_sample_stride_num_frames=24,
)

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
from para_attn.parallel_vae.diffusers_adapters import parallelize_vae

mesh = init_context_parallel_mesh(
pipe.device.type,
)
parallelize_pipe(
pipe,
mesh=mesh,
)
parallelize_vae(pipe.vae, mesh=mesh._flatten())

# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())

# torch._inductor.config.reorder_for_compute_comm_overlap = True
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

output = pipe(
prompt="A cat walks on the grass, realistic",
height=720,
width=1280,
num_frames=129,
num_inference_steps=30,
output_type="pil" if dist.get_rank() == 0 else "pt",
).frames[0]

if dist.get_rank() == 0:
print("Saving video to hunyuan_video.mp4")
export_to_video(output, "hunyuan_video.mp4", fps=15)

dist.destroy_process_group()
```

Save the above code to `run_hunyuan_video.py` and run it with `torchrun`:

```bash
torchrun --nproc_per_node=2 run_hunyuan_video.py
```

## Run Mochi with Parallel Inference

```python
import torch
import torch.distributed as dist
from diffusers import MochiPipeline
from diffusers.utils import export_to_video

dist.init_process_group()

pipe = MochiPipeline.from_pretrained(
"genmo/mochi-1-preview",
torch_dtype=torch.bfloat16,
).to(f"cuda:{dist.get_rank()}")

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe

parallelize_pipe(
pipe,
mesh=init_context_parallel_mesh(
pipe.device.type,
max_batch_dim_size=2,
max_ring_dim_size=2,
),
)

# Enable memory savings
# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
pipe.enable_vae_tiling()

# torch._inductor.config.reorder_for_compute_comm_overlap = True
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
## All Examples

prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
video = pipe(
prompt,
num_frames=84,
output_type="pil" if dist.get_rank() == 0 else "pt",
).frames[0]
Please refer to examples in the `parallel_examples` and `first_block_cache_examples` directories.

if dist.get_rank() == 0:
print("Saving video to mochi.mp4")
export_to_video(video, "mochi.mp4", fps=30)
### Parallelize Models

dist.destroy_process_group()
```
| Model | Command |
| - | - |
| `FLUX` | `torchrun --nproc_per_node=2 parallel_examples/run_flux.py` |
| `HunyuanVideo` | `torchrun --nproc_per_node=2 parallel_examples/run_hunyuan_video.py` |
| `Mochi` | `torchrun --nproc_per_node=2 parallel_examples/run_mochi.py` |
| `CogVideoX` | `torchrun --nproc_per_node=2 parallel_examples/run_cogvideox.py` |

Save the above code to `run_mochi.py` and run it with `torchrun`:
### Apply First Block Cache

```bash
torchrun --nproc_per_node=2 run_mochi.py
```
| Model | Command |
| - | - |
| `FLUX` | `python3 first_block_cache_examples/run_flux.py` |
| `HunyuanVideo` | `python3 first_block_cache_examples/run_hunyuan_video.py` |
| `Mochi` | `python3 first_block_cache_examples/run_mochi.py` |
| `CogVideoX` | `python3 first_block_cache_examples/run_cogvideox.py` |

## Parallelize VAE

VAE can be parallelized with `para_attn.parallel_vae.diffusers_adapters.parallelize_vae`.
Currently, only `AutoencoderKL` is supported.
Currently, only `AutoencoderKL` and `AutoencoderKLHunyuanVideo` are supported.

``` python
import torch
Expand All @@ -292,25 +189,18 @@ from diffusers import AutoencoderKL

dist.init_process_group()

torch.cuda.set_device(dist.get_rank())

vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to(f"cuda:{dist.get_rank()}")
).to("cuda")

from para_attn.parallel_vae.diffusers_adapters import parallelize_vae

parallelize_vae(vae)
```

## All Examples

| Model | Command |
| - | - |
| `FLUX` | `torchrun --nproc_per_node=2 parallel_examples/run_flux.py` |
| `HunyuanVideo` | `torchrun --nproc_per_node=2 parallel_examples/run_hunyuan_video.py` |
| `Mochi` | `torchrun --nproc_per_node=2 parallel_examples/run_mochi.py` |
| `CogVideoX` | `torchrun --nproc_per_node=2 parallel_examples/run_cogvideox.py` |

## Run Unified Attention (Hybird Ulysses Style and Ring Style) with `torch.compile`

```python
Expand Down
Binary file added assets/flux_fbc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flux_original.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/hunyuan_video_fbc.mp4
Binary file not shown.
Binary file added assets/hunyuan_video_original.mp4
Binary file not shown.
Loading

0 comments on commit 1324733

Please sign in to comment.