-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add float8_e4m3 and float8_e3m4 types support #23585
base: main
Are you sure you want to change the base?
Conversation
Thanks for the contribution! I don't think we'll be able to bump our The good news is this is easy enough to do with a few version guards: if you look at the initial implementation of |
Here's an example of how this was handled in the past: https://github.com/google/jax/blob/jax-v0.4.12/jax/_src/dtypes.py#L71 Basically, we only define the dtype in JAX if it's defined in Another strategy we could use is the module-level |
Incidentally, the current TF pin is : If we release I suspect we could ease this process if we committed to semver for |
That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason). |
0336705
to
5553be7
Compare
5553be7
to
090ff3e
Compare
I updated the PR and tested it with ml_dtypes 0.4.0 and 0.5.0 |
third_party/xla/workspace.bzl
Outdated
XLA_COMMIT = "a473d30392e2cea68dc90f95377de3f568ea2055" | ||
XLA_SHA256 = "30324095a4d9454b5a8fdf0397b62cfd6f06155a077ce93cf75b64fb78f98fc0" | ||
XLA_COMMIT = "fa92c93da44082d79390540958bbddfbf9e76899" | ||
XLA_SHA256 = "93ea7c87d098267813cc78c41eba789e1bebef1a03e72db4755937958650fbe1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this please: we should not modify the XLA commit as part of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
restored
jax/_src/dtypes.py
Outdated
@@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool: | |||
_float8_e5m2fnuz_dtype, | |||
] | |||
|
|||
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0 | |||
if hasattr(ml_dtypes, "float8_e4m3"): | |||
float8_e4m3 = ml_dtypes.float8_e4m3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use two-space indentation please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #23959 - Add .pylintrc file. It allows to check modified files with pylint
. Copied .pylintrc
from TF
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have pylint configuration in pyproject.toml
; for example:
Lines 96 to 97 in 9f4e8d0
[tool.pylint.format] | |
indent-string=" " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
switched to ruff
jax/_src/export/serialization.py
Outdated
@@ -324,6 +324,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef): | |||
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz, | |||
} | |||
|
|||
if dtypes._float8_e3m4_dtype is not None: | |||
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two-space indentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
jax/_src/interpreters/mlir.py
Outdated
assert dtypes.uint2 is not None | ||
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( | ||
assert dtypes.uint2 is not None | ||
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two-space indentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
jax/_src/public_test_util.py
Outdated
@@ -90,6 +90,14 @@ def default_tolerance(): | |||
np.dtype(np.complex128): 1e-5, | |||
} | |||
|
|||
# TODO: make this unconditional when ml_dtypes>=0.5.0 is required | |||
if _dtypes.float8_e3m4 is not None: | |||
_default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two-space indentation please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
jax/numpy/__init__.pyi
Outdated
@@ -410,6 +410,8 @@ def flipud(m: ArrayLike) -> Array: ... | |||
float16: Any | |||
float32: Any | |||
float64: Any | |||
float8_e3m4: Any | |||
float8_e4m3: Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would remove these for now, since they may or may not actually be present in the namespace (and this can only be known at runtime)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
tests/dtypes_test.py
Outdated
@@ -64,6 +64,10 @@ | |||
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn), | |||
np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2), | |||
np.dtype(dtypes.float8_e5m2fnuz)] | |||
if dtypes.float8_e3m4 is not None: | |||
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two-space indentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
e82bfa8
to
a6218f2
Compare
a6218f2
to
0b862ed
Compare
Description
Amazon has proposed two new FP8 types,
Float8E4M3
andFloat8E3M4
. These types are implemented in commercially available hardware Amazon EC2 Trn1 Instances, and added to MLIR builtin types, LLVM APFloat, ml_dtypes, StableHLO.XLA has Float8E4M3 and Float8E3M4 implementation in Review. See PR links in Related PRs section below.
This PR adds f8E4M3 and f8E3M4 types support to JAX.
f8E4M3
type follows IEEE 754 convention.f8E3M4
type follows IEEE 754 conventionRelated PRs:
How to build/install
This PR requires ml_dtype version 20240821 or later.
The current version on PyPI is 0.4.0, released on April 1, 2024, which is outdated. Therefore, ml_dtypes should be installed from source.
Related issue: jax-ml/ml_dtypes#185 [Question] Can we release a new version of ml_dtypes?
Smoke test