Skip to content

Commit

Permalink
Fixing backward compatibility with JAX 0.3.16.
Browse files Browse the repository at this point in the history
`ml_dtypes` and JAX bfloat16 dtypes not equivalent older JAX versions.
  • Loading branch information
balancap committed Aug 12, 2024
1 parent 759e9e5 commit 32c0510
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 5 additions & 0 deletions jax_scalify/core/pow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import jax.numpy as jnp
import ml_dtypes
import numpy as np
from jax import core
from jax.interpreters import mlir
Expand All @@ -18,6 +19,10 @@
np.dtype(jnp.bfloat16): np.packbits(
np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8)
).view(np.int16),
# Copy for ml_dtypes.bfloat16, distinct in older JAX versions.
np.dtype(ml_dtypes.bfloat16): np.packbits(
np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8)
).view(np.int16),
np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view(
np.int16
),
Expand Down
5 changes: 3 additions & 2 deletions jax_scalify/quantization/scale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
import ml_dtypes
import numpy as np

from jax_scalify.core import Array, DTypeLike, get_numpy_api
Expand Down Expand Up @@ -33,7 +34,7 @@ def as_e8m0(arr: Array) -> Array:
"""
np_api = get_numpy_api(arr)
# assert len(arr.shape) < 2
assert arr.dtype in {np.dtype(jnp.bfloat16), np.dtype(jnp.float32)}
assert arr.dtype in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)}
# Saturation => negative values saturating to min value (i.e. zero bits) in E8M0.
arr = np_api.maximum(arr, np.array(0, arr.dtype))
arr = pow2_truncate(arr)
Expand All @@ -55,7 +56,7 @@ def from_e8m0(arr: Array, dtype: DTypeLike) -> Array:
"""
np_api = get_numpy_api(arr)
assert arr.dtype == np.uint8
assert np.dtype(dtype) in {np.dtype(jnp.bfloat16), np.dtype(jnp.float32)}
assert np.dtype(dtype) in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)}
# Avoid issues with 7 mantissa bits in BF16.
# TODO: more efficient implementation!
arr = np_api.exp2(arr.astype(np.float32) - 127)
Expand Down

0 comments on commit 32c0510

Please sign in to comment.