diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 8483629a13..51512d0744 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -4,9 +4,6 @@ set -xe -# WAR(rewang) for the "Check failed: reduction_kind.has_value()" -export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true" - : ${TE_PATH:=/opt/transformerengine} pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_* diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index b96b8c4ed1..9f20769045 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -14,7 +14,5 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -# WAR(rewang) for the "Check failed: reduction_kind.has_value()" -export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true" pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 8f2203ae0b..66b0003a11 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -207,20 +207,32 @@ def core_attention(query: Array, key = key.astype(jnp.float32) h_q, h_kv = query.shape[-2], key.shape[-2] - assert (h_q % h_kv == 0) and (h_q >= h_kv) - group_size = h_q // h_kv - grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) + # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. + # Therefore, we have to maintain two code paths. + is_gqa = (h_q != h_kv) + + if is_gqa: + assert (h_q % h_kv == 0) and (h_q >= h_kv) + group_size = h_q // h_kv + grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) if transpose_batch_sequence: - attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) + if is_gqa: + attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) + else: + attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key) else: - attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) + if is_gqa: + attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) + else: + attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) attn_weights = checkpoint_name(attn_weights, 'logits') - b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape - attn_weights_without_groups_shape = (b, h * g, q, k) - attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) + if is_gqa: + b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape + attn_weights_without_groups_shape = (b, h * g, q, k) + attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) attn_weights = _with_sharding_constraint(attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES)) @@ -237,7 +249,8 @@ def core_attention(query: Array, attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype) - attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) + if is_gqa: + attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) if not deterministic and dropout_rate > 0.: keep_prob = 1.0 - dropout_rate @@ -248,9 +261,13 @@ def core_attention(query: Array, attn_weights = attn_weights * multiplier if transpose_batch_sequence: - return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) + if is_gqa: + return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) + return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value) - return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) + if is_gqa: + return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) + return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))