Context parallel attention that works with torch.compile
,
supporting both Ulysses Style and Ring Style parallelism.
This aims to provide:
- An easy to use interface to speed up model inference with context parallel and
torch.compile
. MakeFLUX
andMochi
inference much faster losslessly. - A unified interface to run context parallel attention (cfg-ulysses-ring), as well as keeping the maximum performance while working with
torch.compile
- The fastest accurate attention implemented in Triton, running 50% faster than the originial FA2 implementation on RTX 4090.
What's different from other implementations:
- No unnecessary graph breaks during
torch.compile
. All the heavy computations are captured in a single graph and get the maximum opportunity to be optimized. This makes it possible for the backend compiler to optimize the graph more effectively, for example, by overlapping the computation and communication. - Easy to use. You don't need to change the code of the model to enable context parallelism. Instead, you only need to call a function to parallelize the model.
- Easy to use, too. If you want to use context parallelism with your custom model, you only need to wrap the call with our special
TorchFunctionMode
context manager. - Easy to adjust. You can adjust the parallelism style and the mesh shape with a few lines of code.
You could run the following examples with torchrun
.
For example, to run FLUX with 2 GPUs:
torchrun --nproc_per_node=2 examples/run_flux.py
Model | GPU | Method | Wall Time (s) | Speedup |
---|---|---|---|---|
FLUX.1-dev | A100-SXM4-80GB | Baseline | 13.843 | 1.00x |
FLUX.1-dev | A100-SXM4-80GB | torch.compile |
9.997 | 1.38x |
FLUX.1-dev | A100-SXM4-80GB x 2 | para-attn (ring) |
8.307 | 1.66x |
FLUX.1-dev | A100-SXM4-80GB x 2 | para-attn (ring) + torch.compile |
5.775 | 2.39x |
FLUX.1-dev | A100-SXM4-80GB x 4 | para-attn (ulysses + ring) |
6.157 | 2.25x |
FLUX.1-dev | A100-SXM4-80GB x 4 | para-attn (ulysses + ring) + torch.compile |
3.557 | 3.89x |
mochi-1-preview | A100-SXM4-80GB | Baseline | 196.534 | 1.00x |
mochi-1-preview | A100-SXM4-80GB | torch.compile |
149.868 | 1.31x |
mochi-1-preview | A100-SXM4-80GB x 2 | para-attn (cfg) |
105.438 | 1.86x |
mochi-1-preview | A100-SXM4-80GB x 2 | para-attn (ulysses) |
110.146 | 1.78x |
mochi-1-preview | A100-SXM4-80GB x 2 | para-attn (ring) |
109.435 | 1.80x |
mochi-1-preview | A100-SXM4-80GB x 2 | para-attn (cfg) + torch.compile |
81.913 | 2.40x |
mochi-1-preview | A100-SXM4-80GB x 2 | para-attn (ulysses) + torch.compile |
83.912 | 2.34x |
mochi-1-preview | A100-SXM4-80GB x 2 | para-attn (ring) + torch.compile |
82.176 | 2.39x |
mochi-1-preview | A100-SXM4-80GB x 4 | para-attn (cfg + ring) |
61.206 | 3.21x |
mochi-1-preview | A100-SXM4-80GB x 4 | para-attn (cfg + ring) + torch.compile |
47.100 | 4.17x |
NOTE: The speedup of iterations per second is generally higher than the speedup of wall time, because the wall time includes the overhead of calling the text encoder and vae decoder.
pip3 install 'torch==2.5.0'
pip3 install para-attn
git clone https://github.com/chengzeyi/ParaAttention.git
cd ParaAttention
git submodule update --init --recursive
pip3 install 'torch==2.5.0'
pip3 install 'setuptools>=64' 'setuptools_scm>=8'
# Pass --no-use-pep517 to speed up rebuild by using the legacy build system
# which doesn't use a one-time tmp directory for the build
pip3 install -e '.[dev]' --no-build-isolation
# Or:
# python3 setup.py develop
# Code formatting and linting
pip3 install pre-commit
pre-commit install
pre-commit run --all-files
import torch
import torch.distributed as dist
from diffusers import FluxPipeline
dist.init_process_group()
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to(f"cuda:{dist.get_rank()}")
from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
parallelize_pipe(
pipe,
mesh=init_context_parallel_mesh(
pipe.device.type,
max_ring_dim_size=2,
),
)
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
image = pipe(
"A cat holding a sign that says hello world",
num_inference_steps=28,
output_type="pil" if dist.get_rank() == 0 else "latent",
)
if dist.get_rank() == 0:
print("Saving image to flux.png")
image.save("flux.png")
dist.destroy_process_group()
Save the above code to run_flux.py
and run it with torchrun
:
torchrun --nproc_per_node=2 run_flux.py
import torch
import torch.distributed as dist
from diffusers import MochiPipeline
from diffusers.utils import export_to_video
dist.init_process_group()
pipe = MochiPipeline.from_pretrained(
"genmo/mochi-1-preview",
torch_dtype=torch.float16,
).to(f"cuda:{dist.get_rank()}")
# Enable memory savings
# pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
parallelize_pipe(
pipe,
mesh=init_context_parallel_mesh(
pipe.device.type,
max_batch_dim_size=2,
max_ring_dim_size=2,
),
)
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer,
mode="max-autotune-no-cudagraphs")
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
video = pipe(
prompt,
num_frames=84,
output_type="pil" if dist.get_rank() == 0 else "latent",
).frames[0]
if dist.get_rank() == 0:
print("Saving video to mochi.mp4")
export_to_video(video, "mochi.mp4", fps=30)
dist.destroy_process_group()
Save the above code to run_mochi.py
and run it with torchrun
:
torchrun --nproc_per_node=2 run_mochi.py
Model | Command |
---|---|
FLUX |
torchrun --nproc_per_node=2 examples/run_flux.py |
Mochi |
torchrun --nproc_per_node=2 examples/run_mochi.py |
CogVideoX |
torchrun --nproc_per_node=2 examples/run_cogvideox.py |
import torch
import torch.distributed as dist
import torch.nn.functional as F
from para_attn import para_attn_interface
dist.init_process_group()
world_size = dist.get_world_size()
rank = dist.get_rank()
assert world_size <= torch.cuda.device_count()
if world_size % 2 == 0:
mesh_shape = (2, world_size // 2)
else:
mesh_shape = (1, world_size)
B, H, S_Q, S_KV, D = 2, 24, 4096, 4096, 64
dtype = torch.float16
device = "cuda"
torch._inductor.config.reorder_for_compute_comm_overlap = True
with torch.no_grad(), torch.cuda.device(rank):
torch.manual_seed(0)
query = torch.randn(B, H, S_Q, D, dtype=dtype, device=device)
key = torch.randn(B, H, S_KV, D, dtype=dtype, device=device)
value = torch.randn(B, H, S_KV, D, dtype=dtype, device=device)
attn_mask = None
dropout_p = 0.0
is_causal = False
query_slice = query.chunk(world_size, dim=-2)[rank]
key_slice = key.chunk(world_size, dim=-2)[rank]
value_slice = value.chunk(world_size, dim=-2)[rank]
def func(*args, **kwargs):
return F.scaled_dot_product_attention(*args, **kwargs)
func = torch.compile(func)
for _ in range(2):
mesh = dist.init_device_mesh(device, mesh_shape, mesh_dim_names=("ring", "ulysses"))
with para_attn_interface.UnifiedAttnMode(mesh):
out_slice = func(
query_slice,
key_slice,
value_slice,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
out_slice_ref = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
).chunk(world_size, dim=-2)[rank]
torch.testing.assert_close(out_slice, out_slice_ref, rtol=1e-5, atol=1e-3 * world_size)
dist.destroy_process_group()
Save the above code to test.py
and run it with torchrun
:
torchrun --nproc_per_node=2 test.py
DISTRIBUTED_TESTS_DEFAULT_TIMEOUT=3000 pytest tests