Skip to content

Commit

Permalink
Implement ring attention primative for Jax.
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
  • Loading branch information
mgoldfarb-nvidia committed Nov 5, 2024
1 parent 23caab3 commit 0f0dd3d
Show file tree
Hide file tree
Showing 5 changed files with 657 additions and 4 deletions.
8 changes: 8 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource

# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
Expand Down Expand Up @@ -402,6 +404,10 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
@pytest.mark.parametrize(
"cp_strategy",
[pytest.param(CPStrategy.ALL_GATHER, id="AG"), pytest.param(CPStrategy.RING, id="RING")],
)
def test_contex_parallel_self_attn(
self,
device_count,
Expand All @@ -414,6 +420,7 @@ def test_contex_parallel_self_attn(
dtype,
qkv_layout,
load_balanced,
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
Expand Down Expand Up @@ -461,6 +468,7 @@ def target_func(q, k, v, mask):
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_strategy=cp_strategy,
context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp",
).astype(dtype)
Expand Down
25 changes: 24 additions & 1 deletion transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ class QKVFormat(Enum):
THD = NVTE_QKV_Format.NVTE_THD


class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention.
DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
ALL_GATHER: All-gather/reduce scatter implementation.
RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
"""

DEFAULT = 0
ALL_GATHER = 1
RING = 2


def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
Expand Down Expand Up @@ -266,6 +279,7 @@ def fused_attn(
dropout_probability: float,
is_training: bool,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
Expand Down Expand Up @@ -353,6 +367,7 @@ def fused_attn(
is_training=is_training,
max_segments_per_seq=1,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand All @@ -376,6 +391,7 @@ def fused_attn_thd(
is_training: bool,
max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
Expand Down Expand Up @@ -476,14 +492,15 @@ def fused_attn_thd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)

return output


@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
Expand All @@ -500,6 +517,7 @@ def _fused_attn(
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]],
context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
):
Expand All @@ -519,6 +537,7 @@ def _fused_attn(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
)
Expand All @@ -541,6 +560,7 @@ def _fused_attn_fwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
):
Expand All @@ -560,6 +580,7 @@ def _fused_attn_fwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand Down Expand Up @@ -588,6 +609,7 @@ def _fused_attn_bwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
ctx,
Expand Down Expand Up @@ -623,6 +645,7 @@ def _fused_attn_bwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand Down
Loading

0 comments on commit 0f0dd3d

Please sign in to comment.