Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float8_e4m3 and float8_e3m4 types support #23585

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,17 @@ def type(self) -> type: ...


# fp8 support
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float8_e3m4: type[np.generic] | None = None
float8_e4m3: type[np.generic] | None = None
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz

_float8_e3m4_dtype: np.dtype | None = None
_float8_e4m3_dtype: np.dtype | None = None
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
Expand Down Expand Up @@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool:
_float8_e5m2fnuz_dtype,
]

# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
if hasattr(ml_dtypes, "float8_e4m3"):
float8_e4m3 = ml_dtypes.float8_e4m3
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e4m3_dtype) # type: ignore[arg-type]
_float8_dtypes.insert(0, _float8_e4m3_dtype) # type: ignore[arg-type]
if hasattr(ml_dtypes, "float8_e3m4"):
float8_e3m4 = ml_dtypes.float8_e3m4
_float8_e3m4_dtype = np.dtype(float8_e3m4)
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e3m4_dtype) # type: ignore[arg-type]
_float8_dtypes.insert(0, _float8_e3m4_dtype) # type: ignore[arg-type]

# 2-bit integer support
int2: type[np.generic] | None = None
uint2: type[np.generic] | None = None
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/export/serialization.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ enum DType: byte {
i4 = 15,
ui4 = 16,

f8_e3m4 = 24,
f8_e4m3 = 23,
f8_e4m3b11fnuz = 17,
f8_e4m3fn = 18,
f8_e4m3fnuz = 19,
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/export/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
}

if dtypes._float8_e3m4_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3

_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/export/serialization_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class DType:
bf16 = 14
i4 = 15
ui4 = 16
f8_e3m4 = 24
f8_e4m3 = 23
f8_e4m3b11fnuz = 17
f8_e4m3fn = 18
f8_e4m3fnuz = 19
Expand Down
12 changes: 6 additions & 6 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,13 @@ def _is_ir_values(x: IrValues) -> bool:

if dtypes.int2 is not None:
assert dtypes.uint2 is not None
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(
ir.IntegerType.get_signless, 2
)
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(
ir.IntegerType.get_unsigned, 2
)
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2)
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2)

if dtypes.float8_e3m4 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
if dtypes.float8_e4m3 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get

def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
Expand Down
29 changes: 22 additions & 7 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,11 +831,17 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,

if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32,
DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM):
fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz),
np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz),
np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz))
fp8_dtypes = [
np.dtype(dtypes.float8_e4m3b11fnuz),
np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz),
np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz),
]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
Expand Down Expand Up @@ -3386,9 +3392,18 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
transpose_algorithm: DotTransposeAlgorithm | None = None,
platform: str = "default"):
del transpose_algorithm # unused

def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
fp8_dtypes = [
dtypes.float8_e4m3fn,
dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz,
dtypes.float8_e4m3fnuz,
]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [dtypes.float8_e3m4]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [dtypes.float8_e4m3]
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
del preferred_element_type # Implied by the output aval
lhs_aval, rhs_aval = ctx.avals_in
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
if dtypes.float8_e3m4 is not None:
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
if dtypes.float8_e4m3 is not None:
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/public_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def default_tolerance():
np.dtype(np.complex128): 1e-5,
}

# TODO: make this unconditional when ml_dtypes>=0.5.0 is required
if _dtypes.float8_e3m4 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
if _dtypes.float8_e4m3 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1

def is_python_scalar(val):
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))

Expand All @@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
_dtypes.float8_e5m2fnuz,
_dtypes.bfloat16,
]

if _dtypes.float8_e4m3 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e4m3) # type: ignore[arg-type]
if _dtypes.float8_e3m4 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e3m4) # type: ignore[arg-type]

def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
return x.astype(np.float32)
Expand Down
17 changes: 13 additions & 4 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,10 +1433,19 @@ def supported(self, dtypes):

@_cached_property
def custom_floats(self):
return [np.dtype(t) for t in [
_dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz,
_dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz,
_dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]]
float_dtypes = [
_dtypes.bfloat16,
_dtypes.float8_e4m3b11fnuz,
_dtypes.float8_e4m3fn,
_dtypes.float8_e4m3fnuz,
_dtypes.float8_e5m2,
_dtypes.float8_e5m2fnuz,
]
if _dtypes.float8_e3m4 is not None:
float_dtypes += [_dtypes.float8_e3m4]
if _dtypes.float8_e4m3 is not None:
float_dtypes += [_dtypes.float8_e4m3]
return [np.dtype(t) for t in float_dtypes]

@_cached_property
def floating(self):
Expand Down
9 changes: 9 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@
except ImportError:
pass

# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0
try:
from jax._src.numpy.lax_numpy import (
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
)
except ImportError:
pass

from jax._src.numpy.array_api_metadata import (
__array_api_version__ as __array_api_version__,
__array_namespace_info__ as __array_namespace_info__,
Expand Down
4 changes: 4 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz)]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes

Expand Down
Loading