Skip to content

Commit

Permalink
Add scaled translation rules for trivial LAX primitives. (#14)
Browse files Browse the repository at this point in the history
Translation for primitives: `broadcast_in_dim`, `convert_element_type`, `slice` and `transpose`.
Additionally, improvements to the autoscale interpreter for making it more robusts + proper forwarding of attributes.
  • Loading branch information
balancap authored Nov 10, 2023
1 parent c17877b commit 006f321
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 20 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from .datatype import ScaledArray, scaled_array # noqa: F401
from .datatype import DTypeLike, ScaledArray, Shape, scaled_array # noqa: F401
from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401
43 changes: 30 additions & 13 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,56 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from functools import wraps
from typing import Dict
from typing import Any, Dict

import jax
from jax import core
from jax._src.util import safe_map

from ..core import ScaledArray

_scaled_ops_registry = {}
_scaled_ops_registry: Dict[core.Primitive, Any] = {}


def register_scaled_op(lax_func, scaled_func):
_scaled_ops_registry[lax_func] = scaled_func
def register_scaled_op(prim: core.Primitive, scaled_func: Any) -> None:
"""Register the scaled translation of JAX primitive.
Raises an error if a scaled translation is already existing for this primitive.
def _get_lax_prim(scaled_func):
Args:
prim: JAX primitive.
scaled_fund: Scaled translation of the primitive. With the same interface.
"""
assert isinstance(prim, core.Primitive)
if prim in _scaled_ops_registry:
raise KeyError(f"A scaled translation is already registered for the JAX primitive '{prim}'.")
_scaled_ops_registry[prim] = scaled_func


def _get_lax_prim(scaled_func: Any) -> core.Primitive:
try:
op = getattr(jax.lax, scaled_func.__name__.replace("scaled_", ""))
prim_name = scaled_func.__name__.replace("scaled_", "") + "_p"
prim = getattr(jax.lax, prim_name)
except AttributeError:
raise AttributeError(f"Could not find corresponding jax.lax primitive for {scaled_func.__name__}")
return op
raise AttributeError(f"Could not find corresponding 'jax.lax' primitive for '{scaled_func.__name__}'.")
# Check as well it is a proper primitive! And not something else also in `jax.lax`
if not isinstance(prim, core.Primitive):
raise AttributeError(f"The object `{prim}` is not a proper JAX primitive for '{scaled_func.__name__}'.")
return prim


def register_scaled_lax_op(scaled_func):
"""
Registers a scaled function into the scaled_ops_registry by matching
the function name with pattern `scaled_{func_name}` to a function in the
Registers a scaled function/translation into the scaled_ops_registry by matching
the function name with pattern `scaled_{func_name}` to a primitive in the
`jax.lax` namespace.
Example: `scaled_mul_p` is matched to `jax.lax.mul_p`
Example: `scaled_mul` is matched to `jax.lax.mul_p`
"""
lax_prim = _get_lax_prim(scaled_func)
register_scaled_op(lax_prim, scaled_func)
# Always return the function in the case of decorator use.
return scaled_func


def autoscale(fun):
Expand All @@ -49,7 +66,7 @@ def wrapped(*args, **kwargs):
return wrapped


def autoscale_jaxpr(jaxpr, consts, *args):
def autoscale_jaxpr(jaxpr: core.Jaxpr, consts, *args):
env: Dict[core.Var, ScaledArray] = {}

def read(var):
Expand All @@ -67,7 +84,7 @@ def write(var, val):
invals = safe_map(read, eqn.invars)
if eqn.primitive not in _scaled_ops_registry:
raise NotImplementedError(f"{eqn.primitive} does not have an implementation for ScaledArray inputs yet")
outvals = _scaled_ops_registry[eqn.primitive](*invals)
outvals = _scaled_ops_registry[eqn.primitive](*invals, **eqn.params)
if not eqn.primitive.multiple_results:
outvals = [outvals]
safe_map(write, eqn.outvars, outvals)
Expand Down
35 changes: 31 additions & 4 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Optional, Sequence

from jax import lax

from jax_scaled_arithmetics import core
from jax_scaled_arithmetics.core import ScaledArray
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape


@core.register_scaled_lax_op
def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)
def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: Sequence[int]) -> ScaledArray:
return ScaledArray(lax.broadcast_in_dim(A.data, shape=shape, broadcast_dimensions=broadcast_dimensions), A.scale)


