Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float8_e4m3 and float8_e3m4 types support #23585

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apivovarov
Copy link

@apivovarov apivovarov commented Sep 12, 2024

Description

Amazon has proposed two new FP8 types, Float8E4M3 and Float8E3M4. These types are implemented in commercially available hardware Amazon EC2 Trn1 Instances, and added to MLIR builtin types, LLVM APFloat, ml_dtypes, StableHLO.

XLA has Float8E4M3 and Float8E3M4 implementation in Review. See PR links in Related PRs section below.

This PR adds f8E4M3 and f8E3M4 types support to JAX.

f8E4M3 type follows IEEE 754 convention.

f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 17 =6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Related PRs:

  • LLVM PR-97179 [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-97118 [MLIR] Add f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
  • LLVM PR-101230 [MLIR] Add f8E3M4 IEEE 754 type (Merged)
  • StableHLO PR-2486 [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
  • StableHLO PR-2482 Add f8E4M3 and f8E3M4 types support (Merged)
  • ml_dtypes PR-161 Add float8_e4m3 (Merged)
  • ml_dtypes PR-171 Add float8_e3m4 (Merged)
  • XLA PR-17075 [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Merged)
  • XLA PR-16585 Add support for float8_e4m3 and float8_e3m4 types (in Review)

How to build/install

This PR requires ml_dtype version 20240821 or later.

The current version on PyPI is 0.4.0, released on April 1, 2024, which is outdated. Therefore, ml_dtypes should be installed from source.

Related issue: jax-ml/ml_dtypes#185 [Question] Can we release a new version of ml_dtypes?

## Install the latest ml_dtypes
cd ml_dtypes
pip3 install .

## Install jaxlib and JAX
cd jax

### install jaxlib
python3 build/build.py
pip3 install dist/*.whl

### install jax
pip3 install .

Smoke test

import jax
import jax.numpy as jnp
from jax import Array, random

key1 = random.PRNGKey(41)
key2 = random.PRNGKey(42)
a = random.uniform(key1, shape=(4,4), dtype="float8_e4m3")
b = random.uniform(key2, shape=(4,4), dtype="float8_e4m3")

def foo(a, b):
  return a@b

# StableHLO
print(jax.jit(foo).lower(a,b).as_text())

# HLO (optimized for cpu)
print(jax.jit(foo).lower(a,b).compile().as_text())

c = foo(a, b)

Array([[1, 0.9375, 1.25, 0.5625],
       [0.75, 0.625, 0.75, 0.5],
       [0.8125, 0.8125, 1.25, 0.40625],
       [0.8125, 0.875, 1.25, 0.4375]], dtype=float8_e4m3)

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 12, 2024

Thanks for the contribution! I don't think we'll be able to bump our ml_dtypes requirement any time soon, so if we want to merge this we'll have to make it robust to older ml_dtypes versions (the reason is that tensorflow pins a specific ml_dtypes version, and some workflows depend on installing both JAX and tensorflow.

The good news is this is easy enough to do with a few version guards: if you look at the initial implementation of float8 types in JAX, you can see the pattern we used previously.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 12, 2024

Here's an example of how this was handled in the past: https://github.com/google/jax/blob/jax-v0.4.12/jax/_src/dtypes.py#L71

Basically, we only define the dtype in JAX if it's defined in ml_dtypes.

Another strategy we could use is the module-level __getattr__ for these types, so that if the ml_dtypes version is too old, we raise an error that specifies what version is required.

@hawkinsp
Copy link
Collaborator

Incidentally, the current TF pin is : Requires-Dist: ml-dtypes <0.5.0,>=0.3.1.

If we release ml_dtypes as 0.4.1 instead of 0.5.0 we probably could bump the minimum version.

I suspect we could ease this process if we committed to semver for ml_dtypes so TF felt like they could be less conservative in their pins. (Adding dtypes is hopefully safe!)

@hawkinsp
Copy link
Collaborator

That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason).

@apivovarov apivovarov changed the title Add float8_e4m3 type support Add float8_e4m3 and float8_e3m4 types support Sep 26, 2024
@apivovarov
Copy link
Author

That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason).

I updated the PR and tested it with ml_dtypes 0.4.0 and 0.5.0
@jakevdp @hawkinsp

XLA_COMMIT = "a473d30392e2cea68dc90f95377de3f568ea2055"
XLA_SHA256 = "30324095a4d9454b5a8fdf0397b62cfd6f06155a077ce93cf75b64fb78f98fc0"
XLA_COMMIT = "fa92c93da44082d79390540958bbddfbf9e76899"
XLA_SHA256 = "93ea7c87d098267813cc78c41eba789e1bebef1a03e72db4755937958650fbe1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this please: we should not modify the XLA commit as part of this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restored

@@ -137,6 +142,20 @@ def supports_inf(dtype: DTypeLike) -> bool:
_float8_e5m2fnuz_dtype,
]

# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
if hasattr(ml_dtypes, "float8_e4m3"):
float8_e4m3 = ml_dtypes.float8_e4m3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use two-space indentation please

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #23959 - Add .pylintrc file. It allows to check modified files with pylint. Copied .pylintrc from TF

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have pylint configuration in pyproject.toml; for example:

jax/pyproject.toml

Lines 96 to 97 in 9f4e8d0

[tool.pylint.format]
indent-string=" "

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to ruff

@@ -324,6 +324,10 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
}

if dtypes._float8_e3m4_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two-space indentation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

assert dtypes.uint2 is not None
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(
assert dtypes.uint2 is not None
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two-space indentation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

jax/_src/lax/lax.py Outdated Show resolved Hide resolved
@@ -90,6 +90,14 @@ def default_tolerance():
np.dtype(np.complex128): 1e-5,
}

# TODO: make this unconditional when ml_dtypes>=0.5.0 is required
if _dtypes.float8_e3m4 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two-space indentation please

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -410,6 +410,8 @@ def flipud(m: ArrayLike) -> Array: ...
float16: Any
float32: Any
float64: Any
float8_e3m4: Any
float8_e4m3: Any
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove these for now, since they may or may not actually be present in the namespace (and this can only be known at runtime)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

@@ -64,6 +64,10 @@
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz)]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two-space indentation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@jakevdp jakevdp self-assigned this Sep 26, 2024
@apivovarov apivovarov force-pushed the float8_e4m3 branch 2 times, most recently from e82bfa8 to a6218f2 Compare September 27, 2024 03:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants