Skip to content

Commit

Permalink
fix causal ulysses attn
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Nov 16, 2024
1 parent ffb4f05 commit b92e7e8
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 20 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ rank = dist.get_rank()

assert world_size <= torch.cuda.device_count()
if world_size % 2 == 0:
mesh_shape = (world_size // 2, 2)
mesh_shape = (2, world_size // 2)
else:
mesh_shape = (world_size, 1)
mesh_shape = (1, world_size)

B, H, S_Q, S_KV, D = 2, 24, 4096, 4096, 64
dtype = torch.float16
Expand Down Expand Up @@ -204,7 +204,7 @@ with torch.no_grad(), torch.cuda.device(rank):
func = torch.compile(func)

for _ in range(2):
mesh = dist.init_device_mesh(device, mesh_shape, mesh_dim_names=("ulysses", "ring"))
mesh = dist.init_device_mesh(device, mesh_shape, mesh_dim_names=("ring", "ulysses"))
with para_attn_interface.UnifiedAttnMode(mesh):
out_slice = func(
query_slice,
Expand Down
8 changes: 4 additions & 4 deletions src/para_attn/context_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def init_context_parallel_mesh(
device_type=None, *, mesh=None, max_batch_dim_size=None, max_ulysses_dim_size=None, max_ring_dim_size=None
device_type=None, *, mesh=None, max_batch_dim_size=None, max_ring_dim_size=None, max_ulysses_dim_size=None
):
if mesh is not None:
return mesh
Expand All @@ -22,7 +22,7 @@ def init_context_parallel_mesh(
attn_world_size = world_size // batch_dim_size

assert not (
max_ulysses_dim_size is not None and max_ring_dim_size is not None
max_ring_dim_size is not None and max_ulysses_dim_size is not None
), "Only one of max_ulysses_dim_size and max_ring_dim_size can be set"

if max_ulysses_dim_size is None:
Expand All @@ -35,5 +35,5 @@ def init_context_parallel_mesh(
ulysses_dim_size = math.gcd(attn_world_size, max_ulysses_dim_size)
ring_dim_size = attn_world_size // ulysses_dim_size

mesh_shape = (batch_dim_size, ulysses_dim_size, ring_dim_size)
return dist.init_device_mesh(device_type, mesh_shape, mesh_dim_names=("batch", "ulysses", "ring"))
mesh_shape = (batch_dim_size, ring_dim_size, ulysses_dim_size)
return dist.init_device_mesh(device_type, mesh_shape, mesh_dim_names=("batch", "ring", "ulysses"))
2 changes: 1 addition & 1 deletion src/para_attn/context_parallel/diffusers_adapters/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parallelize_transformer(transformer: FluxTransformer2DModel, *, mesh=None) -

mesh = init_context_parallel_mesh(transformer.device.type, mesh=mesh)
batch_mesh = mesh["batch"]
seq_mesh = mesh["ulysses", "ring"]._flatten()
seq_mesh = mesh["ring", "ulysses"]._flatten()

original_forward = transformer.forward

Expand Down
2 changes: 1 addition & 1 deletion src/para_attn/context_parallel/diffusers_adapters/mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parallelize_transformer(transformer: MochiTransformer3DModel, *, mesh=None)

mesh = init_context_parallel_mesh(transformer.device.type, mesh=mesh)
batch_mesh = mesh["batch"]
seq_mesh = mesh["ulysses", "ring"]._flatten()
seq_mesh = mesh["ring", "ulysses"]._flatten()

original_forward = transformer.forward

Expand Down
6 changes: 0 additions & 6 deletions src/para_attn/para_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,6 @@ def ulysses_attn_func(
query = _sdpa_input_all_to_all(query, mesh)
key = _sdpa_input_all_to_all(key, mesh)
value = _sdpa_input_all_to_all(value, mesh)
if attn_mask is not None:
s_q, s_kv = query.size(-2), key.size(-2)
if attn_mask.size(-1) != s_kv:
attn_mask = DP.get_complete_tensor(attn_mask, dim=-1, group=mesh)
elif attn_mask.size(-2) != s_q:
attn_mask = DP.get_complete_tensor(attn_mask, dim=-2, group=mesh)

out = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale
Expand Down
2 changes: 1 addition & 1 deletion tests/context_parallel/test_diffusers_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_benchmark_pipe(self, dtype, device, parallelize, compile, use_batch, us
super()._test_benchmark_pipe(dtype, device, parallelize, compile, use_batch, use_ring)


# instantiate_parametrized_tests(DiffusionPipelineTest)
instantiate_parametrized_tests(DiffusionPipelineTest)
instantiate_parametrized_tests(FluxPipelineTest)
instantiate_parametrized_tests(MochiPipelineTest)

Expand Down
9 changes: 5 additions & 4 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ class UnifiedAttnTest(ParallelAttnTest):
def attn_mode(self, device):
world_size = self.world_size
if world_size % 2 == 0:
mesh_shape = (world_size // 2, 2)
mesh_shape = (2, world_size // 2)
else:
mesh_shape = (world_size, 1)
mesh = dist.init_device_mesh(device, mesh_shape, mesh_dim_names=("ulysses", "ring"))
mesh_shape = (1, world_size)
mesh = dist.init_device_mesh(device, mesh_shape, mesh_dim_names=("ring", "ulysses"))
return para_attn_interface.UnifiedAttnMode(mesh)

@pytest.mark.skipif("not torch.cuda.is_available()")
Expand All @@ -227,7 +227,8 @@ def test_attn_mode(self, dtype, device, B, H, S_Q, S_KV, D, is_causal, compile):
super()._test_attn_mode(dtype, device, B, H, S_Q, S_KV, D, is_causal, compile)


# instantiate_parametrized_tests(RingAttnTest)
instantiate_parametrized_tests(ParallelAttnTest)
instantiate_parametrized_tests(RingAttnTest)
instantiate_parametrized_tests(UlyssesAttnTest)
instantiate_parametrized_tests(UnifiedAttnTest)

Expand Down

0 comments on commit b92e7e8

Please sign in to comment.