From 090ff3ea0c8ee30619e58027fb448d242324a0bc Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Wed, 11 Sep 2024 23:11:57 +0000 Subject: [PATCH] Add float8_e4m3 and float8_e3m4 types support --- jax/_src/dtypes.py | 19 +++++++++++++++++++ jax/_src/export/serialization.fbs | 2 ++ jax/_src/export/serialization.py | 4 ++++ jax/_src/export/serialization_generated.py | 2 ++ jax/_src/interpreters/mlir.py | 10 +++++++--- jax/_src/lax/lax.py | 16 ++++++++++++---- jax/_src/numpy/lax_numpy.py | 4 ++++ jax/_src/public_test_util.py | 12 ++++++++++++ jax/_src/test_util.py | 9 +++++++-- jax/numpy/__init__.py | 9 +++++++++ jax/numpy/__init__.pyi | 2 ++ tests/dtypes_test.py | 4 ++++ third_party/xla/workspace.bzl | 4 ++-- 13 files changed, 86 insertions(+), 11 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 82be38d1cb57..acf78649dce5 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -90,12 +90,17 @@ def type(self) -> type: ... # fp8 support +# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0 +float8_e3m4: type[np.generic] | None = None +float8_e4m3: type[np.generic] | None = None float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2 float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz +_float8_e3m4_dtype: np.dtype | None = None +_float8_e4m3_dtype: np.dtype | None = None _float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz) _float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn) _float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz) @@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool: _float8_e5m2fnuz_dtype, ] +# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 +if hasattr(ml_dtypes, "float8_e4m3"): + float8_e4m3 = ml_dtypes.float8_e4m3 + _float8_e4m3_dtype = np.dtype(float8_e4m3) + _custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e4m3_dtype) # type: ignore[arg-type] + _float8_dtypes.insert(0, _float8_e4m3_dtype) # type: ignore[arg-type] +if hasattr(ml_dtypes, "float8_e3m4"): + float8_e3m4 = ml_dtypes.float8_e3m4 + _float8_e3m4_dtype = np.dtype(float8_e3m4) + _custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type] + _custom_float_dtypes.insert(0, _float8_e3m4_dtype) # type: ignore[arg-type] + _float8_dtypes.insert(0, _float8_e3m4_dtype) # type: ignore[arg-type] + # 2-bit integer support int2: type[np.generic] | None = None uint2: type[np.generic] | None = None diff --git a/jax/_src/export/serialization.fbs b/jax/_src/export/serialization.fbs index 758950adaa8e..59e169dc6fb6 100644 --- a/jax/_src/export/serialization.fbs +++ b/jax/_src/export/serialization.fbs @@ -64,6 +64,8 @@ enum DType: byte { i4 = 15, ui4 = 16, + f8_e3m4 = 24, + f8_e4m3 = 23, f8_e4m3b11fnuz = 17, f8_e4m3fn = 18, f8_e4m3fnuz = 19, diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index a47b095e4450..f94627cd2c78 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -324,6 +324,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, } +if dtypes._float8_e3m4_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 +if dtypes._float8_e4m3_dtype is not None: + _dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3 _dtype_kind_to_dtype = { kind: dtype for dtype, kind in _dtype_to_dtype_kind.items() diff --git a/jax/_src/export/serialization_generated.py b/jax/_src/export/serialization_generated.py index a872d03a9fdd..583b41814963 100644 --- a/jax/_src/export/serialization_generated.py +++ b/jax/_src/export/serialization_generated.py @@ -52,6 +52,8 @@ class DType: bf16 = 14 i4 = 15 ui4 = 16 + f8_e3m4 = 24 + f8_e4m3 = 23 f8_e4m3b11fnuz = 17 f8_e4m3fn = 18 f8_e4m3fnuz = 19 diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index af773365b12d..48af8c61fcee 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -185,14 +185,18 @@ def _is_ir_values(x: IrValues) -> bool: if dtypes.int2 is not None: - assert dtypes.uint2 is not None - _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( + assert dtypes.uint2 is not None + _dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( ir.IntegerType.get_signless, 2 ) - _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial( + _dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial( ir.IntegerType.get_unsigned, 2 ) +if dtypes.float8_e3m4 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get +if dtypes.float8_e4m3 is not None: + _dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: if isinstance(dtype, core.bint): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index f51f0436b7a9..3a51d14d8930 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -831,11 +831,15 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM): - fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), + fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), - np.dtype(dtypes.float8_e5m2fnuz)) + np.dtype(dtypes.float8_e5m2fnuz)] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: raise ValueError( f"The dot algorithm '{self}' requires both inputs to have float8 " @@ -3387,8 +3391,12 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, platform: str = "default"): del transpose_algorithm # unused def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): - fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, - dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) + fp8_dtypes = [dtypes.float8_e4m3fn, dtypes.float8_e5m2, + dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz] + if dtypes.float8_e3m4 is not None: + fp8_dtypes += [dtypes.float8_e3m4] + if dtypes.float8_e4m3 is not None: + fp8_dtypes += [dtypes.float8_e4m3] return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes del preferred_element_type # Implied by the output aval lhs_aval, rhs_aval = ctx.avals_in diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 73e27245cfa9..895d7847f53e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -211,6 +211,10 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: int16 = _make_scalar_type(np.int16) int32 = _make_scalar_type(np.int32) int64 = _make_scalar_type(np.int64) +if dtypes.float8_e3m4 is not None: + float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4) +if dtypes.float8_e4m3 is not None: + float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3) float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn) float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz) float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index 9859eb64cda2..8ac8daa5c969 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -90,6 +90,14 @@ def default_tolerance(): np.dtype(np.complex128): 1e-5, } +# TODO: make this unconditional when ml_dtypes>=0.5.0 is required +if _dtypes.float8_e3m4 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 +if _dtypes.float8_e4m3 is not None: + _default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1 + def is_python_scalar(val): return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex)) @@ -106,6 +114,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): _dtypes.float8_e5m2fnuz, _dtypes.bfloat16, ] + if _dtypes.float8_e4m3 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e4m3) # type: ignore[arg-type] + if _dtypes.float8_e3m4 is not None: + custom_float_dtypes.insert(0, _dtypes.float8_e3m4) # type: ignore[arg-type] def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 81737f27540b..1a31f76cea2f 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1433,10 +1433,15 @@ def supported(self, dtypes): @_cached_property def custom_floats(self): - return [np.dtype(t) for t in [ + float_dtypes = [ _dtypes.bfloat16, _dtypes.float8_e4m3b11fnuz, _dtypes.float8_e4m3fn, _dtypes.float8_e4m3fnuz, - _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz]] + _dtypes.float8_e5m2, _dtypes.float8_e5m2fnuz] + if _dtypes.float8_e3m4 is not None: + float_dtypes += [_dtypes.float8_e3m4] + if _dtypes.float8_e4m3 is not None: + float_dtypes += [_dtypes.float8_e4m3] + return [np.dtype(t) for t in float_dtypes] @_cached_property def floating(self): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 20c37c55902c..7a1ec3aa5cc5 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -273,6 +273,15 @@ except ImportError: pass +# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0 +try: + from jax._src.numpy.lax_numpy import ( + float8_e3m4 as float8_e3m4, + float8_e4m3 as float8_e4m3, + ) +except ImportError: + pass + from jax._src.numpy.array_api_metadata import ( __array_api_version__ as __array_api_version__, __array_namespace_info__ as __array_namespace_info__, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index c23f659bd3f9..b62dc83a84f8 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -410,6 +410,8 @@ def flipud(m: ArrayLike) -> Array: ... float16: Any float32: Any float64: Any +float8_e3m4: Any +float8_e4m3: Any float8_e4m3b11fnuz: Any float8_e4m3fn: Any float8_e4m3fnuz: Any diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 89d70871a8f9..77eafd65adf4 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -64,6 +64,10 @@ fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), np.dtype(dtypes.float8_e5m2fnuz)] +if dtypes.float8_e3m4 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] +if dtypes.float8_e4m3 is not None: + fp8_dtypes += [np.dtype(dtypes.float8_e4m3)] float_dtypes += fp8_dtypes custom_float_dtypes += fp8_dtypes diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 4ad4e48c02d1..6139964e9256 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "a473d30392e2cea68dc90f95377de3f568ea2055" -XLA_SHA256 = "30324095a4d9454b5a8fdf0397b62cfd6f06155a077ce93cf75b64fb78f98fc0" +XLA_COMMIT = "fa92c93da44082d79390540958bbddfbf9e76899" +XLA_SHA256 = "93ea7c87d098267813cc78c41eba789e1bebef1a03e72db4755937958650fbe1" def repo(): tf_http_archive(