From 8e14725cddc03263cc2fc5cc9b8cb1fdebd70f66 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 29 Jan 2024 11:39:41 +0000 Subject: [PATCH] Implement AutoScale/Scalify TracerMetaArray data structure. (#95) Introducing the dataclass `ScalifyTracerArray` in JSA interpreter/tracer in order to be able to pass additional metadata on the array (e.g. whether it is a broadcasted scalar tensor). --- jax_scaled_arithmetics/core/interpreters.py | 189 +++++++++++++++----- jax_scaled_arithmetics/core/utils.py | 16 ++ tests/core/test_interpreter.py | 54 +++++- tests/core/test_utils.py | 14 +- 4 files changed, 225 insertions(+), 48 deletions(-) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 9add00a..8de426d 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import IntEnum from functools import partial, wraps -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import jax import numpy as np @@ -14,9 +14,10 @@ custom_vjp_call_p, ) from jax._src.util import safe_map +from jax.tree_util import register_pytree_node_class -from .datatype import Array, DTypeLike, ScaledArray, as_scaled_array_base, is_scaled_leaf -from .utils import Pow2RoundMode +from .datatype import Array, ArrayTypes, DTypeLike, ScaledArray, Shape, as_scaled_array_base, is_scaled_leaf +from .utils import Pow2RoundMode, python_scalar_as_numpy @dataclass(frozen=True) @@ -68,7 +69,17 @@ class ScaledPrimitiveType(IntEnum): ALWAYS_SCALE = 2 +_scaled_jaxpr_ops_registry: Dict[core.Primitive, Any] = {} +"""Registry of (sub) "jaxpr" ops/primitives and their scaled translation. + +The core "jaxpr" primitives are typical `pjit`, `xla_call`, where the JSA interpreter +will need to be run on sub-jaxprs, passing the full metadata on input/output tensors. +""" + + _scaled_ops_registry: Dict[core.Primitive, Tuple[Any, ScaledPrimitiveType]] = {} +"""Registry of JAX common primitives and their scaled translation. +""" def _get_lax_prim(scaled_func: Any) -> core.Primitive: @@ -118,6 +129,8 @@ def register_scaled_op( scaled_type: Scaled primitive type => behaviour when `autoscale` tracing. """ assert isinstance(prim, core.Primitive) + # Can not register a jaxpr type op this way. + assert prim not in _scaled_jaxpr_ops_registry if prim in _scaled_ops_registry: raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.") _scaled_ops_registry[prim] = (scaled_func, scaled_type) @@ -147,6 +160,80 @@ def find_registered_scaled_op(prim: core.Primitive) -> Tuple[Any, ScaledPrimitiv return _scaled_ops_registry.get(prim, (None, ScaledPrimitiveType.NEVER)) +def promote_to_scaled_array(val, scale_dtype: Optional[DTypeLike] = None): + if isinstance(val, ScaledArray): + return val + elif np.ndim(val) == 0: + return promote_scalar_to_scaled_array(val, scale_dtype) + # No promotion rule => just return as such. + return val + + +@register_pytree_node_class +@dataclass(frozen=True, init=False) +class ScalifyTracerArray: + """Meta-Array class used in scalify tracer. It can represent + any array, scaled or not, and tracks whether an array corresponds to a scalar broadcasted. + + Compatible with JAX PyTrees in order to be able to trace a graph with `ScalifyTracerArray` + as inputs/outputs. + + Args: + array: Normal or scaled array. + is_broadcasted_scalar: Is the array a broadcasted scalar (metadata). + """ + + array: Union[Array, ScaledArray] = None + is_broadcasted_scalar: bool = False + + def __init__(self, arr: Union[Array, ScaledArray], is_broadcasted_scalar: Optional[bool] = None) -> None: + # Convert Python scalars, if necessary. + arr = python_scalar_as_numpy(arr) + assert isinstance(arr, (np.bool_, np.number, np.ndarray, ScaledArray, *ArrayTypes)) + object.__setattr__(self, "array", arr) + # Optional is broadcasted scalar information. + is_scalar = self.array.size == 1 + is_broadcasted_scalar = is_scalar if is_broadcasted_scalar is None else is_broadcasted_scalar or is_scalar + object.__setattr__(self, "is_broadcasted_scalar", is_broadcasted_scalar) + + def tree_flatten(self): + # See official JAX documentation on extending PyTrees. + # Note: using explicit tree flatten instead of chex for MyPy compatibility. + children = (self.array,) + aux_data = (self.is_broadcasted_scalar,) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + # See official JAX documentation on extending PyTrees. + assert len(aux_data) == 1 + assert len(children) == 1 + return cls(children[0], aux_data[0]) + + @property + def size(self) -> int: + return self.array.size + + @property + def shape(self) -> Shape: + return self.array.shape + + @property + def is_scaled_array(self) -> bool: + return isinstance(self.array, ScaledArray) + + def to_scaled_array(self, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray: + if self.is_scaled_array: + return self.array + # TODO: improve the logic for broadcasted scalar arrays! + return promote_to_scaled_array(self.array, scale_dtype) + + def to_array(self) -> Array: + if not self.is_scaled_array: + return self.array + return self.array.to_array() + + def autoscale(fun): """`autoscale` JAX graph transformation. @@ -174,8 +261,12 @@ def wrapped(*args, **kwargs): # Flattening of PyTree inputs. inputs_scaled = args inputs_scaled_flat, _ = jax.tree_util.tree_flatten(inputs_scaled, is_leaf=is_scaled_leaf) + # Convert to Scalify tracer (meta) arrays. + inputs_tracer_flat = list(map(ScalifyTracerArray, inputs_scaled_flat)) + consts_tracer_flat = list(map(ScalifyTracerArray, closed_jaxpr.literals)) # Trace the graph & convert to scaled one. - outputs_scaled_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *inputs_scaled_flat) + outputs_tracer_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, consts_tracer_flat, *inputs_tracer_flat) + outputs_scaled_flat = [v.array for v in outputs_tracer_flat] # Reconstruct the output Pytree, with scaled arrays. # NOTE: this step is also handling single vs multi outputs. assert len(out_leaves) == len(outputs_scaled_flat) @@ -185,66 +276,73 @@ def wrapped(*args, **kwargs): return wrapped -def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args): - env: Dict[core.Var, ScaledArray] = {} +def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]: + """Bind a Jaxpr equation to arrays.""" + subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) + outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params) + if not eqn.primitive.multiple_results: + outvals = [outvals] + return outvals + + +def autoscale_jaxpr(jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray): + env: Dict[core.Var, ScalifyTracerArray] = {} # Check dtype consistency between normal and scaled modes. safe_check_dtypes: bool = False # AutoScale config to use. autoscale_cfg = get_autoscale_config() - def read(var): + def read(var) -> ScalifyTracerArray: if type(var) is core.Literal: - return var.val + # Wrap the constant in tracer array. + return ScalifyTracerArray(var.val) return env[var] - def write(var, val): + def write(var, val: ScalifyTracerArray): env[var] = val - def promote_to_scaled_array(val, scale_dtype): - if isinstance(val, ScaledArray): - return val - elif np.ndim(val) == 0: - return promote_scalar_to_scaled_array(val, scale_dtype) - # No promotion rule => just return as such. - return val - - def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Sequence[core.ShapedArray]: - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params) - if not eqn.primitive.multiple_results: - outvals = [outvals] - return outvals - # A few initial checks to make sure there is consistency. assert len(jaxpr.invars) == len(args) safe_map(write, jaxpr.invars, args) safe_map(write, jaxpr.constvars, consts) for eqn in jaxpr.eqns: - invals = safe_map(read, eqn.invars) + invals_tracer: List[ScalifyTracerArray] = safe_map(read, eqn.invars) + if eqn.primitive in _scaled_jaxpr_ops_registry: + # Core sub-jaxpr primitive => pass the complete tracer array with metadata. + scaled_jaxpr_prim_fn = _scaled_jaxpr_ops_registry[eqn.primitive] + outvals_tracer = scaled_jaxpr_prim_fn(*invals_tracer, **eqn.params) + # Save outputs and move on! + safe_map(write, eqn.outvars, outvals_tracer) + continue + + # Common (scaled) JAX primitives path. # Is there any ScaledArray among inputs? - any_scaled_inputs = any([isinstance(v, ScaledArray) for v in invals]) + any_scaled_inputs = any([v.is_scaled_array for v in invals_tracer]) # Is there a scaled primitive associated? scaled_prim_fn, scaled_prim_type = _scaled_ops_registry.get(eqn.primitive, (None, ScaledPrimitiveType.NEVER)) if not any_scaled_inputs and scaled_prim_type != ScaledPrimitiveType.ALWAYS_SCALE: # Using normal JAX primitive: no scaled inputs, and not always scale rule. + invals = [v.to_array() for v in invals_tracer] outvals = jaxpr_eqn_bind(eqn, invals) + outvals_tracer = list(map(ScalifyTracerArray, outvals)) elif scaled_prim_fn is None: raise NotImplementedError( f"'{eqn.primitive}' JAX primitive does not have an implementation for ScaledArray inputs yet." ) else: # Using scaled primitive. Automatic promotion of inputs to scaled array, when possible. - scaled_invals = list(map(lambda v: promote_to_scaled_array(v, autoscale_cfg.scale_dtype), invals)) + scaled_invals = [v.to_scaled_array(autoscale_cfg.scale_dtype) for v in invals_tracer] outvals = scaled_prim_fn(*scaled_invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] + outvals_tracer = list(map(ScalifyTracerArray, outvals)) # Check consistency with normal JAX mode. Help catching dtype promotion errors. # NOTE: ignoring when no outputs! (e.g. debug_callback). if safe_check_dtypes and len(outvals) > 0: - ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v) for v in invals]) + ref_outvals = jaxpr_eqn_bind(eqn, [_get_data(v.array) for v in invals_tracer]) data_outvals = [_get_data(v) for v in outvals] # Check scaled dtypes == ref dtypes. ref_dtypes = tuple(v.dtype for v in ref_outvals) @@ -254,13 +352,13 @@ def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Se f"Output dtype of '{eqn.primitive}' scaled translation is not consistent with the JAX reference primitive implementation: {data_dtypes} vs {ref_dtypes}." ) - safe_map(write, eqn.outvars, outvals) + safe_map(write, eqn.outvars, outvals_tracer) - outvals = safe_map(read, jaxpr.outvars) - return outvals + outvals_tracer = safe_map(read, jaxpr.outvars) + return outvals_tracer -def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]: +def scaled_pjit_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]: """Scaled translation of `pjit`. Basically re-running `autoscale` on sub-jaxpr. NOTE: the `pjit` call will be kept, forwarding the proper parameters (shardings, ...). @@ -274,24 +372,24 @@ def scaled_pjit_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[Scale # in_shardings = kwargs["in_shardings"] # out_shardings = kwargs["out_shardings"] + consts_tracer_flat = [ScalifyTracerArray(v) for v in closed_jaxpr.literals] # Generate the sub-scaled function, with proper `jax.jit` options. - subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, closed_jaxpr.literals) + subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, consts_tracer_flat) subfunc.__name__ = name # type:ignore subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused) - - outputs_scaled_flat = subfunc(*args) - return outputs_scaled_flat + outvals = subfunc(*args) + return outvals try: from jax._src.pjit import pjit_p - register_scaled_op(pjit_p, scaled_pjit_translation) + _scaled_jaxpr_ops_registry[pjit_p] = scaled_pjit_translation except (ImportError, ModuleNotFoundError): pass -def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[ScaledArray]: +def scaled_xla_call_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]: """Scaled translation of `xla_call`. Basically re-running `autoscale` on sub-jaxpr. Useful for JAX 0.3 compatibility @@ -310,7 +408,6 @@ def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[S subfunc = partial(autoscale_jaxpr, jaxpr, []) subfunc.__name__ = name # type:ignore subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused) - outputs_scaled_flat = subfunc(*args) return outputs_scaled_flat @@ -318,12 +415,12 @@ def scaled_xla_call_translation(*args: ScaledArray, **kwargs: Any) -> Sequence[S try: from jax.interpreters.xla import xla_call_p - register_scaled_op(xla_call_p, scaled_xla_call_translation) + _scaled_jaxpr_ops_registry[xla_call_p] = scaled_xla_call_translation except (ImportError, ModuleNotFoundError): pass -def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]: +def scaled_custom_jvp_call_translation(*args: ScalifyTracerArray, **params: Any) -> Sequence[ScalifyTracerArray]: """Scaled translation of `custom_jvp_call` primitive. Forwarding the scaled call to sub-jaxpr, and modifying the underlying `jvp` function. """ @@ -337,11 +434,11 @@ def scaled_custom_jvp_call_translation(*args: ScaledArray, **params: Any) -> Seq return call_subfunc(*args) -register_scaled_op(custom_jvp_call_p, scaled_custom_jvp_call_translation) -register_scaled_op(custom_jvp_call_jaxpr_p, scaled_custom_jvp_call_translation) +_scaled_jaxpr_ops_registry[custom_jvp_call_p] = scaled_custom_jvp_call_translation +_scaled_jaxpr_ops_registry[custom_jvp_call_jaxpr_p] = scaled_custom_jvp_call_translation -def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Sequence[ScaledArray]: +def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) -> Sequence[ScalifyTracerArray]: """Scaled translation of `custom_vjp_call` primitive. Forwarding the scaled call to sub-jaxpr, and modifying the underlying `vjp` function. """ @@ -352,5 +449,5 @@ def scaled_custom_vjp_call_translation(*args: ScaledArray, **params: Any) -> Seq return call_subfunc(*args) -register_scaled_op(custom_vjp_call_p, scaled_custom_vjp_call_translation) -register_scaled_op(custom_vjp_call_jaxpr_p, scaled_custom_vjp_call_translation) +_scaled_jaxpr_ops_registry[custom_vjp_call_p] = scaled_custom_vjp_call_translation +_scaled_jaxpr_ops_registry[custom_vjp_call_jaxpr_p] = scaled_custom_vjp_call_translation diff --git a/jax_scaled_arithmetics/core/utils.py b/jax_scaled_arithmetics/core/utils.py index 9e8ef96..6d399ca 100644 --- a/jax_scaled_arithmetics/core/utils.py +++ b/jax_scaled_arithmetics/core/utils.py @@ -102,3 +102,19 @@ def safe_reciprocal(val: Array) -> Array: return np.reciprocal(val, out=np.array(0, dtype=val.dtype), where=val != 0) # JAX general implementation. return jax.lax.select(val == 0, val, jax.lax.reciprocal(val)) + + +def python_scalar_as_numpy(val: Any) -> Any: + """Convert Python scalar to Numpy scalar, if possible. + + Using by default JAX 32 bits precision, instead of 64 bits. + + Returning unchanged value if not any (bool, int, float). + """ + if isinstance(val, bool): + return np.bool_(val) + elif isinstance(val, int): + return np.int32(val) + elif isinstance(val, float): + return np.float32(val) + return val diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index abebc69..37681b9 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -19,7 +19,59 @@ register_scaled_op, scaled_array, ) -from jax_scaled_arithmetics.core.interpreters import promote_scalar_to_scaled_array +from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray, promote_scalar_to_scaled_array + + +class ScalifyTracerArrayTests(chex.TestCase): + @parameterized.parameters( + {"arr": True}, + {"arr": 2}, + {"arr": 3.0}, + ) + def test__scalify_tracer_array__init__from_python_value(self, arr): + tracer_arr = ScalifyTracerArray(arr) + assert tracer_arr.array == arr + assert not tracer_arr.is_scaled_array + assert tracer_arr.is_broadcasted_scalar == (tracer_arr.size == 1) + assert tracer_arr.to_array() is tracer_arr.array + + @parameterized.parameters( + {"arr": np.float32(2)}, + {"arr": np.array([1, 2])}, + {"arr": jnp.array([3, 4])}, + ) + def test__scalify_tracer_array__init__from_normal_array(self, arr): + tracer_arr = ScalifyTracerArray(arr) + assert tracer_arr.array is arr + assert not tracer_arr.is_scaled_array + assert tracer_arr.is_broadcasted_scalar == (tracer_arr.size == 1) + assert tracer_arr.to_array() is tracer_arr.array + # Basic properties. + assert tracer_arr.shape == arr.shape + assert tracer_arr.size == arr.size + + @parameterized.parameters({"arr": scaled_array([1, 2], 3.0)}) + def test__scalify_tracer_array__init__from_scaled_array(self, arr): + tracer_arr = ScalifyTracerArray(arr) + assert tracer_arr.array is arr + assert tracer_arr.is_scaled_array + assert tracer_arr.to_scaled_array() is tracer_arr.array + + def test__scalify_tracer_array__init__is_broadcasted_scalar_kwarg(self): + arr = scaled_array([1, 2], 3.0) + assert ScalifyTracerArray(arr, is_broadcasted_scalar=True).is_broadcasted_scalar + assert not ScalifyTracerArray(arr, is_broadcasted_scalar=False).is_broadcasted_scalar + + def test__scalify_tracer_array__flatten__proper_pytree(self): + arr = scaled_array([1, 2], 3.0) + tracer_arr_in = ScalifyTracerArray(arr, True) + # Proper round trip! + flat_arrays, pytree = jax.tree_util.tree_flatten(tracer_arr_in) + tracer_arr_out = jax.tree_util.tree_unflatten(pytree, flat_arrays) + + assert isinstance(tracer_arr_out, ScalifyTracerArray) + assert tracer_arr_out.is_broadcasted_scalar == tracer_arr_in.is_broadcasted_scalar + npt.assert_array_equal(np.asarray(tracer_arr_out.array), np.asarray(tracer_arr_in.array)) class AutoScaleInterpreterTests(chex.TestCase): diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index e174a34..0640550 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -6,7 +6,13 @@ from absl.testing import parameterized from jax_scaled_arithmetics.core import Array, pow2_round_down, pow2_round_up -from jax_scaled_arithmetics.core.utils import _exponent_bits_mask, get_mantissa, safe_div, safe_reciprocal +from jax_scaled_arithmetics.core.utils import ( + _exponent_bits_mask, + get_mantissa, + python_scalar_as_numpy, + safe_div, + safe_reciprocal, +) class Pow2RoundingUtilTests(chex.TestCase): @@ -90,3 +96,9 @@ def test__safe_reciprocal__zero_div(self, val): out = safe_reciprocal(val) assert out.dtype == val.dtype npt.assert_almost_equal(out, 0) + + +def test__python_scalar_as_numpy__proper_convertion(): + npt.assert_equal(python_scalar_as_numpy(False), np.bool_(False)) + npt.assert_equal(python_scalar_as_numpy(4), np.int32(4)) + npt.assert_equal(python_scalar_as_numpy(3.2), np.float32(3.2))