Skip to content

Commit

Permalink
Add scaled translation rules for trivial LAX primitives.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Nov 10, 2023
1 parent c17877b commit 0b9c32a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 16 deletions.
37 changes: 26 additions & 11 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,51 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from functools import wraps
from typing import Dict
from typing import Any, Dict

import jax
from jax import core
from jax._src.util import safe_map

from ..core import ScaledArray

_scaled_ops_registry = {}
_scaled_ops_registry: Dict[core.Primitive, Any] = {}


def register_scaled_op(lax_func, scaled_func):
_scaled_ops_registry[lax_func] = scaled_func
def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None:
"""Register the scaled translation of JAX primitive.
Raises an error if a scaled translation is already existing for this primitive.
def _get_lax_prim(scaled_func):
Args:
prim: JAX primitive.
scaled_fund: Scaled translation of the primitive. With the same interface.
"""
assert isinstance(prim, core.Primitive)
if prim in _scaled_ops_registry:
raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.")
_scaled_ops_registry[prim] = scaled_func


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
try:
op = getattr(jax.lax, scaled_func.__name__.replace("scaled_", ""))
prim_name = scaled_func.__name__.replace("scaled_", "") + "_p"
prim = getattr(jax.lax, prim_name)
except AttributeError:
raise AttributeError(f"Could not find corresponding jax.lax primitive for {scaled_func.__name__}")
return op
raise AttributeError(f"Could not find corresponding 'jax.lax' primitive for '{scaled_func.__name__}'.")
# Check as well it is a proper primitive! And not something else also in `jax.lax`
if not isinstance(prim, core.Primitive):
raise AttributeError(f"The object `{prim}` is not a proper JAX primitive for '{scaled_func.__name__}'.")
return prim


def register_scaled_lax_op(scaled_func):
"""
Registers a scaled function into the scaled_ops_registry by matching
the function name with pattern `scaled_{func_name}` to a function in the
Registers a scaled function/translation into the scaled_ops_registry by matching
the function name with pattern `scaled_{func_name}` to a primitive in the
`jax.lax` namespace.
Example: `scaled_mul_p` is matched to `jax.lax.mul_p`
Example: `scaled_mul` is matched to `jax.lax.mul_p`
"""
lax_prim = _get_lax_prim(scaled_func)
register_scaled_op(lax_prim, scaled_func)
Expand Down
5 changes: 1 addition & 4 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,5 @@


@core.register_scaled_lax_op
def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray:
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)


__all__ = ["scaled_mul_p"]
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Website = "https://github.com/graphcore-research/jax-scaled-arithmetics/#readme"
[project.optional-dependencies]
test = ["pytest"]

[tool.setuptools]
packages = ["jax_scaled_arithmetics"]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-config", "-p no:hypothesispytest"]
Expand Down
6 changes: 5 additions & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, autoscale, register_scaled_op, scaled_array


class AutoScaleInterpreterTests(chex.TestCase):
def test__register_scaled_op__error_if_already_registered(self):
with self.assertRaises(KeyError):
register_scaled_op(jax.lax.mul_p, lambda a, _: a)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_identity_function(self):
def func(x):
Expand Down

0 comments on commit 0b9c32a

Please sign in to comment.