Skip to content

Commit

Permalink
Implement scaled_concatenate LAX operation. (#15)
Browse files Browse the repository at this point in the history
`scaled_concatenate` is using a max rescaling.
  • Loading branch information
balancap authored Nov 14, 2023
1 parent 006f321 commit bba779d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
13 changes: 13 additions & 0 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Optional, Sequence

import jax.numpy as jnp
from jax import lax

from jax_scaled_arithmetics import core
Expand All @@ -20,6 +21,18 @@ def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type:
return ScaledArray(lax.convert_element_type(A.data, new_dtype=new_dtype), A.scale)


@core.register_scaled_lax_op
def scaled_concatenate(operands: Sequence[ScaledArray], dimension: int) -> ScaledArray:
# TODO: inputs checking (dtype and cie).
scales = jnp.array([v.scale for v in operands])
# Max rescaling of the collection of operands.
# TODO: explore alternative strategies?
scale_max = jnp.max(scales)
datas = [v.data * (v.scale / scale_max) for v in operands]
data_concat = lax.concatenate(datas, dimension=dimension)
return ScaledArray(data_concat, scale_max)


@core.register_scaled_lax_op
def scaled_slice(
A: ScaledArray, start_indices: Sequence[int], limit_indices: Sequence[int], strides: Optional[Sequence[int]] = None
Expand Down
9 changes: 9 additions & 0 deletions tests/lax/test_scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax_scaled_arithmetics.core import ScaledArray, scaled_array
from jax_scaled_arithmetics.lax import (
scaled_broadcast_in_dim,
scaled_concatenate,
scaled_convert_element_type,
scaled_mul,
scaled_slice,
Expand All @@ -21,6 +22,14 @@ def test__scaled_broadcast_in_dim__proper_scaling(self):
npt.assert_array_equal(z.scale, x.scale)
npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1)))

def test__scaled_concatenate__proper_scaling(self):
x = scaled_array(np.random.rand(2, 3), 0.5, dtype=np.float32)
y = scaled_array(np.random.rand(5, 3), 2, dtype=np.float32)
z = scaled_concatenate([x, y], dimension=0)
assert isinstance(z, ScaledArray)
npt.assert_array_equal(z.scale, y.scale)
npt.assert_array_almost_equal(z, np.concatenate([x, y], axis=0))

def test__scaled_convert_element_type__proper_scaling(self):
x = scaled_array(np.random.rand(5), 2, dtype=np.float32)
z = scaled_convert_element_type(x, new_dtype=np.float16)
Expand Down

0 comments on commit bba779d

Please sign in to comment.