Skip to content

Commit

Permalink
Add JAX→Triton fp8 type mapping.
Browse files Browse the repository at this point in the history
Both Triton and JAX support them now.

PiperOrigin-RevId: 673774496
  • Loading branch information
mooskagh authored and The jax_triton Authors committed Sep 12, 2024
1 parent 94fba56 commit 4d1279e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@
jnp.dtype("float64"): "fp64",
jnp.dtype("float32"): "fp32",
jnp.dtype("float16"): "fp16",
# Triton has 'fp8' as well which Jax doesn't support yet.
jnp.dtype("float8_e4m3fn"): "fp8e4nv",
jnp.dtype("float8_e5m2"): "fp8e5",
jnp.dtype("float8_e4m3fnuz"): "fp8e4b8",
jnp.dtype("float8_e5m2fnuz"): "fp8e5b16",
jnp.dtype("int64"): "i64",
jnp.dtype("int32"): "i32",
jnp.dtype("int16"): "i16",
Expand Down

0 comments on commit 4d1279e

Please sign in to comment.