From 32c0510856d5e86630d24e79baa09234d9b08be7 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 12 Aug 2024 22:09:21 +0100 Subject: [PATCH] Fixing backward compatibility with JAX 0.3.16. `ml_dtypes` and JAX bfloat16 dtypes not equivalent older JAX versions. --- jax_scalify/core/pow2.py | 5 +++++ jax_scalify/quantization/scale.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/jax_scalify/core/pow2.py b/jax_scalify/core/pow2.py index ec2fe6d..5066fc0 100644 --- a/jax_scalify/core/pow2.py +++ b/jax_scalify/core/pow2.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Union import jax.numpy as jnp +import ml_dtypes import numpy as np from jax import core from jax.interpreters import mlir @@ -18,6 +19,10 @@ np.dtype(jnp.bfloat16): np.packbits( np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8) ).view(np.int16), + # Copy for ml_dtypes.bfloat16, distinct in older JAX versions. + np.dtype(ml_dtypes.bfloat16): np.packbits( + np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8) + ).view(np.int16), np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view( np.int16 ), diff --git a/jax_scalify/quantization/scale.py b/jax_scalify/quantization/scale.py index 330d30f..7eccc58 100644 --- a/jax_scalify/quantization/scale.py +++ b/jax_scalify/quantization/scale.py @@ -1,5 +1,6 @@ # Copyright (c) 2024 Graphcore Ltd. All rights reserved. import jax.numpy as jnp +import ml_dtypes import numpy as np from jax_scalify.core import Array, DTypeLike, get_numpy_api @@ -33,7 +34,7 @@ def as_e8m0(arr: Array) -> Array: """ np_api = get_numpy_api(arr) # assert len(arr.shape) < 2 - assert arr.dtype in {np.dtype(jnp.bfloat16), np.dtype(jnp.float32)} + assert arr.dtype in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)} # Saturation => negative values saturating to min value (i.e. zero bits) in E8M0. arr = np_api.maximum(arr, np.array(0, arr.dtype)) arr = pow2_truncate(arr) @@ -55,7 +56,7 @@ def from_e8m0(arr: Array, dtype: DTypeLike) -> Array: """ np_api = get_numpy_api(arr) assert arr.dtype == np.uint8 - assert np.dtype(dtype) in {np.dtype(jnp.bfloat16), np.dtype(jnp.float32)} + assert np.dtype(dtype) in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)} # Avoid issues with 7 mantissa bits in BF16. # TODO: more efficient implementation! arr = np_api.exp2(arr.astype(np.float32) - 127)