Skip to content

Commit

Permalink
[JAX] Fix unfused GQA performance (#643)
Browse files Browse the repository at this point in the history
* Fix unfused GQA perf

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Remove WAR for Check failed: reduction_kind.has_value()

Signed-off-by: Reese Wang <rewang@nvidia.com>

---------

Signed-off-by: Reese Wang <rewang@nvidia.com>
  • Loading branch information
zlsh80826 authored Feb 1, 2024
1 parent e2803b1 commit 29b0c9c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
3 changes: 0 additions & 3 deletions qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

set -xe

# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"

: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*

2 changes: 0 additions & 2 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,5 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
39 changes: 28 additions & 11 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,32 @@ def core_attention(query: Array,
key = key.astype(jnp.float32)

h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
is_gqa = (h_q != h_kv)

if is_gqa:
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
if is_gqa:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
else:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
if is_gqa:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)

attn_weights = checkpoint_name(attn_weights, 'logits')

b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
if is_gqa:
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))
Expand All @@ -237,7 +249,8 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)

attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)

if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
Expand All @@ -248,9 +261,13 @@ def core_attention(query: Array,
attn_weights = attn_weights * multiplier

if transpose_batch_sequence:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
if is_gqa:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)

return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
if is_gqa:
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)


dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
Expand Down

0 comments on commit 29b0c9c

Please sign in to comment.