Skip to content

Commit

Permalink
[TE/JAX] XLA FFI calls for three cast transpose functions (#1310)
Browse files Browse the repository at this point in the history
* FFI for some transpose & activation functions

Signed-off-by: Hua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove comments in transformer_engine/jax/csrc/extensions/activation.cpp

Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Signed-off-by: Hua Huang <huangh1994@outlook.com>

---------

Signed-off-by: Hua Huang <huah@nvidia.com>
Signed-off-by: Hua Huang <huangh1994@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 6, 2024
1 parent d4aa299 commit 4d65073
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 229 deletions.
30 changes: 29 additions & 1 deletion tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformer_engine.jax.cpp_extensions.transpose import (
_jax_transpose,
_jax_cast_transpose,
_jax_dbias_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax import cpp_extensions as tex
Expand Down Expand Up @@ -504,7 +505,6 @@ def _prim_func_bwd(ctx, g):
scale_inv,
FP8Helper.BWD_DTYPE,
-1,
-2,
self.activation_type,
)
)
Expand Down Expand Up @@ -812,6 +812,34 @@ def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)

@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_dbias_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_dbias_cast_transpose(
input, amax, scale, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.dbias_cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit 4d65073

Please sign in to comment.