__all__ = ["scaled_mul_p"]
@core.register_scaled_lax_op
def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False) -> ScaledArray:
# NOTE: by default, no rescaling done before casting.
# Choice of adding an optional rescaling op before is up to the user (and which strategy to use).
# NOTE bis: scale not casted as well by default!
return ScaledArray(lax.convert_element_type(A.data, new_dtype=new_dtype), A.scale)


@core.register_scaled_lax_op
def scaled_slice(
A: ScaledArray, start_indices: Sequence[int], limit_indices: Sequence[int], strides: Optional[Sequence[int]] = None
) -> ScaledArray:
return ScaledArray(
lax.slice(A.data, start_indices=start_indices, limit_indices=limit_indices, strides=strides), A.scale
)


@core.register_scaled_lax_op
def scaled_transpose(A: ScaledArray, permutation: Sequence[int]) -> ScaledArray:
return ScaledArray(lax.transpose(A.data, permutation=permutation), A.scale)


@core.register_scaled_lax_op
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Website = "https://github.com/graphcore-research/jax-scaled-arithmetics/#readme"
[project.optional-dependencies]
test = ["pytest"]

[tool.setuptools]
packages = ["jax_scaled_arithmetics"]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-config", "-p no:hypothesispytest"]
Expand Down
21 changes: 19 additions & 2 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.core import ScaledArray, autoscale, register_scaled_op, scaled_array


class AutoScaleInterpreterTests(chex.TestCase):
def test__register_scaled_op__error_if_already_registered(self):
with self.assertRaises(KeyError):
register_scaled_op(jax.lax.mul_p, lambda a, _: a)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_identity_function(self):
def func(x):
Expand All @@ -34,7 +38,7 @@ def func(x):
assert jaxpr.outvars[1].aval.shape == ()

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_mul_function(self):
def test__scaled_mul__no_attributes(self):
def func(x, y):
return x * y

Expand All @@ -48,3 +52,16 @@ def func(x, y):
out = asfunc(x, y)
assert isinstance(out, ScaledArray)
npt.assert_array_almost_equal(out, expected)

@chex.variants(with_jit=True, without_jit=True)
def test__scaled_convert_element_type__attributes_passing(self):
def func(x):
return jax.lax.convert_element_type(x, np.float16)

# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))
x = scaled_array([-4.0, 2.0], 0.5, dtype=np.float32)
out = asfunc(x)
assert isinstance(out, ScaledArray)
assert out.dtype == np.float16
npt.assert_array_almost_equal(out, x)
51 changes: 51 additions & 0 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import chex
import numpy as np
import numpy.testing as npt

from jax_scaled_arithmetics.core import ScaledArray, scaled_array
from jax_scaled_arithmetics.lax import (
scaled_broadcast_in_dim,
scaled_convert_element_type,
scaled_mul,
scaled_slice,
scaled_transpose,
)


class ScaledTranslationPrimitivesTests(chex.TestCase):
def test__scaled_broadcast_in_dim__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,))
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1)))

def test__scaled_convert_element_type__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
z = scaled_convert_element_type(x, new_dtype=np.float16)
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.astype(z.dtype))

def test__scaled_transpose__proper_scaling(self):
x = scaled_array(np.random.rand(3, 5), 2, dtype=np.float32)
z = scaled_transpose(x, (1, 0))
assert isinstance(z, ScaledArray)
assert z.scale == x.scale
npt.assert_array_almost_equal(z.data, x.data.T)

def test__scaled_slice__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
z = scaled_slice(x, (1,), (4,), (2,))
assert isinstance(z, ScaledArray)
assert z.scale == x.scale
npt.assert_array_almost_equal(z.data, x.data[1:4:2])

def test__scaled_mul__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
z = scaled_mul(x, y)
assert isinstance(z, ScaledArray)
assert z.scale == 6
npt.assert_array_almost_equal(z, np.asarray(x) * np.asarray(y))

0 comments on commit 006f321

Please sign in to comment.