-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add scaled translation rules for trivial LAX primitives. (#14)
Translation for primitives: `broadcast_in_dim`, `convert_element_type`, `slice` and `transpose`. Additionally, improvements to the autoscale interpreter for making it more robusts + proper forwarding of attributes.
- Loading branch information
Showing
6 changed files
with
135 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from .datatype import ScaledArray, scaled_array # noqa: F401 | ||
from .datatype import DTypeLike, ScaledArray, Shape, scaled_array # noqa: F401 | ||
from .interpreters import autoscale, register_scaled_lax_op, register_scaled_op # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,39 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from typing import Optional, Sequence | ||
|
||
from jax import lax | ||
|
||
from jax_scaled_arithmetics import core | ||
from jax_scaled_arithmetics.core import ScaledArray | ||
from jax_scaled_arithmetics.core import DTypeLike, ScaledArray, Shape | ||
|
||
|
||
@core.register_scaled_lax_op | ||
def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray: | ||
return ScaledArray(A.data * B.data, A.scale * B.scale) | ||
def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions: Sequence[int]) -> ScaledArray: | ||
return ScaledArray(lax.broadcast_in_dim(A.data, shape=shape, broadcast_dimensions=broadcast_dimensions), A.scale) | ||
|
||
|
||
__all__ = ["scaled_mul_p"] | ||
@core.register_scaled_lax_op | ||
def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False) -> ScaledArray: | ||
# NOTE: by default, no rescaling done before casting. | ||
# Choice of adding an optional rescaling op before is up to the user (and which strategy to use). | ||
# NOTE bis: scale not casted as well by default! | ||
return ScaledArray(lax.convert_element_type(A.data, new_dtype=new_dtype), A.scale) | ||
|
||
|
||
@core.register_scaled_lax_op | ||
def scaled_slice( | ||
A: ScaledArray, start_indices: Sequence[int], limit_indices: Sequence[int], strides: Optional[Sequence[int]] = None | ||
) -> ScaledArray: | ||
return ScaledArray( | ||
lax.slice(A.data, start_indices=start_indices, limit_indices=limit_indices, strides=strides), A.scale | ||
) | ||
|
||
|
||
@core.register_scaled_lax_op | ||
def scaled_transpose(A: ScaledArray, permutation: Sequence[int]) -> ScaledArray: | ||
return ScaledArray(lax.transpose(A.data, permutation=permutation), A.scale) | ||
|
||
|
||
@core.register_scaled_lax_op | ||
def scaled_mul(A: ScaledArray, B: ScaledArray) -> ScaledArray: | ||
return ScaledArray(A.data * B.data, A.scale * B.scale) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
import chex | ||
import numpy as np | ||
import numpy.testing as npt | ||
|
||
from jax_scaled_arithmetics.core import ScaledArray, scaled_array | ||
from jax_scaled_arithmetics.lax import ( | ||
scaled_broadcast_in_dim, | ||
scaled_convert_element_type, | ||
scaled_mul, | ||
scaled_slice, | ||
scaled_transpose, | ||
) | ||
|
||
|
||
class ScaledTranslationPrimitivesTests(chex.TestCase): | ||
def test__scaled_broadcast_in_dim__proper_scaling(self): | ||
x = scaled_array(np.random.rand(5), 2, dtype=np.float32) | ||
z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,)) | ||
assert isinstance(z, ScaledArray) | ||
npt.assert_array_equal(z.scale, x.scale) | ||
npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1))) | ||
|
||
def test__scaled_convert_element_type__proper_scaling(self): | ||
x = scaled_array(np.random.rand(5), 2, dtype=np.float32) | ||
z = scaled_convert_element_type(x, new_dtype=np.float16) | ||
assert isinstance(z, ScaledArray) | ||
npt.assert_array_equal(z.scale, x.scale) | ||
npt.assert_array_almost_equal(z.data, x.data.astype(z.dtype)) | ||
|
||
def test__scaled_transpose__proper_scaling(self): | ||
x = scaled_array(np.random.rand(3, 5), 2, dtype=np.float32) | ||
z = scaled_transpose(x, (1, 0)) | ||
assert isinstance(z, ScaledArray) | ||
assert z.scale == x.scale | ||
npt.assert_array_almost_equal(z.data, x.data.T) | ||
|
||
def test__scaled_slice__proper_scaling(self): | ||
x = scaled_array(np.random.rand(5), 2, dtype=np.float32) | ||
z = scaled_slice(x, (1,), (4,), (2,)) | ||
assert isinstance(z, ScaledArray) | ||
assert z.scale == x.scale | ||
npt.assert_array_almost_equal(z.data, x.data[1:4:2]) | ||
|
||
def test__scaled_mul__proper_scaling(self): | ||
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32) | ||
y = scaled_array([1.5, 1.5], 2, dtype=np.float32) | ||
z = scaled_mul(x, y) | ||
assert isinstance(z, ScaledArray) | ||
assert z.scale == 6 | ||
npt.assert_array_almost_equal(z, np.asarray(x) * np.asarray(y)) |