Skip to content

Commit

Permalink
Implement ring attention primative for Jax.
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
  • Loading branch information
mgoldfarb-nvidia committed Nov 5, 2024
1 parent 77c37d4 commit c67cc11
Show file tree
Hide file tree
Showing 5 changed files with 594 additions and 4 deletions.
8 changes: 8 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource

# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
Expand Down Expand Up @@ -400,6 +402,10 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
@pytest.mark.parametrize(
"cp_strategy",
[pytest.param(CPStrategy.ALL_GATHER, id="AG"), pytest.param(CPStrategy.RING, id="RING")],
)
def test_contex_parallel_self_attn(
self,
device_count,
Expand All @@ -412,6 +418,7 @@ def test_contex_parallel_self_attn(
dtype,
qkv_layout,
load_balanced,
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
Expand Down Expand Up @@ -469,6 +476,7 @@ def target_func(q, k, v, mask):
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_strategy=cp_strategy,
context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp",
).astype(dtype)
Expand Down
25 changes: 24 additions & 1 deletion transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ class QKVFormat(Enum):
THD = NVTE_QKV_Format.NVTE_THD


class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention.
DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
ALL_GATHER: All-gather/reduce scatter implementation.
RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
"""

DEFAULT = 0
ALL_GATHER = 1
RING = 2


def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
Expand Down Expand Up @@ -260,6 +273,7 @@ def fused_attn(
dropout_probability: float,
is_training: bool,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
Expand Down Expand Up @@ -347,6 +361,7 @@ def fused_attn(
is_training=is_training,
max_segments_per_seq=1,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand All @@ -370,6 +385,7 @@ def fused_attn_thd(
is_training: bool,
max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
Expand Down Expand Up @@ -470,14 +486,15 @@ def fused_attn_thd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)

return output


@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
Expand All @@ -494,6 +511,7 @@ def _fused_attn(
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]],
context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
):
Expand All @@ -513,6 +531,7 @@ def _fused_attn(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
)
Expand All @@ -535,6 +554,7 @@ def _fused_attn_fwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
):
Expand All @@ -554,6 +574,7 @@ def _fused_attn_fwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand Down Expand Up @@ -582,6 +603,7 @@ def _fused_attn_bwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
ctx,
Expand Down Expand Up @@ -617,6 +639,7 @@ def _fused_attn_bwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand Down
Loading

0 comments on commit c67cc11

Please sign in to comment.