From 293c1c813086080efd1e5f56eea5138eb40dbac1 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 11 Jan 2024 17:44:48 +0000 Subject: [PATCH] Add `safe_check_dtypes` mode to AutoScale interpreter. (#82) Allowing to check in unit tests as well as large graphs that AutoScale graph interpreter is not silently promoting to more accurate floating datatype. --- jax_scaled_arithmetics/core/interpreters.py | 41 +++++++++++++++++---- tests/core/test_interpreter.py | 6 ++- tests/lax/test_base_scaling_primitives.py | 2 +- 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 9b00264..45791f5 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -15,7 +15,7 @@ ) from jax._src.util import safe_map -from .datatype import DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf +from .datatype import Array, DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf from .utils import Pow2RoundMode @@ -90,6 +90,12 @@ def _get_aval(val: Any) -> core.ShapedArray: return core.ShapedArray(shape=val.shape, dtype=val.dtype) +def _get_data(val: Any) -> Array: + if isinstance(val, ScaledArray): + return val.data + return val + + def promote_scalar_to_scaled_array(val: Any) -> ScaledArray: """Promote a scalar (Numpy, JAX, ...) to a Scaled Array. @@ -192,6 +198,8 @@ def wrapped(*args, **kwargs): def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args): env: Dict[core.Var, ScaledArray] = {} + # Check dtype consistency between normal and scaled modes. + safe_check_dtypes: bool = False def read(var): if type(var) is core.Literal: @@ -209,6 +217,13 @@ def promote_to_scaled_array(val): # No promotion rule => just return as such. return val + def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]: + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + return outvals + # A few initial checks to make sure there is consistency. assert len(jaxpr.invars) == len(args) safe_map(write, jaxpr.invars, args) @@ -223,19 +238,31 @@ def promote_to_scaled_array(val): if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE: # Using normal JAX primitive: no scaled inputs, and not always scale rule. - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params) + outvals = jaxpr_eqn_bind(eqn, invals) elif scaled_prim_fn is None: raise NotImplementedError( f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." ) else: # Using scaled primitive. Automatic promotion of inputs to scaled array, when possible. - invals = list(map(promote_to_scaled_array, invals)) - outvals = scaled_prim_fn(*invals, **eqn.params) + scaled_invals = list(map(promote_to_scaled_array, invals)) + outvals = scaled_prim_fn(*scaled_invals, **eqn.params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + + # Check consistency with normal JAX mode. Help catching dtype promotion errors. + # NOTE: ignoring when no outputs! (e.g. debug_callback). + if safe_check_dtypes and len(outvals) > 0: + ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v) for v in invals]) + data_outvals = [_get_data(v) for v in outvals] + # Check scaled dtypes == ref dtypes. + ref_dtypes = tuple(v.dtype for v in ref_outvals) + data_dtypes = tuple(v.dtype for v in data_outvals) + if data_dtypes != ref_dtypes: + raise ValueError( + f"Output dtype of '{eqn.primitive}' scaled translation is not consistent with the JAX reference primitive implementation: {data_dtypes} vs {ref_dtypes}." + ) - if not eqn.primitive.multiple_results: - outvals = [outvals] safe_map(write, eqn.outvars, outvals) outvals = safe_map(read, jaxpr.outvars) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index 808fa19..e68fac9 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -52,7 +52,8 @@ def func(x): scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32) jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr # Need 3 equations: 2 mul + 1 cast. - assert len(jaxpr.eqns) == 3 + # TODO: additional mul in `safe_check_dtypes` mode. + assert len(jaxpr.eqns) in {3, 4} # Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray assert jaxpr.invars[0].aval.shape == scaled_input.shape assert jaxpr.invars[1].aval.shape == () @@ -67,7 +68,8 @@ def myfunc(x): scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32) jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr # One main jit equation. - assert len(jaxpr.eqns) == 1 + # TODO: additional mul in `safe_check_dtypes` mode. + assert len(jaxpr.eqns) in {1, 2} eqn = jaxpr.eqns[0] assert eqn.primitive.name in ("pjit", "xla_call") assert eqn.params["name"] == "myfunc" diff --git a/tests/lax/test_base_scaling_primitives.py b/tests/lax/test_base_scaling_primitives.py index 8a017f4..dbaf497 100644 --- a/tests/lax/test_base_scaling_primitives.py +++ b/tests/lax/test_base_scaling_primitives.py @@ -164,7 +164,7 @@ def fn(arr): return get_data_scale(arr) fn = self.variant(autoscale(fn)) - arr = scaled_array([2, 3], 4, dtype=np.float16) + arr = scaled_array([2, 3], np.float16(4), dtype=np.float16) data, scale = fn(arr) npt.assert_array_equal(data, arr.data) npt.assert_equal(scale, arr.scale)