diff --git a/README.md b/README.md index 97d97db..2775411 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# JAX Scaled Arithmetics +# JAX Scalify: end-to-end scaled Arithmetics -**JAX Scaled Arithmetics** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of +**JAX Scalify** is a thin library implementing numerically stable scaled arithmetics, allowing easy training and inference of deep neural networks in low precision (BF16, FP16, FP8). Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Usually, these works have focused on ad-hoc approaches around scaling of matmuls (and sometimes reduction operations). The JSA library is adopting a more systematic approach by transforming the full computational graph into a `ScaledArray` graph, i.e. every operation taking `ScaledArray` inputs and returning `ScaledArray`, where the latter is a simple datastructure: @@ -14,14 +14,14 @@ class ScaledArray: return data * scale ``` -A typical JAX training loop requires just a few modifications to take advantage of `autoscale`: +A typical JAX training loop requires just a few modifications to take advantage of `scalify`: ```python -import jax_scaled_arithmetics as jsa +import jax_scalify as jsa params = jsa.as_scaled_array(params) @jit -@jsa.autoscale +@jsa.scalify def update(params, batch): grads = grad(loss)(params, batch) return opt_update(params, grads) @@ -30,7 +30,7 @@ for batch in batches: batch = jsa.as_scaled_array(batch) params = update(params, batch) ``` -In other words: model parameters and micro-batch are converted to `ScaledArray` objects, and the decorator `jsa.autoscale` properly transforms the graph into a scaled arithmetics graph (see the [MNIST examples](./experiments/mnist/) for more details). +In other words: model parameters and micro-batch are converted to `ScaledArray` objects, and the decorator `jsa.scalify` properly transforms the graph into a scaled arithmetics graph (see the [MNIST examples](./experiments/mnist/) for more details). There are multiple benefits to this systematic approach: @@ -46,7 +46,7 @@ There are multiple benefits to this systematic approach: JSA library can be easily installed in Python virtual environnment: ```bash -git clone git@github.com:graphcore-research/jax-scaled-arithmetics.git +git clone git@github.com:graphcore-research/jax-scalify.git pip install -e ./ ``` The main dependencies are `numpy`, `jax` and `chex` libraries. diff --git a/docs/design.md b/docs/design.md index 3b39e05..e324cb4 100644 --- a/docs/design.md +++ b/docs/design.md @@ -1,6 +1,6 @@ -# AutoScale: stable scaled arithmetics +# Scalify: stable scaled arithmetics ## Introduction @@ -63,7 +63,7 @@ Summary of how it would compare to existing methods (credits to @thecharlieblake | ~ perfect scale at initialisation | We can use the "unit scaling rules" to achieve approximately unit variance for the first fwd & bwd pass. Without this, our tensors immediately require re-scaling. Starting in the ideal range may mean few re-scalings are required. | Apply appropriate scaling factors to the output of our operations (as determined in the unit scaling work). | --> -## AutoScale: JAX scaled arithmetics +## JAX Scalify: end-to-end scaled arithmetics We focus here on a JAX implementation, but it should be possible to adopt a similar approach in Pytorch (with tracing or dynamo?). Modern ML frameworks expose their IR at the Python level, allowing users to perform complex transforms on the computational graph without any modification to the C++ backend (XLA, ...). diff --git a/docs/operators.md b/docs/operators.md index f0b0b97..2a79342 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -1,6 +1,6 @@ # JAX Scaled Operators coverage -Summary of JAX LAX operators supported in `autoscale` graph transformation. +Summary of JAX LAX operators supported in `scalify` graph transformation. ## [JAX LAX operations](https://jax.readthedocs.io/en/latest/jax.lax.html) diff --git a/examples/autoscale-quickstart.ipynb b/examples/autoscale-quickstart.ipynb index fb83cdd..f273d37 100644 --- a/examples/autoscale-quickstart.ipynb +++ b/examples/autoscale-quickstart.ipynb @@ -20,7 +20,7 @@ "source": [ "import numpy as np\n", "import jax\n", - "import jax_scaled_arithmetics as jsa" + "import jax_scalify as jsa" ] }, { @@ -38,8 +38,8 @@ "metadata": {}, "outputs": [], "source": [ - "# `autoscale` interpreter is tracing the graph, adding scale propagation where necessary.\n", - "@jsa.autoscale\n", + "# `scalify` interpreter is tracing the graph, adding scale propagation where necessary.\n", + "@jsa.scalify\n", "def fn(a, b):\n", " return a + b" ] @@ -127,13 +127,13 @@ } ], "source": [ - "# Running `fn` on scaled arrays triggers `autoscale` graph transformation\n", + "# Running `fn` on scaled arrays triggers `scalify` graph transformation\n", "sout = fn(sa, sb)\n", "# NOTE: by default, scale propagation is using power-of-2.\n", "print(\"SCALED OUTPUT:\", sout)\n", "\n", "# To choose a different scale rounding:\n", - "with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n", + "with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.NONE):\n", " print(\"No scale rounding:\", fn(sa, sb))" ] }, @@ -228,7 +228,7 @@ "# Similarly to `dynamic_rescale`, `cast_ml_dtype(_grad)` are available to cast in forward and backward passes\n", "sc = jsa.as_scaled_array(np.array([17., 19.]), scale=np.float32(1))\n", "\n", - "@jsa.autoscale\n", + "@jsa.scalify\n", "def cast_fn(v):\n", " return jsa.ops.cast_ml_dtype(v, ml_dtypes.float8_e4m3fn)\n", "\n", diff --git a/experiments/mnist/cifar_training.py b/experiments/mnist/cifar_training.py index d1fb772..d715d1d 100644 --- a/experiments/mnist/cifar_training.py +++ b/experiments/mnist/cifar_training.py @@ -27,7 +27,7 @@ import numpy.random as npr from jax import grad, jit, lax -import jax_scaled_arithmetics as jsa +import jax_scalify as jsa def logsumexp(a, axis=None, keepdims=False): @@ -88,8 +88,8 @@ def accuracy(params, batch): if __name__ == "__main__": width = 2048 lr = 1e-4 - use_autoscale = True - autoscale = jsa.autoscale if use_autoscale else lambda f: f + use_scalify = True + scalify = jsa.scalify if use_scalify else lambda f: f layer_sizes = [3072, width, width, 10] param_scale = 1.0 @@ -116,12 +116,12 @@ def data_stream(): batches = data_stream() params = init_random_params(param_scale, layer_sizes) # Transform parameters to `ScaledArray` and proper dtype. - if use_autoscale: + if use_scalify: params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit - @autoscale + @scalify def update(params, batch): grads = grad(loss)(params, batch) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] @@ -131,11 +131,11 @@ def update(params, batch): for _ in range(num_batches): batch = next(batches) # Scaled micro-batch + training dtype cast. - if use_autoscale: + if use_scalify: batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale)) batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) - with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): + with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) epoch_time = time.time() - start_time diff --git a/experiments/mnist/flax_example/train.py b/experiments/mnist/flax_example/train.py index 75a5dde..3c09e2c 100644 --- a/experiments/mnist/flax_example/train.py +++ b/experiments/mnist/flax_example/train.py @@ -32,7 +32,7 @@ from flax.metrics import tensorboard from flax.training import train_state -import jax_scaled_arithmetics as jsa +import jax_scalify as jsa class CNN(nn.Module): @@ -75,7 +75,7 @@ def update_model(state, grads): @jax.jit -@jsa.autoscale +@jsa.scalify def apply_and_update_model(state, batch_images, batch_labels): # Jitting together forward + backward + update. grads, loss, accuracy = apply_model(state, batch_images, batch_labels) diff --git a/experiments/mnist/mnist_classifier.py b/experiments/mnist/mnist_classifier.py index 1f052cf..efb058a 100644 --- a/experiments/mnist/mnist_classifier.py +++ b/experiments/mnist/mnist_classifier.py @@ -31,7 +31,7 @@ from jax.example_libraries import optimizers, stax from jax.example_libraries.stax import Dense, LogSoftmax, Relu -import jax_scaled_arithmetics as jsa +import jax_scalify as jsa def loss(params, batch): @@ -75,7 +75,7 @@ def data_stream(): opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass) @jit - @jsa.autoscale + @jsa.scalify def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/experiments/mnist/mnist_classifier_from_scratch.py index ccf4268..69cab06 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/experiments/mnist/mnist_classifier_from_scratch.py @@ -27,7 +27,7 @@ import numpy.random as npr from jax import grad, jit, lax -import jax_scaled_arithmetics as jsa +import jax_scalify as jsa # from jax.scipy.special import logsumexp @@ -107,7 +107,7 @@ def data_stream(): params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit - @jsa.autoscale + @jsa.scalify def update(params, batch): grads = grad(loss)(params, batch) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] @@ -120,7 +120,7 @@ def update(params, batch): batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) - with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): + with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) epoch_time = time.time() - start_time diff --git a/experiments/mnist/mnist_classifier_from_scratch_fp8.py b/experiments/mnist/mnist_classifier_from_scratch_fp8.py index 4a84055..01912f3 100644 --- a/experiments/mnist/mnist_classifier_from_scratch_fp8.py +++ b/experiments/mnist/mnist_classifier_from_scratch_fp8.py @@ -28,7 +28,7 @@ import numpy.random as npr from jax import grad, jit, lax -import jax_scaled_arithmetics as jsa +import jax_scalify as jsa # from functools import partial @@ -134,7 +134,7 @@ def data_stream(): params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit - @jsa.autoscale + @jsa.scalify def update(params, batch): grads = grad(loss)(params, batch) return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] @@ -147,7 +147,7 @@ def update(params, batch): batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) - with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): + with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) epoch_time = time.time() - start_time diff --git a/experiments/mnist/optax_cifar_training.py b/experiments/mnist/optax_cifar_training.py index eedb8c7..3a078b6 100644 --- a/experiments/mnist/optax_cifar_training.py +++ b/experiments/mnist/optax_cifar_training.py @@ -28,7 +28,7 @@ import optax from jax import grad, jit, lax -import jax_scaled_arithmetics as jsa +import jax_scalify as jsa def logsumexp(a, axis=None, keepdims=False): @@ -92,9 +92,9 @@ def accuracy(params, batch): if __name__ == "__main__": width = 256 lr = 1e-3 - use_autoscale = False + use_scalify = False training_dtype = np.float32 - autoscale = jsa.autoscale if use_autoscale else lambda f: f + scalify = jsa.scalify if use_scalify else lambda f: f layer_sizes = [3072, width, width, 10] param_scale = 1.0 @@ -123,13 +123,13 @@ def data_stream(): optimizer = optax.adam(learning_rate=lr, eps=1e-5) opt_state = optimizer.init(params) - if use_autoscale: + if use_scalify: params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit - @autoscale + @scalify def update(params, batch, opt_state): grads = grad(loss)(params, batch) updates, opt_state = optimizer.update(grads, opt_state) @@ -141,11 +141,11 @@ def update(params, batch, opt_state): for _ in range(num_batches): batch = next(batches) # Scaled micro-batch + training dtype cast. - if use_autoscale: + if use_scalify: batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale)) batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) - with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): + with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params, opt_state = update(params, batch, opt_state) epoch_time = time.time() - start_time diff --git a/jax_scaled_arithmetics/__init__.py b/jax_scalify/__init__.py similarity index 88% rename from jax_scaled_arithmetics/__init__.py rename to jax_scalify/__init__.py index 74eea10..838fd2c 100644 --- a/jax_scaled_arithmetics/__init__.py +++ b/jax_scalify/__init__.py @@ -2,12 +2,12 @@ from . import core, lax, ops from ._version import __version__ from .core import ( # noqa: F401 - AutoScaleConfig, Pow2RoundMode, ScaledArray, + ScalifyConfig, as_scaled_array, asarray, - autoscale, debug_callback, scaled_array, + scalify, ) diff --git a/jax_scaled_arithmetics/_version.py b/jax_scalify/_version.py similarity index 100% rename from jax_scaled_arithmetics/_version.py rename to jax_scalify/_version.py diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scalify/core/__init__.py similarity index 92% rename from jax_scaled_arithmetics/core/__init__.py rename to jax_scalify/core/__init__.py index aefd0dc..d82e7f5 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scalify/core/__init__.py @@ -15,13 +15,13 @@ ) from .debug import debug_callback # noqa: F401 from .interpreters import ( # noqa: F401 - AutoScaleConfig, ScaledPrimitiveType, - autoscale, + ScalifyConfig, find_registered_scaled_op, - get_autoscale_config, + get_scalify_config, register_scaled_lax_op, register_scaled_op, + scalify, ) from .pow2 import Pow2RoundMode, pow2_decompose, pow2_round, pow2_round_down, pow2_round_up # noqa: F401 from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401 diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scalify/core/datatype.py similarity index 100% rename from jax_scaled_arithmetics/core/datatype.py rename to jax_scalify/core/datatype.py diff --git a/jax_scaled_arithmetics/core/debug.py b/jax_scalify/core/debug.py similarity index 100% rename from jax_scaled_arithmetics/core/debug.py rename to jax_scalify/core/debug.py diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scalify/core/interpreters.py similarity index 92% rename from jax_scaled_arithmetics/core/interpreters.py rename to jax_scalify/core/interpreters.py index 4fc1050..0a69705 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scalify/core/interpreters.py @@ -31,11 +31,11 @@ @dataclass(frozen=True) -class AutoScaleConfig: - """AutoScale configuration/parameters when tracing a graph. +class ScalifyConfig: + """Scalify configuration/parameters when tracing a graph. NOTE: this config can be locally changed using a Python context manager: - `with AutoScaleConfig(...):` + `with ScalifyConfig(...):` Args: rounding_mode: Power-of-2 rounding mode. @@ -46,27 +46,27 @@ class AutoScaleConfig: scale_dtype: DTypeLike = None def __enter__(self): - global _autoscale_config_stack - _autoscale_config_stack.append(self) + global _scalify_config_stack + _scalify_config_stack.append(self) def __exit__(self, exc_type, exc_val, exc_tb): - global _autoscale_config_stack - _autoscale_config_stack.pop() + global _scalify_config_stack + _scalify_config_stack.pop() -# AutoScale config stack. -_autoscale_config_stack = [AutoScaleConfig()] +# Scalify config stack. +_scalify_config_stack = [ScalifyConfig()] -def get_autoscale_config() -> AutoScaleConfig: - """Get current/local autoscale config.""" - return _autoscale_config_stack[-1] +def get_scalify_config() -> ScalifyConfig: + """Get current/local scalify config.""" + return _scalify_config_stack[-1] class ScaledPrimitiveType(IntEnum): """Scale (JAX) primitive type. - This enum described the behaviour when `autoscale` is + This enum described the behaviour when `scalify` is tracing the graph. FORWARD: Forwarding scaling => only used if scaled inputs. @@ -144,7 +144,7 @@ def register_scaled_op( Args: prim: JAX primitive. scaled_func: Scaled translation of the primitive. With the same interface. - scaled_type: Scaled primitive type => behaviour when `autoscale` tracing. + scaled_type: Scaled primitive type => behaviour when `scalify` tracing. """ assert isinstance(prim, core.Primitive) # Can not register a jaxpr type op this way. @@ -294,10 +294,10 @@ def to_array(self) -> Array: return self.array.to_array() -def autoscale(fun): - """`autoscale` JAX graph transformation. +def scalify(fun): + """`scalify` JAX graph transformation. - The `autoscale` graph transformation works in a forwarding mode: + The `scalify` graph transformation works in a forwarding mode: scaled arrays are forwarded to scaled primitives, which will generate scaled outputs. If no inputs to a JAX primitive are scaled -> the normal primitive is then called, generating a common @@ -311,7 +311,7 @@ def autoscale(fun): @wraps(fun) def wrapped(*args, **kwargs): if len(kwargs) > 0: - raise NotImplementedError("`autoscale` JAX interpreter not supporting named tensors at present.") + raise NotImplementedError("`scalify` JAX interpreter not supporting named tensors at present.") aval_args = jax.tree_util.tree_map(_get_aval, args, is_leaf=is_scaled_leaf) # Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well. @@ -325,7 +325,7 @@ def wrapped(*args, **kwargs): inputs_tracer_flat = list(map(ScalifyTracerArray, inputs_scaled_flat)) consts_tracer_flat = list(map(ScalifyTracerArray, closed_jaxpr.literals)) # Trace the graph & convert to scaled one. - outputs_tracer_flat = autoscale_jaxpr(closed_jaxpr.jaxpr, consts_tracer_flat, *inputs_tracer_flat) + outputs_tracer_flat = scalify_jaxpr(closed_jaxpr.jaxpr, consts_tracer_flat, *inputs_tracer_flat) outputs_scaled_flat = [v.array for v in outputs_tracer_flat] # Reconstruct the output Pytree, with scaled arrays. # NOTE: this step is also handling single vs multi outputs. @@ -345,14 +345,14 @@ def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Se return outvals -def autoscale_jaxpr( +def scalify_jaxpr( jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray ) -> Sequence[ScalifyTracerArray]: env: Dict[core.Var, ScalifyTracerArray] = {} # Check dtype consistency between normal and scaled modes. safe_check_dtypes: bool = False - # AutoScale config to use. - autoscale_cfg = get_autoscale_config() + # Scalify config to use. + scalify_cfg = get_scalify_config() def read(var: core.Var) -> ScalifyTracerArray: if type(var) is core.Literal: @@ -411,7 +411,7 @@ def write(var: core.Var, val: ScalifyTracerArray) -> None: ) else: # Using scaled primitive. Automatic promotion of inputs to scaled array, when possible. - scaled_invals = [v.to_scaled_array(autoscale_cfg.scale_dtype) for v in invals_tracer] + scaled_invals = [v.to_scaled_array(scalify_cfg.scale_dtype) for v in invals_tracer] outvals = scaled_prim_fn(*scaled_invals, **eqn.params) if not eqn.primitive.multiple_results: outvals = [outvals] @@ -437,7 +437,7 @@ def write(var: core.Var, val: ScalifyTracerArray) -> None: def scaled_pjit_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]: - """Scaled translation of `pjit`. Basically re-running `autoscale` on sub-jaxpr. + """Scaled translation of `pjit`. Basically re-running `scalify` on sub-jaxpr. NOTE: the `pjit` call will be kept, forwarding the proper parameters (shardings, ...). """ @@ -452,7 +452,7 @@ def scaled_pjit_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequenc consts_tracer_flat = [ScalifyTracerArray(v) for v in closed_jaxpr.literals] # Generate the sub-scaled function, with proper `jax.jit` options. - subfunc = partial(autoscale_jaxpr, closed_jaxpr.jaxpr, consts_tracer_flat) + subfunc = partial(scalify_jaxpr, closed_jaxpr.jaxpr, consts_tracer_flat) subfunc.__name__ = name # type:ignore subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused) outvals = subfunc(*args) @@ -468,7 +468,7 @@ def scaled_pjit_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequenc def scaled_xla_call_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Sequence[ScalifyTracerArray]: - """Scaled translation of `xla_call`. Basically re-running `autoscale` on sub-jaxpr. + """Scaled translation of `xla_call`. Basically re-running `scalify` on sub-jaxpr. Useful for JAX 0.3 compatibility """ @@ -483,7 +483,7 @@ def scaled_xla_call_translation(*args: ScalifyTracerArray, **kwargs: Any) -> Seq assert len(jaxpr.constvars) == 0 # Generate the sub-scaled function, with proper `jax.jit` options. - subfunc = partial(autoscale_jaxpr, jaxpr, []) + subfunc = partial(scalify_jaxpr, jaxpr, []) subfunc.__name__ = name # type:ignore subfunc = jax.jit(subfunc, inline=inline, keep_unused=keep_unused) outputs_scaled_flat = subfunc(*args) @@ -508,7 +508,7 @@ def scaled_custom_jvp_call_translation(*args: ScalifyTracerArray, **params: Any) # JAX 0.3 compatibility. assert params.get("num_consts", 0) == 0 # FIXME: re-call the custom_jvp decorator/bind. - call_subfunc = partial(autoscale_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals) + call_subfunc = partial(scalify_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals) return call_subfunc(*args) @@ -523,7 +523,7 @@ def scaled_custom_vjp_call_translation(*args: ScalifyTracerArray, **params: Any) key_jaxpr = "fun_jaxpr" call_closed_jaxpr = params[key_jaxpr] # FIXME: re-call the custom_vjp decorator/bind. - call_subfunc = partial(autoscale_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals) + call_subfunc = partial(scalify_jaxpr, call_closed_jaxpr.jaxpr, call_closed_jaxpr.literals) return call_subfunc(*args) diff --git a/jax_scaled_arithmetics/core/pow2.py b/jax_scalify/core/pow2.py similarity index 100% rename from jax_scaled_arithmetics/core/pow2.py rename to jax_scalify/core/pow2.py diff --git a/jax_scaled_arithmetics/core/typing.py b/jax_scalify/core/typing.py similarity index 100% rename from jax_scaled_arithmetics/core/typing.py rename to jax_scalify/core/typing.py diff --git a/jax_scaled_arithmetics/core/utils.py b/jax_scalify/core/utils.py similarity index 100% rename from jax_scaled_arithmetics/core/utils.py rename to jax_scalify/core/utils.py diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scalify/lax/__init__.py similarity index 100% rename from jax_scaled_arithmetics/lax/__init__.py rename to jax_scalify/lax/__init__.py diff --git a/jax_scaled_arithmetics/lax/base_scaling_primitives.py b/jax_scalify/lax/base_scaling_primitives.py similarity index 92% rename from jax_scaled_arithmetics/lax/base_scaling_primitives.py rename to jax_scalify/lax/base_scaling_primitives.py index e564640..c07f1ec 100644 --- a/jax_scaled_arithmetics/lax/base_scaling_primitives.py +++ b/jax_scalify/lax/base_scaling_primitives.py @@ -7,13 +7,13 @@ from jax.interpreters import mlir from jax.interpreters.mlir import LoweringRuleContext, ir, ir_constant -from jax_scaled_arithmetics.core import ( +from jax_scalify.core import ( Array, DTypeLike, ScaledArray, ScaledPrimitiveType, asarray, - get_autoscale_config, + get_scalify_config, is_static_one_scalar, register_scaled_op, safe_div, @@ -26,7 +26,7 @@ In standard JAX, this is just an identity operation, ignoring the `scale` input, just returning unchanged the `data` component. -In JAX Scaled Arithmetics/AutoScale mode, it will rebalance the data term to +In JAX Scalify mode, it will rebalance the data term to return a ScaledArray semantically equivalent. NOTE: there is specific corner case of passing zero to `set_scaling`. In this @@ -102,8 +102,7 @@ def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray: In standard JAX, this is just an identity operation (with optional casting). -In JAX Scaled Arithmetics/AutoScale mode, it will return the value tensor, -with optional casting. +In JAX Scalify mode, it will return the value tensor, with optional casting. Similar in principle to `jax.lax.stop_gradient` """ @@ -160,14 +159,14 @@ def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None) In standard JAX, this is just an operation returning the input array and a constant scalar(1). -In JAX Scaled Arithmetics/AutoScale mode, it will return the pair of data and scale tensors +In JAX Scalify mode, it will return the pair of data and scale tensors from a ScaledArray. """ def get_scale_dtype() -> Optional[DTypeLike]: - """Get the scale dtype, if set in the AutoScale config.""" - return get_autoscale_config().scale_dtype + """Get the scale dtype, if set in the Scalify config.""" + return get_scalify_config().scale_dtype def get_data_scale(values: Array) -> Array: @@ -208,10 +207,10 @@ def get_data_scale_mlir_lowering( def scaled_get_data_scale(values: ScaledArray) -> Array: """Scaled `get_data_scale` implementation: return scale tensor.""" scale_dtype = get_scale_dtype() - # Mis-match may potentially create issues (i.e. not equivalent scale dtype after autoscale tracer)! + # Mis-match may potentially create issues (i.e. not equivalent scale dtype after scalify tracer)! if scale_dtype != values.scale.dtype: logging.warning( - f"Autoscale config scale dtype not matching ScaledArray scale dtype: '{values.scale.dtype}' vs '{scale_dtype}'. AutoScale graph transformation may fail because of that." + f"Scalify config scale dtype not matching ScaledArray scale dtype: '{values.scale.dtype}' vs '{scale_dtype}'. Scalify graph transformation may fail because of that." ) return values.data, values.scale diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scalify/lax/scaled_ops_common.py similarity index 99% rename from jax_scaled_arithmetics/lax/scaled_ops_common.py rename to jax_scalify/lax/scaled_ops_common.py index ce6ca5b..17850f2 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scalify/lax/scaled_ops_common.py @@ -7,8 +7,8 @@ import numpy as np from jax import lax -from jax_scaled_arithmetics import core -from jax_scaled_arithmetics.core import ( +from jax_scalify import core +from jax_scalify.core import ( Array, DTypeLike, ScaledArray, diff --git a/jax_scaled_arithmetics/lax/scaled_ops_l2.py b/jax_scalify/lax/scaled_ops_l2.py similarity index 95% rename from jax_scaled_arithmetics/lax/scaled_ops_l2.py rename to jax_scalify/lax/scaled_ops_l2.py index ca92247..de7bb3f 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_l2.py +++ b/jax_scalify/lax/scaled_ops_l2.py @@ -6,15 +6,8 @@ from jax import lax from jax._src.ad_util import add_any_p -from jax_scaled_arithmetics import core -from jax_scaled_arithmetics.core import ( - DTypeLike, - ScaledArray, - get_autoscale_config, - pow2_round, - register_scaled_op, - safe_div, -) +from jax_scalify import core +from jax_scalify.core import DTypeLike, ScaledArray, get_scalify_config, pow2_round, register_scaled_op, safe_div from .scaled_ops_common import check_scalar_scales, promote_scale_types @@ -27,7 +20,7 @@ def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArra A, B = promote_scale_types(A, B) assert np.issubdtype(A.scale.dtype, np.floating) # Pow2 rounding for unit scaling "rule". - pow2_rounding_mode = get_autoscale_config().rounding_mode + pow2_rounding_mode = get_scalify_config().rounding_mode # TODO: what happens to `sqrt` for non-floating scale? # More stable than direct L2 norm, to avoid scale overflow. ABscale_max = lax.max(A.scale, B.scale) @@ -75,7 +68,7 @@ def scaled_dot_general( assert len(rhs_contracting_dims) == 1 # Pow2 rounding for unit scaling "rule". - pow2_rounding_mode = get_autoscale_config().rounding_mode + pow2_rounding_mode = get_scalify_config().rounding_mode contracting_dim_size = lhs.shape[lhs_contracting_dims[0]] # "unit scaling" rule, based on the contracting axis. outscale_dtype = jnp.promote_types(lhs.scale.dtype, rhs.scale.dtype) @@ -111,7 +104,7 @@ def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: scale_dtype = val.scale.dtype axes_size = np.array([shape[idx] for idx in axes]) # Pow2 rounding for unit scaling "rule". - pow2_rounding_mode = get_autoscale_config().rounding_mode + pow2_rounding_mode = get_scalify_config().rounding_mode # Rescale data component following reduction axes & round to power of 2 value. axes_rescale = np.sqrt(np.prod(axes_size)).astype(scale_dtype) axes_rescale = pow2_round(axes_rescale, pow2_rounding_mode) diff --git a/jax_scaled_arithmetics/ops/__init__.py b/jax_scalify/ops/__init__.py similarity index 100% rename from jax_scaled_arithmetics/ops/__init__.py rename to jax_scalify/ops/__init__.py diff --git a/jax_scaled_arithmetics/ops/debug.py b/jax_scalify/ops/debug.py similarity index 93% rename from jax_scaled_arithmetics/ops/debug.py rename to jax_scalify/ops/debug.py index 1db09ab..dc18779 100644 --- a/jax_scaled_arithmetics/ops/debug.py +++ b/jax_scalify/ops/debug.py @@ -4,7 +4,7 @@ import jax -from jax_scaled_arithmetics.core import Array, debug_callback +from jax_scalify.core import Array, debug_callback @partial(jax.custom_vjp, nondiff_argnums=(0,)) diff --git a/jax_scaled_arithmetics/ops/ml_dtypes.py b/jax_scalify/ops/ml_dtypes.py similarity index 94% rename from jax_scaled_arithmetics/ops/ml_dtypes.py rename to jax_scalify/ops/ml_dtypes.py index 940dcdd..0766e6e 100644 --- a/jax_scaled_arithmetics/ops/ml_dtypes.py +++ b/jax_scalify/ops/ml_dtypes.py @@ -4,7 +4,7 @@ import jax import ml_dtypes -from jax_scaled_arithmetics.core import Array, DTypeLike +from jax_scalify.core import Array, DTypeLike from .rescaling import fn_bwd_identity_fwd, fn_fwd_identity_bwd diff --git a/jax_scaled_arithmetics/ops/rescaling.py b/jax_scalify/ops/rescaling.py similarity index 95% rename from jax_scaled_arithmetics/ops/rescaling.py rename to jax_scalify/ops/rescaling.py index f2e0325..62f6b44 100644 --- a/jax_scaled_arithmetics/ops/rescaling.py +++ b/jax_scalify/ops/rescaling.py @@ -4,8 +4,8 @@ import jax import numpy as np -from jax_scaled_arithmetics.core import ScaledArray, pow2_round, pow2_round_down -from jax_scaled_arithmetics.lax import get_data_scale, rebalance +from jax_scalify.core import ScaledArray, pow2_round, pow2_round_down +from jax_scalify.lax import get_data_scale, rebalance @partial(jax.custom_vjp, nondiff_argnums=(0,)) diff --git a/pyproject.toml b/pyproject.toml index b093c06..3d7c492 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "jax_scaled_arithmetics" +name = "jax_scalify" version = "0.1" -description="JAX Scaled Arithmetics." +description="JAX Scalify: end-to-end scaled arithmetics." readme = "README.md" authors = [ { name = "Graphcore Research", email = "paulb@graphcore.ai" }, @@ -23,15 +23,15 @@ dependencies = [ ] [project.urls] -Website = "https://github.com/graphcore-research/jax-scaled-arithmetics/#readme" -"Source Code" = "https://github.com/graphcore-research/jax-scaled-arithmetics/" -"Bug Tracker" = "https://github.com/graphcore-research/jax-scaled-arithmetics/issues" +Website = "https://github.com/graphcore-research/jax-scalify/#readme" +"Source Code" = "https://github.com/graphcore-research/jax-scalify/" +"Bug Tracker" = "https://github.com/graphcore-research/jax-scalify/issues" [project.optional-dependencies] test = ["pytest"] [tool.setuptools] -packages = ["jax_scaled_arithmetics", "jax_scaled_arithmetics.core", "jax_scaled_arithmetics.lax", "jax_scaled_arithmetics.ops"] +packages = ["jax_scalify", "jax_scalify.core", "jax_scalify.lax", "jax_scalify.ops"] [tool.pytest.ini_options] minversion = "6.0" diff --git a/setup.cfg b/setup.cfg index f0b8046..b62f264 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,4 +4,4 @@ max-complexity = 20 min_python_version = 3.8 ignore = F401 per-file-ignores = - jax_scaled_arithmetics/__init__.py: F401 + jax_scalify/__init__.py: F401 diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index 512877e..b7a6d06 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -6,7 +6,7 @@ from absl.testing import parameterized from jax.core import ShapedArray -from jax_scaled_arithmetics.core import ( +from jax_scalify.core import ( Array, ScaledArray, as_scaled_array, diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index e83dda1..8235c83 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -8,19 +8,19 @@ from absl.testing import parameterized from numpy.typing import NDArray -from jax_scaled_arithmetics.core import ( +from jax_scalify.core import ( Array, - AutoScaleConfig, Pow2RoundMode, ScaledArray, + ScalifyConfig, asarray, - autoscale, - get_autoscale_config, + get_scalify_config, is_scaled_leaf, register_scaled_op, scaled_array, + scalify, ) -from jax_scaled_arithmetics.core.interpreters import ScalifyTracerArray +from jax_scalify.core.interpreters import ScalifyTracerArray class ScalifyTracerArrayTests(chex.TestCase): @@ -136,17 +136,17 @@ def test__scalify_tracer_array__to_scaled_array__broadcasted_scalar_input(self): npt.assert_array_equal(np.asarray(scaled_out), data) -class AutoScaleInterpreterTests(chex.TestCase): +class ScalifyInterpreterTests(chex.TestCase): def test__register_scaled_op__error_if_already_registered(self): with self.assertRaises(KeyError): register_scaled_op(jax.lax.mul_p, lambda a, _: a) @chex.variants(with_jit=True, without_jit=True) - def test__autoscale_interpreter__normal_jax_mode(self): + def test__scalify_interpreter__normal_jax_mode(self): def func(x): return x * 2 - func = self.variant(autoscale(func)) + func = self.variant(scalify(func)) data: NDArray[np.float32] = np.array([1, 2], dtype=np.float32) out = func(data) # Proper behaviour! @@ -158,11 +158,11 @@ def func(x): assert len(jaxpr.outvars) == 1 assert len(jaxpr.eqns) == 1 - def test__autoscale_interpreter__without_jit__proper_jaxpr_signature(self): + def test__scalify_interpreter__without_jit__proper_jaxpr_signature(self): def func(x): return x * 2 - scaled_func = autoscale(func) + scaled_func = scalify(func) scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32) jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr # Need 4 equations: 1 pow2_decompose + 2 mul + 1 cast. @@ -175,11 +175,11 @@ def func(x): assert jaxpr.outvars[0].aval.shape == scaled_input.shape assert jaxpr.outvars[1].aval.shape == () - def test__autoscale_interpreter__with_jit__proper_jaxpr_signature(self): + def test__scalify_interpreter__with_jit__proper_jaxpr_signature(self): def myfunc(x): return x * 2 - scaled_func = autoscale(jax.jit(myfunc)) + scaled_func = scalify(jax.jit(myfunc)) scaled_input = scaled_array([1.0, 2.0], 3, dtype=np.float32) jaxpr = jax.make_jaxpr(scaled_func)(scaled_input).jaxpr # One main jit equation. @@ -237,9 +237,9 @@ def myfunc(x): # "inputs": [scaled_array([[-2.0], [0.5]], 0.5, dtype=np.float32)], # }, ) - def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, inputs): - # Autoscale function + (optional) jitting. - scaled_fn = self.variant(autoscale(fn)) + def test__scalify_decorator__proper_graph_transformation_and_result(self, fn, inputs): + # Scalify function + (optional) jitting. + scaled_fn = self.variant(scalify(fn)) scaled_output = scaled_fn(*inputs) # Normal JAX path, without scaled arrays. raw_inputs = jax.tree_util.tree_map(np.asarray, inputs, is_leaf=is_scaled_leaf) @@ -259,7 +259,7 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn, npt.assert_array_almost_equal(scaled_out, exp_out, decimal=4) @chex.variants(with_jit=True, without_jit=True) - def test__autoscale_decorator__promotion_broadcasted_scalar_array(self): + def test__scalify_decorator__promotion_broadcasted_scalar_array(self): def fn(sa, b): # Forcing broadcasting before the `lax.mul` b = jax.lax.broadcast_in_dim(b, sa.shape, ()) @@ -268,7 +268,7 @@ def fn(sa, b): sa = scaled_array([0.5, 1.0], np.float32(4.0), dtype=np.float32) b = jnp.array(4.0, dtype=np.float16) - scaled_fn = self.variant(autoscale(fn)) + scaled_fn = self.variant(scalify(fn)) sout = scaled_fn(sa, b) expected_out = fn(np.asarray(sa), b) @@ -278,7 +278,7 @@ def fn(sa, b): npt.assert_array_equal(np.asarray(sout), expected_out) @chex.variants(with_jit=True, without_jit=True) - def test__autoscale_decorator__custom_jvp__proper_graph_transformation_and_result(self): + def test__scalify_decorator__custom_jvp__proper_graph_transformation_and_result(self): # JAX official `jvp` example. @jax.custom_jvp def f(x, y): @@ -295,12 +295,12 @@ def f_jvp(primals, tangents): def fn(x, y): return jax.jvp(f, (x, y), (x, y)) - # `autoscale` on `custom_jvp` method. + # `scalify` on `custom_jvp` method. scaled_inputs = ( scaled_array([-2.0, 0.5], 0.5, dtype=np.float32), scaled_array([1.5, -4.5], 2, dtype=np.float32), ) - scaled_primals, scaled_tangents = self.variant(autoscale(fn))(*scaled_inputs) + scaled_primals, scaled_tangents = self.variant(scalify(fn))(*scaled_inputs) # JAX default/expected values inputs = tuple(map(asarray, scaled_inputs)) primals, tangents = self.variant(fn)(*inputs) @@ -311,7 +311,7 @@ def fn(x, y): npt.assert_array_almost_equal(scaled_tangents, tangents) @chex.variants(with_jit=True, without_jit=True) - def test__autoscale_decorator__custom_vjp__proper_graph_transformation_and_result(self): + def test__scalify_decorator__custom_vjp__proper_graph_transformation_and_result(self): # JAX official `vjp` example. @jax.custom_vjp def f(x, y): @@ -330,12 +330,12 @@ def fn(x, y): primals, f_vjp = jax.vjp(f, x, y) return primals, f_vjp(x * y) - # `autoscale` on `custom_jvp` method. + # `scalify` on `custom_jvp` method. scaled_inputs = ( scaled_array([-2.0, 0.5], 0.5, dtype=np.float32), scaled_array([1.5, -4.5], 2, dtype=np.float32), ) - scaled_primals, scaled_grads = self.variant(autoscale(fn))(*scaled_inputs) + scaled_primals, scaled_grads = self.variant(scalify(fn))(*scaled_inputs) # JAX default/expected values inputs = tuple(map(asarray, scaled_inputs)) primals, grads = self.variant(fn)(*inputs) @@ -346,16 +346,16 @@ def fn(x, y): assert isinstance(sg, ScaledArray) npt.assert_array_almost_equal(sg, g) - def test__autoscale_config__default_values(self): - cfg = get_autoscale_config() - assert isinstance(cfg, AutoScaleConfig) + def test__scalify_config__default_values(self): + cfg = get_scalify_config() + assert isinstance(cfg, ScalifyConfig) assert cfg.rounding_mode == Pow2RoundMode.DOWN assert cfg.scale_dtype is None - def test__autoscale_config__context_manager(self): - with AutoScaleConfig(rounding_mode=Pow2RoundMode.NONE, scale_dtype=np.float32): - cfg = get_autoscale_config() - assert isinstance(cfg, AutoScaleConfig) + def test__scalify_config__context_manager(self): + with ScalifyConfig(rounding_mode=Pow2RoundMode.NONE, scale_dtype=np.float32): + cfg = get_scalify_config() + assert isinstance(cfg, ScalifyConfig) assert cfg.rounding_mode == Pow2RoundMode.NONE assert cfg.scale_dtype == np.float32 @@ -364,7 +364,7 @@ def test__autoscale_config__context_manager(self): {"scale_dtype": np.float16}, {"scale_dtype": np.float32}, ) - def test__autoscale_config__scale_dtype_used_in_interpreter_promotion(self, scale_dtype): + def test__scalify_config__scale_dtype_used_in_interpreter_promotion(self, scale_dtype): def fn(x): # Sub-normal "learning rate" => can create issue when converting to FP16 scaled array. # return x * 3.123283386230469e-05 @@ -373,8 +373,8 @@ def fn(x): expected_output = fn(np.float16(1)) - with AutoScaleConfig(scale_dtype=scale_dtype): + with ScalifyConfig(scale_dtype=scale_dtype): scaled_input = scaled_array(np.array(2.0, np.float16), scale=scale_dtype(0.5)) - scaled_output = self.variant(autoscale(fn))(scaled_input) + scaled_output = self.variant(scalify(fn))(scaled_input) assert scaled_output.scale.dtype == scale_dtype npt.assert_equal(np.asarray(scaled_output, dtype=np.float32), expected_output) diff --git a/tests/core/test_pow2.py b/tests/core/test_pow2.py index 7a1fdb6..7b09712 100644 --- a/tests/core/test_pow2.py +++ b/tests/core/test_pow2.py @@ -7,8 +7,8 @@ import numpy.testing as npt from absl.testing import parameterized -from jax_scaled_arithmetics.core import Pow2RoundMode, pow2_decompose, pow2_round_down, pow2_round_up -from jax_scaled_arithmetics.core.pow2 import _exponent_bits_mask, get_mantissa +from jax_scalify.core import Pow2RoundMode, pow2_decompose, pow2_round_down, pow2_round_up +from jax_scalify.core.pow2 import _exponent_bits_mask, get_mantissa class Pow2DecomposePrimitveTests(chex.TestCase): diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index bdcf71c..a64a5ff 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -5,7 +5,7 @@ import numpy.testing as npt from absl.testing import parameterized -from jax_scaled_arithmetics.core.utils import Array, python_scalar_as_numpy, safe_div, safe_reciprocal +from jax_scalify.core.utils import Array, python_scalar_as_numpy, safe_div, safe_reciprocal class SafeDivOpTests(chex.TestCase): diff --git a/tests/lax/test_base_scaling_primitives.py b/tests/lax/test_base_scaling_primitives.py index 22d75dc..d517f09 100644 --- a/tests/lax/test_base_scaling_primitives.py +++ b/tests/lax/test_base_scaling_primitives.py @@ -6,8 +6,8 @@ from absl.testing import parameterized from numpy.typing import NDArray -from jax_scaled_arithmetics.core import Array, AutoScaleConfig, ScaledArray, autoscale, scaled_array -from jax_scaled_arithmetics.lax.base_scaling_primitives import ( +from jax_scalify.core import Array, ScaledArray, ScalifyConfig, scaled_array, scalify +from jax_scalify.lax.base_scaling_primitives import ( get_data_scale, rebalance, scaled_set_scaling, @@ -43,13 +43,13 @@ def fn(arr, scale): return set_scaling(arr, scale) scale = np.array(0, dtype=arr.dtype) - out = self.variant(autoscale(fn))(arr, scale) + out = self.variant(scalify(fn))(arr, scale) assert isinstance(out, ScaledArray) npt.assert_array_almost_equal(out.scale, 0) npt.assert_array_almost_equal(out.data, 0) @chex.variants(with_jit=True, without_jit=True) - def test__set_scaling_primitive__proper_result_without_autoscale(self): + def test__set_scaling_primitive__proper_result_without_scalify(self): def fn(arr, scale): return set_scaling(arr, scale) @@ -70,11 +70,11 @@ def fn(arr, scale): {"arr": scaled_array([-1.0, 2.0], 2.0, dtype=np.float32), "scale": scaled_array(1.0, 4.0, dtype=np.float32)}, {"arr": scaled_array([-1.0, 2.0], 2.0, dtype=np.float16), "scale": scaled_array(1.0, 4.0, dtype=np.float32)}, ) - def test__set_scaling_primitive__proper_result_with_autoscale(self, arr, scale): + def test__set_scaling_primitive__proper_result_with_scalify(self, arr, scale): def fn(arr, scale): return set_scaling(arr, scale) - fn = self.variant(autoscale(fn)) + fn = self.variant(scalify(fn)) out = fn(arr, scale) # Unchanged output tensor, with proper dtype. assert isinstance(out, ScaledArray) @@ -114,7 +114,7 @@ def test__stop_scaling_primitive__scaled_array__eager_mode(self, npapi): npt.assert_array_equal(output, values) @chex.variants(with_jit=True, without_jit=True) - def test__stop_scaling_primitive__proper_result_without_autoscale(self): + def test__stop_scaling_primitive__proper_result_without_scalify(self): def fn(arr): # Testing both variants. return stop_scaling(arr), stop_scaling(arr, dtype=np.float16) @@ -127,12 +127,12 @@ def fn(arr): npt.assert_array_almost_equal(out1, arr) @chex.variants(with_jit=True, without_jit=True) - def test__stop_scaling_primitive__proper_result_with_autoscale(self): + def test__stop_scaling_primitive__proper_result_with_scalify(self): def fn(arr): # Testing both variants. return stop_scaling(arr), stop_scaling(arr, dtype=np.float16) - fn = self.variant(autoscale(fn)) + fn = self.variant(scalify(fn)) arr = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32) out0, out1 = fn(arr) assert isinstance(out0, Array) @@ -145,10 +145,10 @@ def fn(arr): class GetDataScalePrimitiveTests(chex.TestCase): @chex.variants(with_jit=True, without_jit=True) - def test__get_data_scale_primitive__proper_result_without_autoscale(self): + def test__get_data_scale_primitive__proper_result_without_scalify(self): def fn(arr): # Set a default scale dtype. - with AutoScaleConfig(scale_dtype=np.float32): + with ScalifyConfig(scale_dtype=np.float32): return get_data_scale(arr) fn = self.variant(fn) @@ -160,11 +160,11 @@ def fn(arr): npt.assert_equal(scale, np.array(1, np.float32)) @chex.variants(with_jit=True, without_jit=True) - def test__get_data_scale_primitive__proper_result_with_autoscale(self): + def test__get_data_scale_primitive__proper_result_with_scalify(self): def fn(arr): return get_data_scale(arr) - fn = self.variant(autoscale(fn)) + fn = self.variant(scalify(fn)) arr = scaled_array([2, 3], np.float16(4), dtype=np.float16) data, scale = fn(arr) npt.assert_array_equal(data, arr.data) diff --git a/tests/lax/test_numpy_integration.py b/tests/lax/test_numpy_integration.py index 6e03bd3..6f617aa 100644 --- a/tests/lax/test_numpy_integration.py +++ b/tests/lax/test_numpy_integration.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import numpy as np -from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array +from jax_scalify.core import ScaledArray, scaled_array, scalify class ScaledJaxNumpyFunctions(chex.TestCase): @@ -22,7 +22,7 @@ def mean_fn(x): # size = 8 * 16 input_scaled = scaled_array(self.rs.rand(8, 16).astype(np.float32), np.float32(1)) - output_grad_scaled = self.variant(autoscale(mean_fn))(input_scaled) + output_grad_scaled = self.variant(scalify(mean_fn))(input_scaled) assert isinstance(output_grad_scaled, ScaledArray) # Proper scale propagation on the backward pass (rough interval) diff --git a/tests/lax/test_scaled_ops_common.py b/tests/lax/test_scaled_ops_common.py index b3bacac..a23b65a 100644 --- a/tests/lax/test_scaled_ops_common.py +++ b/tests/lax/test_scaled_ops_common.py @@ -5,15 +5,8 @@ from absl.testing import parameterized from jax import lax -from jax_scaled_arithmetics.core import ( - Array, - ScaledArray, - autoscale, - debug_callback, - find_registered_scaled_op, - scaled_array, -) -from jax_scaled_arithmetics.lax import ( +from jax_scalify.core import Array, ScaledArray, debug_callback, find_registered_scaled_op, scaled_array, scalify +from jax_scalify.lax import ( scaled_broadcast_in_dim, scaled_concatenate, scaled_convert_element_type, @@ -49,7 +42,7 @@ def fn(a): return a x = scaled_array(self.rs.rand(5), 2, dtype=np.float16) - fn = self.variant(autoscale(fn)) + fn = self.variant(scalify(fn)) fn(x) assert len(host_values) == 2 @@ -206,7 +199,7 @@ def test__scaled_select_n__proper_result(self): {"scale": 8.0}, ) def test__scaled_select__relu_grad_example(self, scale): - @autoscale + @scalify def relu_grad(g): return lax.select(g > 0, g, lax.full_like(g, 0)) diff --git a/tests/lax/test_scaled_ops_l2.py b/tests/lax/test_scaled_ops_l2.py index 3ea6256..523a8d9 100644 --- a/tests/lax/test_scaled_ops_l2.py +++ b/tests/lax/test_scaled_ops_l2.py @@ -5,8 +5,8 @@ from absl.testing import parameterized from jax import lax -from jax_scaled_arithmetics.core import ScaledArray, find_registered_scaled_op, scaled_array -from jax_scaled_arithmetics.lax import scaled_div, scaled_dot_general, scaled_mul, scaled_reduce_window_sum +from jax_scalify.core import ScaledArray, find_registered_scaled_op, scaled_array +from jax_scalify.lax import scaled_div, scaled_dot_general, scaled_mul, scaled_reduce_window_sum class ScaledTranslationDotPrimitivesTests(chex.TestCase): diff --git a/tests/lax/test_scipy_integration.py b/tests/lax/test_scipy_integration.py index 58085d5..cb9a392 100644 --- a/tests/lax/test_scipy_integration.py +++ b/tests/lax/test_scipy_integration.py @@ -5,7 +5,7 @@ from absl.testing import parameterized from jax import lax -from jax_scaled_arithmetics.core import autoscale, scaled_array +from jax_scalify.core import scaled_array, scalify class ScaledScipyHighLevelMethodsTests(chex.TestCase): @@ -19,7 +19,7 @@ def fn(a): return lax.full_like(a, 0) a = scaled_array(np.random.rand(3, 5).astype(np.float32), np.float32(1)) - autoscale(fn)(a) + scalify(fn)(a) # FIMXE/TODO: what should be the expected result? @chex.variants(with_jit=False, without_jit=True) @@ -32,7 +32,7 @@ def test__scipy_logsumexp__accurate_scaled_op(self, dtype): input_scaled = scaled_array(self.rs.rand(10), 4.0, dtype=dtype) # JAX `logsumexp` Jaxpr is a non-trivial graph! - out_scaled = self.variant(autoscale(logsumexp))(input_scaled) + out_scaled = self.variant(scalify(logsumexp))(input_scaled) out_expected = logsumexp(np.asarray(input_scaled)) assert out_scaled.dtype == out_expected.dtype # Proper accuracy + keep the same scale. diff --git a/tests/ops/test_debug.py b/tests/ops/test_debug.py index 6e06ccc..22ce39d 100644 --- a/tests/ops/test_debug.py +++ b/tests/ops/test_debug.py @@ -1,8 +1,8 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. import numpy as np -from jax_scaled_arithmetics.core import autoscale, scaled_array -from jax_scaled_arithmetics.ops import debug_print +from jax_scalify.core import scaled_array, scalify +from jax_scalify.ops import debug_print def test__debug_print__scaled_arrays(capfd): @@ -12,7 +12,7 @@ def debug_print_fn(x): debug_print(fmt, x, x) input_scaled = scaled_array([2, 3], 2.0, dtype=np.float32) - autoscale(debug_print_fn)(input_scaled) + scalify(debug_print_fn)(input_scaled) # Check captured stdout and stderr! captured = capfd.readouterr() assert len(captured.err) == 0 diff --git a/tests/ops/test_ml_dtypes.py b/tests/ops/test_ml_dtypes.py index 4bb5f56..75cee9d 100644 --- a/tests/ops/test_ml_dtypes.py +++ b/tests/ops/test_ml_dtypes.py @@ -8,8 +8,8 @@ from absl.testing import parameterized from numpy.typing import NDArray -from jax_scaled_arithmetics.core import autoscale, scaled_array -from jax_scaled_arithmetics.ops import cast_ml_dtype +from jax_scalify.core import scaled_array, scalify +from jax_scalify.ops import cast_ml_dtype class CastMLDtypeTests(chex.TestCase): @@ -29,10 +29,10 @@ def test__cast_ml_dtype__consistent_rounding_down(self, ml_dtype): {"ml_dtype": ml_dtypes.float8_e4m3fn}, {"ml_dtype": ml_dtypes.float8_e5m2}, ) - def test__cast_ml_dtype__autoscale_compatiblity(self, ml_dtype): + def test__cast_ml_dtype__scalify_compatiblity(self, ml_dtype): values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) arr = scaled_array(values, np.float32(1)) - out = autoscale(partial(cast_ml_dtype, dtype=ml_dtype))(arr) + out = scalify(partial(cast_ml_dtype, dtype=ml_dtype))(arr) npt.assert_array_equal(out.scale, arr.scale) npt.assert_array_equal(out, np.asarray(arr.data).astype(ml_dtype)) diff --git a/tests/ops/test_rescaling.py b/tests/ops/test_rescaling.py index 8e9e829..53a117c 100644 --- a/tests/ops/test_rescaling.py +++ b/tests/ops/test_rescaling.py @@ -4,8 +4,8 @@ import numpy.testing as npt from absl.testing import parameterized -from jax_scaled_arithmetics.core import ScaledArray, scaled_array -from jax_scaled_arithmetics.ops import dynamic_rescale_l1, dynamic_rescale_l2, dynamic_rescale_max +from jax_scalify.core import ScaledArray, scaled_array +from jax_scalify.ops import dynamic_rescale_l1, dynamic_rescale_l2, dynamic_rescale_max class DynamicRescaleOpsTests(chex.TestCase):