From 0f0dd3d35db1f2ce91b00c982c1bf516a622a24c Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Thu, 31 Oct 2024 18:56:45 +0000 Subject: [PATCH] Implement ring attention primative for Jax. Signed-off-by: Michael Goldfarb --- tests/jax/test_distributed_fused_attn.py | 8 + transformer_engine/jax/attention.py | 25 +- .../jax/cpp_extensions/attention.py | 609 +++++++++++++++++- transformer_engine/jax/cpp_extensions/misc.py | 17 + transformer_engine/jax/sharding.py | 2 +- 5 files changed, 657 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 23a26087d4..d11bdef844 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -35,7 +35,9 @@ get_qkv_format, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, + CPStrategy, ) +from transformer_engine.jax.sharding import MeshResource # We will use the golden reference model from our non distributed attention test fixture. from test_fused_attn import general_dot_product_attention, make_mask @@ -402,6 +404,10 @@ def qkv_to_layout(self, q, k, v, qkv_layout): "load_balanced", [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")], ) + @pytest.mark.parametrize( + "cp_strategy", + [pytest.param(CPStrategy.ALL_GATHER, id="AG"), pytest.param(CPStrategy.RING, id="RING")], + ) def test_contex_parallel_self_attn( self, device_count, @@ -414,6 +420,7 @@ def test_contex_parallel_self_attn( dtype, qkv_layout, load_balanced, + cp_strategy, ): attn_bias_type = AttnBiasType.NO_BIAS dropout_prob = 0.0 @@ -461,6 +468,7 @@ def target_func(q, k, v, mask): scaling_factor=scaling_factor, dropout_probability=dropout_prob, is_training=is_training, + context_parallel_strategy=cp_strategy, context_parallel_causal_load_balanced=load_balanced, context_parallel_axis="cp", ).astype(dtype) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index b3b11bb9dd..74b64a0de4 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -79,6 +79,19 @@ class QKVFormat(Enum): THD = NVTE_QKV_Format.NVTE_THD +class CPStrategy(Enum): + """Defines the context parallel strategies of Jax fused attention. + + DEFAULT: Default strategy will choose automatically if context parallel axis is sharded. + ALL_GATHER: All-gather/reduce scatter implementation. + RING: Ring attention implementation (https://arxiv.org/abs/2310.01889). + """ + + DEFAULT = 0 + ALL_GATHER = 1 + RING = 2 + + def get_qkv_format(qkv_layout): """ Get qkv_format from qkv_layout @@ -266,6 +279,7 @@ def fused_attn( dropout_probability: float, is_training: bool, window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -353,6 +367,7 @@ def fused_attn( is_training=is_training, max_segments_per_seq=1, window_size=window_size, + context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -376,6 +391,7 @@ def fused_attn_thd( is_training: bool, max_segments_per_seq: int = 1, window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -476,6 +492,7 @@ def fused_attn_thd( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=window_size, + context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -483,7 +500,7 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -500,6 +517,7 @@ def _fused_attn( is_training: bool, max_segments_per_seq: int, window_size: Optional[Tuple[int, int]], + context_parallel_strategy: CPStrategy, context_parallel_causal_load_balanced: bool, context_parallel_axis: str, ): @@ -519,6 +537,7 @@ def _fused_attn( is_training, max_segments_per_seq, window_size, + context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, ) @@ -541,6 +560,7 @@ def _fused_attn_fwd_rule( is_training, max_segments_per_seq, window_size, + context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, ): @@ -560,6 +580,7 @@ def _fused_attn_fwd_rule( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=window_size, + context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -588,6 +609,7 @@ def _fused_attn_bwd_rule( is_training, max_segments_per_seq, window_size, + context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, ctx, @@ -623,6 +645,7 @@ def _fused_attn_bwd_rule( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=window_size, + context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index ae74a8ca46..809fb03224 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -17,6 +17,8 @@ from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi +from transformer_engine.jax.attention import CPStrategy + from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import ( NVTE_Bias_Type, @@ -35,6 +37,7 @@ get_padded_spec, get_cudnn_version, is_ffi_enabled, + get_xla_flag, ) from ..sharding import ( global_mesh_resource, @@ -1411,6 +1414,592 @@ def _cross_attn_bwd( register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) +@dataclass(frozen=True) +class _FusedAttnCPWithP2PHelper: + """Helper class to assist with running the P2P ring strategy for CP attention.""" + + mesh: jax.sharding.Mesh + config: _FusedAttnConfig + + @property + @cache + def use_scanloop(self): + use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1"))) + # nvbug(4675071): Disable the HLO verifier for channel ID checks. + # A WAR was added to XLA: https://github.com/openxla/xla/pull/16779 + return use_scan and bool(get_xla_flag("--xla_experimental_ignore_channel_id", False)) + + def check_supported(self): + """Checks if the context parallel implementation is supported by the given arguments.""" + header = "Context parallel fused ring attention" + + allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + if self.config.qkv_layout not in allowed_layouts: + raise ValueError( + f"{header} only supports layouts:" + f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}" + ) + + if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: + raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") + + allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + if self.config.attn_mask_type not in allowed_masks: + raise ValueError( + f"{header} only supports masking types: " + f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" + ) + + if self.config.max_segments_per_seq != 1: + raise ValueError( + f"{header} only supports max_segments_per_seq == 1 got:" + f" {self.config.max_segments_per_seq}" + ) + + if self.config.dropout_probability != 0.0: + raise ValueError(f"{header} does not support dropout") + + # We want to encourage use of scan loop to minimize unrolling and ensure more + # predictable scheduling from XLA. The unrolled flavor will be supported but + # not the prefered implementation. + if not self.use_scanloop: + warnings.warn( + "Scan loop is disabled for fused ring attention. To enable set" + " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment and" + " XLA_FLAGS=--xla_experimental_ignore_channel_id=true" + ) + + def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: + """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + return _FusedAttnConfig( + attn_bias_type=self.config.attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + scaling_factor=self.config.scaling_factor, + dropout_probability=self.config.dropout_probability, + is_training=self.config.is_training, + max_segments_per_seq=self.config.max_segments_per_seq, + window_size=self.config.window_size, + context_parallel_load_balanced=self.config.context_parallel_load_balanced, + cp_axis=self.config.cp_axis, + ) + + def stack_kv(self, k, v): + """Stacks k and v tensors if not stacked.""" + _not_used = jnp.zeros(0, dtype=k.dtype) + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return k + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return jnp.stack([k, v], axis=2) + return _not_used + + def unstack_kv(self, kv): + """Un-stacks k and v tensors if not stacked.""" + _not_used = jnp.zeros(0, dtype=kv.dtype) + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return kv, _not_used + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return jnp.unstack(kv, axis=2) + return _not_used, _not_used # fall through + + def permute_kv(self, kv, cp_perm): + """Permutes kv around the ring as described by cp_perm.""" + return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm) + + def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step): + """Apply soft max correction after an attention step.""" + max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step) + min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step) + new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale)) + return new_softmax_aux + + def adjust_seqlen(self, seqlen, max_seqlen, idx): + """Adjust the sequence length per step.""" + seqlen_of_curr_step = seqlen - max_seqlen * idx + seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step) + seqlen_per_step = jnp.where( + seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen + ) + return seqlen_per_step + + +class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Ring Attention Forward Primitive + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithP2PHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def ring_attn_fwd_impl( + q, + k, + v, + bias, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + seed, + ): + _not_used = jnp.zeros(0, dtype=v.dtype) + + # Combine KV tensors if separate for better permute scheduling and performance. + # Eventually XLA should perform this automatically. + kv = helper.stack_kv(k, v) + + batch, q_max_seqlen, head, _ = q.shape + kv_max_seqlen = k.shape[1] + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] + + output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=jnp.float32) + softmax_aux_per_steps = jnp.zeros( + (cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32 + ) + softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32) + + # RNG shape should be the shared shape. This is unused for ring attention as we do not + # support dropout currently. + rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) + + def scan_kv_block(idx, carry): + kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry + + # Send KV block to next step so we can overlap compute. + kv_next = helper.permute_kv(kv, cp_perm) + + def causal_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + output_per_step, softmax_aux_per_step, rng_state_per_step = ( + FusedAttnFwdPrimitive.impl( + q, + kv, + _not_used, + bias, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + seed, + helper.get_step_config(NVTE_Mask_Type.NVTE_CAUSAL_MASK), + ) + ) + return output_per_step, softmax_aux_per_step + + def half_kv_no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2 + kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1) + output_per_step, softmax_aux_per_step, rng_state_per_step = ( + FusedAttnFwdPrimitive.impl( + q, + kv_part, + _not_used, + bias, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + ) + return output_per_step, softmax_aux_per_step + + def half_q_no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2 + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + q_part = q[:, q_max_seqlen // 2 :, :, :] + output_per_step, softmax_aux_per_step, rng_state_per_step = ( + FusedAttnFwdPrimitive.impl( + q_part, + kv, + _not_used, + bias, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + ) + output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1) + softmax_aux_per_step = jnp.concat( + [ + jnp.full_like(softmax_aux_per_step, -jnp.inf), + softmax_aux_per_step, + ], + axis=2, + ) + return output_per_step, softmax_aux_per_step + + def no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + output_per_step, softmax_aux_per_step, rng_state_per_step = ( + FusedAttnFwdPrimitive.impl( + q, + kv, + _not_used, + bias, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + ) + return output_per_step, softmax_aux_per_step + + def skip_compute(): + output_per_step = jnp.zeros_like(q) + softmax_aux_per_step = jnp.full( + (batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32 + ) + return output_per_step, softmax_aux_per_step + + if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + # This is for nested jax.lax.cond + def jax_cond_wrap(): + if config.context_parallel_load_balanced: + return lax.cond( + (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute + ) + return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute) + + output_per_step, softmax_aux_per_step = lax.cond( + idx == 0, causal_mask_compute, jax_cond_wrap + ) + else: + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl( + q, + kv, + _not_used, + bias, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=helper.get_step_config(config.attn_mask_type) + ) + + softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step) + output_per_steps = output_per_steps.at[idx].set(output_per_step.astype(jnp.float32)) + softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step) + + return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps) + + carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) + if helper.use_scanloop: + carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) + else: + for i in range(0, cp_size): + carry = scan_kv_block(i, carry) + (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry + + output = jnp.zeros(q.shape).astype(jnp.float32) + for idx in range(cp_size): + output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp( + softmax_aux_per_steps[idx] - softmax_aux + ).transpose(0, 2, 1, 3) + output = output.astype(q.dtype) + return output, softmax_aux, rng_state + + return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings + + +register_primitive(FusedRingAttnFwdPrimitive) + + +class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Ring Attention Backward Primitive + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + helper = _FusedAttnCPWithP2PHelper(mesh, config) + helper.check_supported() + + def ring_attn_bwd_impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + ): + _not_used = jnp.zeros(0, dtype=output.dtype) + + # Combine KV tensors if separate for better permute scheduling and performance. + # Eventually XLA should perform this automatically. + kv = helper.stack_kv(k, v) + + q_max_seqlen = q.shape[1] + kv_max_seqlen = k.shape[1] + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] + + dq = jnp.zeros_like(q) + dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v)) + dbias = jnp.zeros_like(bias) + + def scan_kv_block(idx, carry): + + kv, dq, dq_prev_step, dk_dv, dk_dv_prev_step, dbias, dbias_prev_step = carry + + def accum_grads(dq, dq_prev_step, dk_dv, dk_dv_prev_step, dbias, dbias_prev_step): + dq = dq + dq_prev_step + dk_dv = dk_dv + dk_dv_prev_step + if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + dbias = dbias + dbias_prev_step + return dq, dk_dv, dbias + + def skip_accum_grads( + dq, dq_prev_step, dk_dv, dk_dv_prev_step, dbias, dbias_prev_step + ): + return dq, dk_dv, dbias + + # Accumulate gradient from previous iteration. + dq, dk_dv, dbias = lax.cond( + (idx > 0), + accum_grads, + skip_accum_grads, + dq, + dq_prev_step, + dk_dv, + dk_dv_prev_step, + dbias, + dbias_prev_step, + ) + + # Start communication that feeds the next iteraton. + # We further combine the tensors to improve overlap. + kv_dk_dv = jnp.stack([kv, dk_dv]) + kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm) + + def causal_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q, + kv, + _not_used, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_CAUSAL_MASK), + ) + return dq_per_step, dk_dv_per_step, dbias_per_step + + def half_kv_no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2 + kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1) + dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q, + kv_part, + _not_used, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + dk_dv_per_step = jnp.concat( + [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1 + ) + return dq_per_step, dk_dv_per_step, dbias_per_step + + def half_q_no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2 + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + # doutput_part = lax.slice_in_dim(doutput, 0, q_max_seqlen // 2, axis=1) + # output_part = lax.slice_in_dim(output, 0, q_max_seqlen // 2, axis=1) + # softmax_aux_part = lax.slice_in_dim(softmax_aux, 0, q_max_seqlen // 2, axis=2) + # q_part = lax.slice_in_dim(q, 0, q_max_seqlen // 2, axis=1) + doutput_part = doutput[:, q_max_seqlen // 2 :, :, :] + output_part = output[:, q_max_seqlen // 2 :, :, :] + softmax_aux_part = softmax_aux[:, :, q_max_seqlen // 2 :, 1] + q_part = q[:, q_max_seqlen // 2 :, :, :] + dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q_part, + kv, + _not_used, + bias, + softmax_aux_part, + rng_state, + output_part, + doutput_part, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1) + return dq_per_step, dk_dv_per_step, dbias_per_step + + def no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q, + kv, + _not_used, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + return dq_per_step, dk_dv_per_step, dbias_per_step + + def skip_compute(): + return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias) + + if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + # This is for nested jax.lax.cond + def jax_cond_wrap(): + if config.context_parallel_load_balanced: + return lax.cond( + (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute + ) + return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute) + + dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond( + idx == 0, causal_mask_compute, jax_cond_wrap + ) + else: + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q, + kv, + _not_used, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + config=helper.get_step_config(config.attn_mask_type), + ) + + kv_next, dk_dv = jnp.unstack(kv_dk_dv) + lax.optimization_barrier((kv_next, dk_dv)) + return (kv_next, dq, dq_per_step, dk_dv, dk_dv_per_step, dbias, dbias_per_step) + + carry = ( + kv, + dq, + jnp.zeros_like(dq), + dk_dv, + jnp.zeros_like(dk_dv), + dbias, + jnp.zeros_like(dbias), + ) + if helper.use_scanloop: + carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) + else: + for i in range(0, cp_size): + carry = scan_kv_block(i, carry) + (kv, dq, dq_per_step, dk_dv, dk_dv_per_step, dbias, dbias_per_step) = carry + + # Final accumulate + permute since we've pipelined the accumulation. + dq = dq + dq_per_step + dk_dv = dk_dv + dk_dv_per_step + dk_dv = helper.permute_kv(dk_dv, cp_perm) + + global_dbias = dbias + if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + dbias = dbias + dbias_per_step + global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) + + dk, dv = helper.unstack_kv(dk_dv) + return dq, dk, dv, global_dbias + + return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings + + +register_primitive(FusedRingAttnBwdPrimitive) + + def _maybe_context_parallel_axis(cp_axis: str): if not cp_axis: gmr = global_mesh_resource() @@ -1437,6 +2026,7 @@ def fused_attn_fwd( is_training: bool, max_segments_per_seq: int, window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ) -> jnp.ndarray: @@ -1519,7 +2109,14 @@ def fused_attn_fwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - return FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive.bind( + primative = None + match context_parallel_strategy: + case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: + primative = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive + case CPStrategy.RING: + primative = FusedRingAttnFwdPrimitive.outer_primitive + + return primative.bind( *qkv_for_primitive, bias, q_seqlen, @@ -1550,6 +2147,7 @@ def fused_attn_bwd( is_training: bool, max_segments_per_seq: int, window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -1636,7 +2234,14 @@ def fused_attn_bwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - *qkv_grads, bias_grad = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive.bind( + primative = None + match context_parallel_strategy: + case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: + primative = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive + case CPStrategy.RING: + primative = FusedRingAttnBwdPrimitive.outer_primitive + + *qkv_grads, bias_grad = primative.bind( *qkv_for_primitive, bias, softmax_aux, diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index d3df614ac9..686258eb2c 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -167,3 +167,20 @@ def is_ffi_enabled(): is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1")) assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value" return is_supported and is_enabled + + +def get_xla_flag(flag: str, default=None): + """ + Returns the value of an XLA_FLAGS environment variable if set or returns default. + """ + xla_flags = [] + if xla_flags_env := os.getenv("XLA_FLAGS"): + xla_flags.extend(xla_flags_env.split()) + for flag_i in sorted(xla_flags): + if "=" in flag: + name, val = flag_i.split("=", 2) + if name == flag: + return val + else: + return flag == flag_i # bool flag with no value + return default diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index a14a8384cf..f2da288be5 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -197,7 +197,7 @@ class MeshResource: The axis name in Mesh used to split the batch and weights along. If it is None, then full-sharded data parallelism is disabled. pp_resource : str, default = None - The axis name in Mesh used to split model layers. along. + The axis name in Mesh used to split model layers along. If it is None, then pipeline parallelism is disabled. cp_resource : str, default = None The axis name in Mesh used to split sequence (context) dimensions along