Skip to content

Commit

Permalink
Implement ring attention primative for Jax.
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
  • Loading branch information
mgoldfarb-nvidia committed Nov 7, 2024
1 parent e5ffaa7 commit ad91add
Show file tree
Hide file tree
Showing 8 changed files with 701 additions and 36 deletions.
9 changes: 8 additions & 1 deletion qa/L1_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,11 @@
set -xe

: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*

# Skip ring attention tests since they need fixed environment vars
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'

# Test ring attention with and without scan loop
NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
120 changes: 89 additions & 31 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource

# We will use the golden reference model from our non distributed attention test fixture.
from test_fused_attn import general_dot_product_attention, make_mask
Expand Down Expand Up @@ -333,6 +335,36 @@ def ref_func(query, kv, mask):
)


@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
class TestDistributedContextParallelSelfAttn:

def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
Expand Down Expand Up @@ -370,37 +402,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args

@pytest.mark.parametrize(
"device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
)
@pytest.mark.parametrize(
"data_shape",
[
pytest.param([2, 512, 12, 128], id="2-512-12-128"),
pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
],
)
@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
],
)
@pytest.mark.parametrize("dtype", [jnp.bfloat16])
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"load_balanced",
[pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
)
def test_contex_parallel_self_attn(
def impl_test_contex_parallel_attn(
self,
device_count,
mesh_shape,
Expand All @@ -412,6 +414,7 @@ def test_contex_parallel_self_attn(
dtype,
qkv_layout,
load_balanced,
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
Expand Down Expand Up @@ -469,6 +472,7 @@ def target_func(q, k, v, mask):
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
context_parallel_strategy=cp_strategy,
context_parallel_causal_load_balanced=load_balanced,
context_parallel_axis="cp",
).astype(dtype)
Expand Down Expand Up @@ -574,6 +578,60 @@ def grad_func(func, *args, **kwargs):

assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)

def test_contex_parallel_allgather_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.ALL_GATHER,
)

def test_context_parallel_ring_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
kv_groups,
attn_mask_type,
dtype,
qkv_layout,
load_balanced,
CPStrategy.RING,
)


class TestReorderCausalLoadBalancing:
@pytest.mark.parametrize("cp_size", [2, 4, 8])
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial
from math import sqrt
from typing import Tuple, Optional
import random

import jax
import jax.numpy as jnp
Expand Down
40 changes: 40 additions & 0 deletions tests/jax/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest
from functools import partial
import os

from transformer_engine.jax.cpp_extensions.misc import get_xla_flag


@pytest.fixture(autouse=True, scope="function")
def preserve_xla_flags():
"""Ensures the XLA flags environment variable is restored after any tests in this file run."""
old_flags = os.getenv("XLA_FLAGS")
yield
if old_flags is not None:
os.environ["XLA_FLAGS"] = old_flags


def test_get_xla_flag(request):
os.environ["XLA_FLAGS"] = ""
assert get_xla_flag("") is None
assert get_xla_flag("--foo") is None
assert get_xla_flag("--bar=1") is None

os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz"
assert get_xla_flag("--foo") == True
assert get_xla_flag("--bar") == "1"
assert get_xla_flag("--bar", cast=int) == 1
assert get_xla_flag("--bar", cast=bool) == True
assert get_xla_flag("--baz") == "biz"
with pytest.raises(ValueError):
# cast will fail
assert get_xla_flag("--baz", cast=int)
assert get_xla_flag("--xla") is None

os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb"
assert get_xla_flag("--xla_abc") == True
assert get_xla_flag("--xla_abb") == True
25 changes: 24 additions & 1 deletion transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ class QKVFormat(Enum):
THD = NVTE_QKV_Format.NVTE_THD


class CPStrategy(Enum):
"""Defines the context parallel strategies of Jax fused attention.
DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
ALL_GATHER: All-gather/reduce scatter implementation.
RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
"""

DEFAULT = 0
ALL_GATHER = 1
RING = 2


def get_qkv_format(qkv_layout):
"""
Get qkv_format from qkv_layout
Expand Down Expand Up @@ -260,6 +273,7 @@ def fused_attn(
dropout_probability: float,
is_training: bool,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
Expand Down Expand Up @@ -347,6 +361,7 @@ def fused_attn(
is_training=is_training,
max_segments_per_seq=1,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand All @@ -370,6 +385,7 @@ def fused_attn_thd(
is_training: bool,
max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
Expand Down Expand Up @@ -470,14 +486,15 @@ def fused_attn_thd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)

return output


@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
Expand All @@ -494,6 +511,7 @@ def _fused_attn(
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]],
context_parallel_strategy: CPStrategy,
context_parallel_causal_load_balanced: bool,
context_parallel_axis: str,
):
Expand All @@ -513,6 +531,7 @@ def _fused_attn(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
)
Expand All @@ -535,6 +554,7 @@ def _fused_attn_fwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
):
Expand All @@ -554,6 +574,7 @@ def _fused_attn_fwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand Down Expand Up @@ -582,6 +603,7 @@ def _fused_attn_bwd_rule(
is_training,
max_segments_per_seq,
window_size,
context_parallel_strategy,
context_parallel_causal_load_balanced,
context_parallel_axis,
ctx,
Expand Down Expand Up @@ -617,6 +639,7 @@ def _fused_attn_bwd_rule(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=window_size,
context_parallel_strategy=context_parallel_strategy,
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
context_parallel_axis=context_parallel_axis,
)
Expand Down
Loading

0 comments on commit ad91add

Please sign in to comment.