Skip to content

Commit

Permalink
Merge pull request #4 from chengzeyi/dev_batch_parallel
Browse files Browse the repository at this point in the history
Dev batch parallel
  • Loading branch information
chengzeyi authored Nov 14, 2024
2 parents 3d32b5a + cf7ec43 commit ffb4f05
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 61 deletions.
34 changes: 27 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ supporting both [**Ulysses Style**](https://arxiv.org/abs/2309.14509) and [**Rin
This aims to provide:

- [x] An easy to use interface to speed up model inference with context parallel and `torch.compile`. Make `FLUX` and `Mochi` inference much faster losslessly.
- [x] A unified interface to run context parallel attention, as well as keeping the maximum performance while working with `torch.compile`
- [x] A unified interface to run context parallel attention (***cfg-ulysses-ring***), as well as keeping the maximum performance while working with `torch.compile`
- [ ] The fastest accurate attention implemented in Triton, running 50% faster than the originial FA2 implementation on RTX 4090.

# Performance
Expand All @@ -15,18 +15,22 @@ This aims to provide:
| --- | --- | --- | --- | --- |
| FLUX.1-dev | A100-SXM4-80GB | Baseline | 13.843 | 1.00x |
| FLUX.1-dev | A100-SXM4-80GB | `torch.compile` | 9.997 | 1.38x |
| FLUX.1-dev | A100-SXM4-80GB x 2 | `para-attn (ulysses)` | 8.379 | 1.65x |
| FLUX.1-dev | A100-SXM4-80GB x 2 | `para-attn (ring)` | 8.307 | 1.66x |
| FLUX.1-dev | A100-SXM4-80GB x 2 | `para-attn (ulysses)` + `torch.compile` | 5.915 | 2.34x |
| FLUX.1-dev | A100-SXM4-80GB x 2 | `para-attn (ring)` + `torch.compile` | 5.775 | 2.39x |
| FLUX.1-dev | A100-SXM4-80GB x 4 | `para-attn (ulysses + ring)` + `torch.compile` | ? | ? |
| FLUX.1-dev | A100-SXM4-80GB x 4 | `para-attn (ulysses + ring)` | 6.157 | 2.25x |
| FLUX.1-dev | A100-SXM4-80GB x 4 | `para-attn (ulysses + ring)` + `torch.compile` | 3.557 | 3.89x |
| mochi-1-preview | A100-SXM4-80GB | Baseline | 196.534 | 1.00x |
| mochi-1-preview | A100-SXM4-80GB | `torch.compile` | 149.868 | 1.31x |
| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (cfg)` | 105.438 | 1.86x |
| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (ulysses)` | 110.146 | 1.78x |
| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (ring)` | 109.435 | 1.80x |
| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (cfg)` + `torch.compile` | 81.913 | 2.40x |
| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (ulysses)` + `torch.compile` | 83.912 | 2.34x |
| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (ring)` + `torch.compile` | 82.176 | 2.39x |
| mochi-1-preview | A100-SXM4-80GB x 4 | `para-attn (ulysses + ring)` + `torch.compile` | ? | ? |
| mochi-1-preview | A100-SXM4-80GB x 4 | `para-attn (cfg + ring)` | 61.206 | 3.21x |
| mochi-1-preview | A100-SXM4-80GB x 4 | `para-attn (cfg + ring)` + `torch.compile` | 47.100 | 4.17x |

NOTE: The speedup of iterations per second is generally higher than the speedup of wall time, because the wall time includes the overhead of calling the text encoder and vae decoder.

# Installation

Expand Down Expand Up @@ -74,9 +78,16 @@ pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(f"cuda:{dist.get_rank()}")

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe

parallelize_pipe(pipe)
parallelize_pipe(
pipe,
mesh=init_context_parallel_mesh(
pipe.device.type,
max_ring_dim_size=2,
),
)

torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(
Expand All @@ -103,6 +114,7 @@ torchrun --nproc_per_node=2 test.py

``` python
import torch
import torch.distributed as dist
from diffusers import MochiPipeline
from diffusers.utils import export_to_video

Expand All @@ -116,9 +128,17 @@ pipe = MochiPipeline.from_pretrained(
# pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()

from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe

parallelize_pipe(pipe)
parallelize_pipe(
pipe,
mesh=init_context_parallel_mesh(
pipe.device.type,
max_batch_dim_size=2,
max_ring_dim_size=2,
),
)

torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ log_format = %(asctime)s %(filename)s:%(lineno)d %(levelname)s %(message)s
log_date_format = %Y-%m-%d %H:%M:%S
log_cli = true
log_level = INFO
addopts = --capture=no --verbose --color=auto --durations=0
addopts = --capture=tee-sys --verbose --color=auto --durations=0
39 changes: 39 additions & 0 deletions src/para_attn/context_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import math

import torch.distributed as dist

import para_attn.primitives as DP


def init_context_parallel_mesh(
device_type=None, *, mesh=None, max_batch_dim_size=None, max_ulysses_dim_size=None, max_ring_dim_size=None
):
if mesh is not None:
return mesh

assert device_type is not None, "device must be provided if mesh is not provided"

world_size = DP.get_world_size()
if max_batch_dim_size is None:
batch_dim_size = 1
else:
batch_dim_size = math.gcd(world_size, max_batch_dim_size)

attn_world_size = world_size // batch_dim_size

assert not (
max_ulysses_dim_size is not None and max_ring_dim_size is not None
), "Only one of max_ulysses_dim_size and max_ring_dim_size can be set"

if max_ulysses_dim_size is None:
if max_ring_dim_size is None:
ring_dim_size = 1
else:
ring_dim_size = math.gcd(attn_world_size, max_ring_dim_size)
ulysses_dim_size = attn_world_size // ring_dim_size
else:
ulysses_dim_size = math.gcd(attn_world_size, max_ulysses_dim_size)
ring_dim_size = attn_world_size // ulysses_dim_size

mesh_shape = (batch_dim_size, ulysses_dim_size, ring_dim_size)
return dist.init_device_mesh(device_type, mesh_shape, mesh_dim_names=("batch", "ulysses", "ring"))
26 changes: 21 additions & 5 deletions src/para_attn/context_parallel/diffusers_adapters/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
from diffusers import DiffusionPipeline, FluxTransformer2DModel

import para_attn.primitives as DP
from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.para_attn_interface import UnifiedAttnMode


def parallelize_transformer(transformer: FluxTransformer2DModel, *, mesh=None) -> None:
assert isinstance(transformer, FluxTransformer2DModel)

mesh = init_context_parallel_mesh(transformer.device.type, mesh=mesh)
batch_mesh = mesh["batch"]
seq_mesh = mesh["ulysses", "ring"]._flatten()

original_forward = transformer.forward

@functools.wraps(transformer.__class__.forward)
Expand All @@ -25,16 +30,26 @@ def new_forward(
controlnet_single_block_samples: Optional[List[torch.Tensor]] = None,
**kwargs,
):
hidden_states = DP.get_assigned_chunk(hidden_states, dim=-2)
encoder_hidden_states = DP.get_assigned_chunk(encoder_hidden_states, dim=-2)
hidden_states = DP.get_assigned_chunk(hidden_states, dim=0, group=batch_mesh)
hidden_states = DP.get_assigned_chunk(hidden_states, dim=-2, group=seq_mesh)
encoder_hidden_states = DP.get_assigned_chunk(encoder_hidden_states, dim=0, group=batch_mesh)
encoder_hidden_states = DP.get_assigned_chunk(encoder_hidden_states, dim=-2, group=seq_mesh)
img_ids = DP.get_assigned_chunk(img_ids, dim=-2)
txt_ids = DP.get_assigned_chunk(txt_ids, dim=-2)
if controlnet_block_samples is not None:
controlnet_block_samples = [DP.get_assigned_chunk(sample, dim=-2) for sample in controlnet_block_samples]
controlnet_block_samples = [
DP.get_assigned_chunk(sample, dim=0, group=batch_mesh) for sample in controlnet_block_samples
]
controlnet_block_samples = [
DP.get_assigned_chunk(sample, dim=-2, group=seq_mesh) for sample in controlnet_block_samples
]
kwargs["controlnet_block_samples"] = controlnet_block_samples
if controlnet_single_block_samples is not None:
controlnet_single_block_samples = [
DP.get_assigned_chunk(sample, dim=-2) for sample in controlnet_single_block_samples
DP.get_assigned_chunk(sample, dim=0, group=batch_mesh) for sample in controlnet_single_block_samples
]
controlnet_single_block_samples = [
DP.get_assigned_chunk(sample, dim=-2, group=seq_mesh) for sample in controlnet_single_block_samples
]
kwargs["controlnet_single_block_samples"] = controlnet_single_block_samples

Expand All @@ -50,7 +65,8 @@ def new_forward(

return_dict = not isinstance(output, tuple)
sample = output[0]
sample = DP.get_complete_tensor(sample, dim=-2)
sample = DP.get_complete_tensor(sample, dim=-2, group=seq_mesh)
sample = DP.get_complete_tensor(sample, dim=0, group=batch_mesh)
if return_dict:
return output.__class__(sample, *output[1:])
return (sample, *output[1:])
Expand Down
33 changes: 23 additions & 10 deletions src/para_attn/context_parallel/diffusers_adapters/mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
from diffusers import DiffusionPipeline, MochiTransformer3DModel

import para_attn.primitives as DP
from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.para_attn_interface import UnifiedAttnMode


def parallelize_transformer(transformer: MochiTransformer3DModel, *, mesh=None) -> None:
assert isinstance(transformer, MochiTransformer3DModel)

mesh = init_context_parallel_mesh(transformer.device.type, mesh=mesh)
batch_mesh = mesh["batch"]
seq_mesh = mesh["ulysses", "ring"]._flatten()

original_forward = transformer.forward

@functools.wraps(transformer.__class__.forward)
Expand All @@ -22,9 +27,12 @@ def new_forward(
encoder_attention_mask: torch.Tensor,
**kwargs,
):
hidden_states = DP.get_assigned_chunk(hidden_states, dim=-2)
encoder_hidden_states = DP.get_assigned_chunk(encoder_hidden_states, dim=-2)
encoder_attention_mask = DP.get_assigned_chunk(encoder_attention_mask, dim=-1)
hidden_states = DP.get_assigned_chunk(hidden_states, dim=0, group=batch_mesh)
hidden_states = DP.get_assigned_chunk(hidden_states, dim=-2, group=seq_mesh)
encoder_hidden_states = DP.get_assigned_chunk(encoder_hidden_states, dim=0, group=batch_mesh)
encoder_hidden_states = DP.get_assigned_chunk(encoder_hidden_states, dim=-2, group=seq_mesh)
encoder_attention_mask = DP.get_assigned_chunk(encoder_attention_mask, dim=0, group=batch_mesh)
encoder_attention_mask = DP.get_assigned_chunk(encoder_attention_mask, dim=-1, group=seq_mesh)

with UnifiedAttnMode(mesh):
output = original_forward(
Expand All @@ -37,7 +45,8 @@ def new_forward(

return_dict = not isinstance(output, tuple)
sample = output[0]
sample = DP.get_complete_tensor(sample, dim=-2)
sample = DP.get_complete_tensor(sample, dim=-2, group=seq_mesh)
sample = DP.get_complete_tensor(sample, dim=0, group=batch_mesh)
if return_dict:
return output.__class__(sample, *output[1:])
return (sample, *output[1:])
Expand All @@ -56,13 +65,17 @@ def new_time_embed_forward(
*args,
**kwargs,
):
encoder_hidden_states = DP.get_complete_tensor(encoder_hidden_states, dim=-2)
encoder_attention_mask = DP.get_complete_tensor(encoder_attention_mask, dim=-1)
encoder_hidden_states = DP.get_complete_tensor(encoder_hidden_states, dim=-2, group=seq_mesh)
encoder_hidden_states = DP.get_complete_tensor(encoder_hidden_states, dim=0, group=batch_mesh)
encoder_attention_mask = DP.get_complete_tensor(encoder_attention_mask, dim=-1, group=seq_mesh)
encoder_attention_mask = DP.get_complete_tensor(encoder_attention_mask, dim=0, group=batch_mesh)
with UnifiedAttnMode.disable():
conditioning, caption_proj = original_time_embed_forward(
timestep, encoder_hidden_states, encoder_attention_mask, *args, **kwargs
)
caption_proj = DP.get_assigned_chunk(caption_proj, dim=-2)
conditioning = DP.get_assigned_chunk(conditioning, dim=0, group=batch_mesh)
caption_proj = DP.get_assigned_chunk(caption_proj, dim=0, group=batch_mesh)
caption_proj = DP.get_assigned_chunk(caption_proj, dim=-2, group=seq_mesh)
return conditioning, caption_proj

new_time_embed_forward = new_time_embed_forward.__get__(transformer.time_embed)
Expand All @@ -80,7 +93,7 @@ def new_rope_forward(
*args,
**kwargs,
):
height *= DP.get_world_size()
height *= DP.get_world_size(seq_mesh)
rope_cos, rope_sin = original_rope_forward(
pos_frequencies,
num_frames,
Expand All @@ -93,8 +106,8 @@ def new_rope_forward(
n, h, f = rope_cos.shape
rope_cos = rope_cos.reshape(num_frames, -1, h, f)
rope_sin = rope_sin.reshape(num_frames, -1, h, f)
rope_cos = DP.get_assigned_chunk(rope_cos, dim=-3)
rope_sin = DP.get_assigned_chunk(rope_sin, dim=-3)
rope_cos = DP.get_assigned_chunk(rope_cos, dim=-3, group=seq_mesh)
rope_sin = DP.get_assigned_chunk(rope_sin, dim=-3, group=seq_mesh)
rope_cos = rope_cos.reshape(-1, h, f)
rope_sin = rope_sin.reshape(-1, h, f)
return rope_cos, rope_sin
Expand Down
40 changes: 32 additions & 8 deletions src/para_attn/para_attn_interface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import contextlib

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import DeviceMesh
from torch.distributed.tensor.experimental._attention import _templated_ring_attention
from torch.overrides import TorchFunctionMode

import para_attn.primitives as DP

try:
from torch.distributed.tensor.experimental._attention import _templated_ring_attention
except ImportError:
_templated_ring_attention = None

para_attn_ops = torch.ops.para_attn

__all__ = [
Expand Down Expand Up @@ -106,6 +110,8 @@ def forward(
scale,
mesh,
):
assert _templated_ring_attention is not None, "RingAttnFunc requires a newer version of PyTorch"

out, lse = _templated_ring_attention(
mesh,
para_attn_ops.attention_forward_with_lse,
Expand Down Expand Up @@ -283,13 +289,27 @@ def __init__(self, mesh=None):
self._parallel_method = "ulysses"

if mesh is None:
self._ulysses_mesh = None
self._ulysses_mesh = DP.get_default_group()
self._ring_mesh = None
else:
assert isinstance(mesh, DeviceMesh), "mesh must be a DeviceMesh"
assert mesh.mesh.ndim == 2, "mesh must be 2D, got {}".format(mesh.mesh.ndim)
self._ulysses_mesh = mesh["ulysses"]
self._ring_mesh = mesh["ring"]
if isinstance(mesh, dist.ProcessGroup):
self._ulysses_mesh = mesh
self._ring_mesh = None
else:
assert isinstance(mesh, dist.DeviceMesh), "mesh must be a ProcessGroup or DeviceMesh"

if "ulysses" in mesh.mesh_dim_names:
self._ulysses_mesh = mesh["ulysses"]
else:
self._ulysses_mesh = None
if "ring" in mesh.mesh_dim_names:
self._ring_mesh = mesh["ring"]
else:
self._ring_mesh = None

assert (
self._ulysses_mesh is not None or self._ring_mesh is not None
), "mesh must have ulysses or ring dim"

def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
Expand All @@ -300,10 +320,14 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
if func is torch.nn.functional.scaled_dot_product_attention:
parallel_method = self._parallel_method
if parallel_method == "ulysses":
with self._set_parallel_method("none" if self._ring_mesh is None else "ring"), self:
with self._set_parallel_method("ring"), self:
if self._ulysses_mesh is None:
return func(*args, **kwargs)
return ulysses_attn_func(*args, **kwargs, mesh=self._ulysses_mesh)
elif parallel_method == "ring":
with self._set_parallel_method("none"), self:
if self._ring_mesh is None:
return func(*args, **kwargs)
return ring_attn_func(*args, **kwargs, mesh=self._ring_mesh)
elif parallel_method == "none":
return func(*args, **kwargs)
Expand Down
Loading

0 comments on commit ffb4f05

Please sign in to comment.