Skip to content

Commit

Permalink
Add safe_check_dtypes mode to AutoScale interpreter. (#82)
Browse files Browse the repository at this point in the history
Allowing to check in unit tests as well as large graphs that AutoScale graph interpreter is not silently promoting to more accurate floating datatype.
  • Loading branch information
balancap authored Jan 11, 2024
1 parent 4e01b4b commit 293c1c8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 10 deletions.
41 changes: 34 additions & 7 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from jax._src.util import safe_map

from .datatype import DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .datatype import Array, DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode


Expand Down Expand Up @@ -90,6 +90,12 @@ def _get_aval(val: Any) -> core.ShapedArray:
return core.ShapedArray(shape=val.shape, dtype=val.dtype)


def _get_data(val: Any) -> Array:
if isinstance(val, ScaledArray):
return val.data
return val


def promote_scalar_to_scaled_array(val: Any) -> ScaledArray:
"""Promote a scalar (Numpy, JAX, ...) to a Scaled Array.
Expand Down Expand Up @@ -192,6 +198,8 @@ def wrapped(*args, **kwargs):

def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}
# Check dtype consistency between normal and scaled modes.
safe_check_dtypes: bool = False

def read(var):
if type(var) is core.Literal:
Expand All @@ -209,6 +217,13 @@ def promote_to_scaled_array(val):
# No promotion rule => just return as such.
return val

def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
return outvals

# A few initial checks to make sure there is consistency.
assert len(jaxpr.invars) == len(args)
safe_map(write, jaxpr.invars, args)
Expand All @@ -223,19 +238,31 @@ def promote_to_scaled_array(val):

if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE:
# Using normal JAX primitive: no scaled inputs, and not always scale rule.
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
outvals = jaxpr_eqn_bind(eqn, invals)
elif scaled_prim_fn is None:
raise NotImplementedError(
f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet."
)
else:
# Using scaled primitive. Automatic promotion of inputs to scaled array, when possible.
invals = list(map(promote_to_scaled_array, invals))
outvals = scaled_prim_fn(*invals, **eqn.params)
scaled_invals = list(map(promote_to_scaled_array, invals))
outvals = scaled_prim_fn(*scaled_invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]

# Check consistency with normal JAX mode. Help catching dtype promotion errors.
# NOTE: ignoring when no outputs! (e.g. debug_callback).
if safe_check_dtypes and len(outvals) > 0:
ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v) for v in invals])
data_outvals = [_get_data(v) for v in outvals]
# Check scaled dtypes == ref dtypes.
ref_dtypes = tuple(v.dtype for v in ref_outvals)
data_dtypes = tuple(v.dtype for v in data_outvals)
if data_dtypes != ref_dtypes:
raise ValueError(
f"Output dtype of '{eqn.primitive}' scaled translation is not consistent with the JAX reference primitive implementation: {data_dtypes} vs {ref_dtypes}."
)

if not eqn.primitive.multiple_results:
outvals = [outvals]
safe_map(write, eqn.outvars, outvals)

outvals = safe_map(read, jaxpr.outvars)
Expand Down
6 changes: 4 additions & 2 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def func(x):
scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32)
jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr
# Need 3 equations: 2 mul + 1 cast.
assert len(jaxpr.eqns) == 3
# TODO: additional mul in `safe_check_dtypes` mode.
assert len(jaxpr.eqns) in {3, 4}
# Vars need to be primitive data types (e.g., f32) -> 2 Vars per ScaledArray
assert jaxpr.invars[0].aval.shape == scaled_input.shape
assert jaxpr.invars[1].aval.shape == ()
Expand All @@ -67,7 +68,8 @@ def myfunc(x):
scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32)
jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr
# One main jit equation.
assert len(jaxpr.eqns) == 1
# TODO: additional mul in `safe_check_dtypes` mode.
assert len(jaxpr.eqns) in {1, 2}
eqn = jaxpr.eqns[0]
assert eqn.primitive.name in ("pjit", "xla_call")
assert eqn.params["name"] == "myfunc"
Expand Down
2 changes: 1 addition & 1 deletion tests/lax/test_base_scaling_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def fn(arr):
return get_data_scale(arr)

fn = self.variant(autoscale(fn))
arr = scaled_array([2, 3], 4, dtype=np.float16)
arr = scaled_array([2, 3], np.float16(4), dtype=np.float16)
data, scale = fn(arr)
npt.assert_array_equal(data, arr.data)
npt.assert_equal(scale, arr.scale)
Expand Down

0 comments on commit 293c1c8

Please sign in to comment.