Skip to content

Commit

Permalink
make hunyuanvideo faster and more memory efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Dec 19, 2024
1 parent e91b702 commit a9defb2
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 56 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ supporting both [**Ulysses Style**](https://arxiv.org/abs/2309.14509) and [**Rin

This aims to provide:

- [x] An easy to use interface to speed up model inference with context parallel and `torch.compile`. Make `FLUX`, `HunyuanVideo` and `Mochi` inference much faster losslessly.
- [x] An easy to use interface to speed up model inference with context parallel and `torch.compile`. Make **`FLUX`**, **`HunyuanVideo`** and **`Mochi`** inference much faster losslessly.
- [x] A unified interface to run context parallel attention (***cfg-ulysses-ring***), as well as keeping the maximum performance while working with `torch.compile`
- [ ] The fastest accurate attention implemented in Triton, running 50% faster than the originial FA2 implementation on RTX 4090.

Expand Down Expand Up @@ -34,7 +34,7 @@ torchrun --nproc_per_node=2 examples/run_flux.py
- [CogVideoX](examples/run_cogvideox.py)

**NOTE**: To run `HunyuanVideo`, you need to install `diffusers` from its latest master branch.
It is suggested to run `HunyuanVideo` with GPUs with 80GB memory, or you might experience OOM errors,
It is suggested to run `HunyuanVideo` with GPUs with at least 48GB memory, or you might experience OOM errors,
and the performance might be worse due to frequent memory re-allocation.

# Performance
Expand Down Expand Up @@ -197,10 +197,6 @@ parallelize_pipe(
)
parallelize_vae(pipe.vae, mesh=mesh._flatten())

# Fix OOM because of awful inductor lowering of attn_bias of _scaled_dot_product_efficient_attention
# import para_attn
# para_attn.config.attention.force_dispatch_to_custom_ops = True

# torch._inductor.config.reorder_for_compute_comm_overlap = True
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

Expand Down
4 changes: 0 additions & 4 deletions examples/run_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@
)
parallelize_vae(pipe.vae, mesh=mesh._flatten())

# Fix OOM because of awful inductor lowering of attn_bias of _scaled_dot_product_efficient_attention
# import para_attn
# para_attn.config.attention.force_dispatch_to_custom_ops = True

# torch._inductor.config.reorder_for_compute_comm_overlap = True
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

Expand Down
142 changes: 99 additions & 43 deletions src/para_attn/context_parallel/diffusers_adapters/hunyuan_video.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,121 @@
import functools
import itertools
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput

import para_attn.primitives as DP
from para_attn.context_parallel import init_context_parallel_mesh
from para_attn.para_attn_interface import UnifiedAttnMode
from para_attn.para_attn_interface import SparseKVAttnMode, UnifiedAttnMode


def parallelize_transformer(transformer: HunyuanVideoTransformer3DModel, *, mesh=None):
mesh = init_context_parallel_mesh(transformer.device.type, mesh=mesh)
batch_mesh = mesh["batch"]
seq_mesh = mesh["ring", "ulysses"]._flatten()

original_forward = transformer.forward

@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
*args,
**kwargs,
):
with UnifiedAttnMode(mesh):
output = original_forward(*args, **kwargs)
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
guidance: torch.Tensor = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p

# 1. RoPE
image_rotary_emb = self.rope(hidden_states)

# 2. Conditional embeddings
temb = self.time_text_embed(timestep, guidance, pooled_projections)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)

# 3. Attention mask preparation
latent_sequence_length = hidden_states.shape[1]
latent_attention_mask = torch.ones(
batch_size, 1, latent_sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, 1, N]
attention_mask = torch.cat(
[latent_attention_mask, encoder_attention_mask.unsqueeze(1).to(torch.bool)], dim=-1
) # [B, 1, N + M]

with SparseKVAttnMode(), UnifiedAttnMode(mesh):
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}

for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)

else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)

# 5. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

return output
if not return_dict:
return (hidden_states,)

return Transformer2DModelOutput(sample=hidden_states)

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward

original_context_embedder_forward = transformer.context_embedder.forward

@functools.wraps(transformer.context_embedder.__class__.forward)
def new_context_embedder_forward(
self,
*args,
**kwargs,
):
with UnifiedAttnMode.disable():
output = original_context_embedder_forward(*args, **kwargs)

