From e82bfa86385ac245a48c33c7669d8ed987e52b03 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 11 Sep 2024 23:11:57 +0000 Subject: [PATCH] Add float8_e4m3 and float8_e3m4 types support --- jax/_src/dtypes.py | 22 +++++++++++++++- jax/_src/export/serialization.fbs | 2 ++ jax/_src/export/serialization.py | 4 +++ jax/_src/export/serialization_generated.py | 2 ++ jax/_src/interpreters/mlir.py | 12 ++++----- jax/_src/lax/lax.py | 29 ++++++++++++++++------ jax/_src/numpy/lax_numpy.py | 4 +++ jax/_src/public_test_util.py | 14 +++++++++++ jax/_src/test_util.py | 17 ++++++++++--- jax/numpy/__init__.py | 9 +++++++ tests/dtypes_test.py | 4 +++ 11 files changed, 101 insertions(+), 18 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 82be38d1cb57..074b0a70ab91 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,12 +90,17 @@ def type(self) -> type: ... # fp8 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float8_e3m4: type[np.generic] | None = None +float8_e4m3: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz +_float8_e3m4_dtype: np.dtype | None = None +_float8_e4m3_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -105,7 +110,8 @@ def type(self) -> type: ... def supports_inf(dtype: DTypeLike) -> bool: """Return true if the dtype supports infinity, else return False.""" typ = np.dtype(dtype).type - if typ in {float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, float8_e5m2fnuz}: + if typ in {float8_e4m3b11fnuz, float8_e4m3fn, + float8_e4m3fnuz, float8_e5m2fnuz}: return False return issubdtype(dtype, np.inexact) @@ -137,6 +143,20 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] +# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 +if hasattr(ml_dtypes, "float8_e4m3"): + float8_e4m3 = ml_dtypes.float8_e4m3 + _float8_e4m3_dtype = np.dtype(float8_e4m3) + _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e4m3_dtype) # type: ignore[arg-type] + _float8_dtypes.insert(0, _float8_e4m3_dtype) # type: ignore[arg-type] +if hasattr(ml_dtypes, "float8_e3m4"): + float8_e3m4 = ml_dtypes.float8_e3m4 + _float8_e3m4_dtype = np.dtype(float8_e3m4) + _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e3m4_dtype) # type: ignore[arg-type] + _float8_dtypes.insert(0, _float8_e3m4_dtype) # type: ignore[arg-type] + # 2-bit integer support int2: type[np.generic] | None = None uint2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 758950adaa8e..59e169dc6fb6 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -64,6 +64,8 @@ enum DType: byte { i4 = 15, ui4 = 16, + f8_e3m4 = 24, + f8_e4m3 = 23, f8_e4m3b11fnuz = 17, f8_e4m3fn = 18, f8_e4m3fnuz = 19, diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index a47b095e4450..e283e0d57528 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -324,6 +324,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, } +if dtypes._float8_e3m4_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 +if dtypes._float8_e4m3_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index a872d03a9fdd..583b41814963 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -52,6 +52,8 @@ class DType: bf16 = 14 i4 = 15 ui4 = 16 + f8_e3m4 = 24 + f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index af773365b12d..9b149e6af951 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -186,13 +186,13 @@ def _is_ir_values(x: IrValues) -> bool: if dtypes.int2 is not None: assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( - ir.IntegerType.get_signless, 2 - ) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial( - ir.IntegerType.get_unsigned, 2 - ) + _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2) + _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2) +if dtypes.float8_e3m4 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get +if dtypes.float8_e4m3 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f51f0436b7a9..4a9c34be1c1c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -831,11 +831,17 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM): - fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), - np.dtype(dtypes.float8_e4m3fn), - np.dtype(dtypes.float8_e4m3fnuz), - np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)) + fp8_dtypes = [ + np.dtype(dtypes.float8_e4m3b11fnuz), + np.dtype(dtypes.float8_e4m3fn), + np.dtype(dtypes.float8_e4m3fnuz), + np.dtype(dtypes.float8_e5m2), + np.dtype(dtypes.float8_e5m2fnuz), + ] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -3386,9 +3392,18 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, transpose_algorithm: DotTransposeAlgorithm | None = None, platform: str = "default"): del transpose_algorithm # unused + def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): - fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, - dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) + fp8_dtypes = [ + dtypes.float8_e4m3fn, + dtypes.float8_e5m2, + dtypes.float8_e5m2fnuz, + dtypes.float8_e4m3fnuz, + ] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [dtypes.float8_e3m4] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [dtypes.float8_e4m3] return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 387b3b2a51a7..aaf70648dbc2 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -211,6 +211,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) +if dtypes.float8_e3m4 is not None: + float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +if dtypes.float8_e4m3 is not None: + float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 9859eb64cda2..7256bb12c95d 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -90,6 +90,14 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } +# TODO: make this unconditional when ml_dtypes>=0.5.0 is required +if _dtypes.float8_e3m4 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 +if _dtypes.float8_e4m3 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -106,6 +114,12 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.float8_e5m2fnuz, _dtypes.bfloat16, ] + + if _dtypes.float8_e4m3 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e4m3) # type: ignore[arg-type] + if _dtypes.float8_e3m4 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e3m4) # type: ignore[arg-type] + def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 81737f27540b..5ed178eac5df 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1433,10 +1433,19 @@ def supported(self, dtypes): @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ - _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, - _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + float_dtypes = [ + _dtypes.bfloat16, + _dtypes.float8_e4m3b11fnuz, + _dtypes.float8_e4m3fn, + _dtypes.float8_e4m3fnuz, + _dtypes.float8_e5m2, + _dtypes.float8_e5m2fnuz, + ] + if _dtypes.float8_e3m4 is not None: + float_dtypes += [_dtypes.float8_e3m4] + if _dtypes.float8_e4m3 is not None: + float_dtypes += [_dtypes.float8_e4m3] + return [np.dtype(t) for t in float_dtypes] @_cached_property def floating(self): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 20c37c55902c..7a1ec3aa5cc5 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -273,6 +273,15 @@ except ImportError: pass +# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 +try: + from jax._src.numpy.lax_numpy import ( + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, + ) +except ImportError: + pass + from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 89d70871a8f9..6c7e9e3ab712 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -64,6 +64,10 @@ fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz)] +if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] +if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes