From 0b9c32a126a1c215eb4717fc9f8ed4baf7dcc38f Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 10 Nov 2023 16:01:14 +0000 Subject: [PATCH] Add scaled translation rules for trivial LAX primitives. --- jax_scaled_arithmetics/core/interpreters.py | 37 +++++++++++++++------ jax_scaled_arithmetics/lax/scaled_ops.py | 5 +-- pyproject.toml | 3 ++ tests/core/test_interpreter.py | 6 +++- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 3e32bac..d00d35c 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from functools import wraps -from typing import Dict +from typing import Any, Dict import jax from jax import core @@ -9,28 +9,43 @@ from ..core import ScaledArray -_scaled_ops_registry = {} +_scaled_ops_registry: Dict[core.Primitive, Any] = {} -def register_scaled_op(lax_func, scaled_func): - _scaled_ops_registry[lax_func] = scaled_func +def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None: + """Register the scaled translation of JAX primitive. + Raises an error if a scaled translation is already existing for this primitive. -def _get_lax_prim(scaled_func): + Args: + prim: JAX primitive. + scaled_fund: Scaled translation of the primitive. With the same interface. + """ + assert isinstance(prim, core.Primitive) + if prim in _scaled_ops_registry: + raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.") + _scaled_ops_registry[prim] = scaled_func + + +def _get_lax_prim(scaled_func: Any) -> core.Primitive: try: - op = getattr(jax.lax, scaled_func.__name__.replace("scaled_", "")) + prim_name = scaled_func.__name__.replace("scaled_", "") + "_p" + prim = getattr(jax.lax, prim_name) except AttributeError: - raise AttributeError(f"Could not find corresponding jax.lax primitive for {scaled_func.__name__}") - return op + raise AttributeError(f"Could not find corresponding 'jax.lax' primitive for '{scaled_func.__name__}'.") + # Check as well it is a proper primitive! And not something else also in `jax.lax` + if not isinstance(prim, core.Primitive): + raise AttributeError(f"The object `{prim}` is not a proper JAX primitive for '{scaled_func.__name__}'.") + return prim def register_scaled_lax_op(scaled_func): """ - Registers a scaled function into the scaled_ops_registry by matching - the function name with pattern `scaled_{func_name}` to a function in the + Registers a scaled function/translation into the scaled_ops_registry by matching + the function name with pattern `scaled_{func_name}` to a primitive in the `jax.lax` namespace. - Example: `scaled_mul_p` is matched to `jax.lax.mul_p` + Example: `scaled_mul` is matched to `jax.lax.mul_p` """ lax_prim = _get_lax_prim(scaled_func) register_scaled_op(lax_prim, scaled_func) diff --git a/jax_scaled_arithmetics/lax/scaled_ops.py b/jax_scaled_arithmetics/lax/scaled_ops.py index 75e6f21..013754d 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops.py +++ b/jax_scaled_arithmetics/lax/scaled_ops.py @@ -5,8 +5,5 @@ @core.register_scaled_lax_op -def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray: +def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: return ScaledArray(A.data * B.data, A.scale * B.scale) - - -__all__ = ["scaled_mul_p"] diff --git a/pyproject.toml b/pyproject.toml index d90e1e9..728e247 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,9 @@ Website = "https://github.com/graphcore-research/jax-scaled-arithmetics/#readme" [project.optional-dependencies] test = ["pytest"] +[tool.setuptools] +packages = ["jax_scaled_arithmetics"] + [tool.pytest.ini_options] minversion = "6.0" addopts = ["-ra", "--showlocals", "--strict-config", "-p no:hypothesispytest"] diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index b906020..6961fbe 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -6,10 +6,14 @@ import numpy as np import numpy.testing as npt -from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array +from jax_scaled_arithmetics.core import ScaledArray, autoscale, register_scaled_op, scaled_array class AutoScaleInterpreterTests(chex.TestCase): + def test__register_scaled_op__error_if_already_registered(self): + with self.assertRaises(KeyError): + register_scaled_op(jax.lax.mul_p, lambda a, _: a) + @chex.variants(with_jit=True, without_jit=True) def test__scaled_identity_function(self): def func(x):