return output

new_context_embedder_forward = new_context_embedder_forward.__get__(transformer.context_embedder)
transformer.context_embedder.forward = new_context_embedder_forward

"""
torch._dynamo hit config.cache_size_limit (8)
function: 'new_transformer_block_forward'
Expand Down Expand Up @@ -76,26 +145,13 @@ def new_transformer_block_forward(
**kwargs,
):
world_size = DP.get_world_size(seq_mesh)
if attention_mask is not None and world_size > 1:
if attention_mask is not None:
assert attention_mask.shape[0] == 1, "Only support batch size 1 for now"

hidden_states_len = hidden_states.shape[-2]
encoder_hidden_states_len = encoder_hidden_states.shape[-2]

new_attention_mask = []
for i in range(world_size):
new_attention_mask.append(
attention_mask[..., i * hidden_states_len : (i + 1) * hidden_states_len, :]
)
new_attention_mask.append(
attention_mask[
...,
world_size * hidden_states_len
+ i * encoder_hidden_states_len : world_size * hidden_states_len
+ (i + 1) * encoder_hidden_states_len,
:,
]
)
new_attention_mask = torch.cat(new_attention_mask, dim=-2)
attention_mask = new_attention_mask
attention_mask = attention_mask[:1, ..., :1, :]

new_attention_mask = []
for i in range(world_size):
Expand Down
84 changes: 84 additions & 0 deletions src/para_attn/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,87 @@ def _(
is_causal=is_causal,
scale=scale,
)


def _attention_forward_sparse_kv(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
*,
scale=None,
):
if attn_mask is None:
return aten.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)

assert attn_mask.dtype == torch.bool, "attn_mask must be a boolean tensor"

s_kv = key.shape[-2]
while attn_mask.ndim > 1:
attn_mask = attn_mask[0]
indices = torch.arange(s_kv, device=key.device)
indices = indices[attn_mask]
key = key[..., indices, :]
value = value[..., indices, :]
return aten.scaled_dot_product_attention(
query,
key,
value,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)


@_torch_custom_op_wrapper("para_attn::attention_forward_sparse_kv", mutates_args=(), device_types=("cpu", "cuda"))
def attention_forward_sparse_kv(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> torch.Tensor:
return _attention_forward_sparse_kv(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)


@_torch_register_fake_wrapper("para_attn::attention_forward_sparse_kv")
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> torch.Tensor:
return _attention_forward_sparse_kv(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
64 changes: 64 additions & 0 deletions src/para_attn/para_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,27 @@ def in_batch_attn_func(
return out


def sparse_kv_attn_func(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
*,
scale=None,
):
return para_attn_ops.attention_forward_sparse_kv(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)


def _get_arg(args, kwargs, *field):
if len(field) == 1:
if isinstance(field, int):
Expand Down Expand Up @@ -458,3 +479,46 @@ def _set_disabled(cls, value):
old_disabled = cls.disabled
cls.disabled = value
return old_disabled


class SparseKVAttnMode(TorchFunctionMode):
disabled = False

@torch.compiler.disable()
def __init__(self):
super().__init__()

def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if SparseKVAttnMode.disabled:
return func(*args, **kwargs)

if func is torch.nn.functional.scaled_dot_product_attention:
return sparse_kv_attn_func(*args, **kwargs)

return func(*args, **kwargs)

@torch.compiler.disable()
def __enter__(self):
super().__enter__()

@torch.compiler.disable()
def __exit__(self, *args):
super().__exit__(*args)

@classmethod
@contextlib.contextmanager
def disable(cls):
old_disabled = cls._set_disabled(True)
try:
yield
finally:
cls._set_disabled(old_disabled)

@classmethod
@torch.compiler.disable()
def _set_disabled(cls, value):
old_disabled = cls.disabled
cls.disabled = value
return old_disabled
6 changes: 3 additions & 3 deletions tests/context_parallel/test_diffusers_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def new_pipe(self, dtype, device):
)

# Fix OOM because of awful inductor lowering of attn_bias of _scaled_dot_product_efficient_attention
import para_attn

para_attn.config.attention.force_dispatch_to_custom_ops = True
# import para_attn
#
# para_attn.config.attention.force_dispatch_to_custom_ops = True

return pipe

Expand Down

0 comments on commit a9defb2

Please sign in to comment.