Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jun 12, 2024
1 parent 45f634b commit 95312db
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
7 changes: 4 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Shape,
as_scaled_array,
get_scale_dtype,
is_static_anyscale,
is_static_zero,
safe_div,
)
Expand Down Expand Up @@ -223,10 +224,10 @@ def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> Array:
def scaled_minmax(prim: jax.core.Primitive, lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
"""General min/max scaled translation: propagating the largest input scale."""
check_scalar_scales(lhs, rhs)
# Specific rule if lhs/rhs is zero => propagate the other term scale.
if np.all(is_static_zero(lhs)):
# Specific rule if lhs/rhs is zero or inf => propagate the other term scale.
if np.all(is_static_anyscale(lhs)):
return ScaledArray(prim.bind(lhs.data, rhs.data), rhs.scale)
if np.all(is_static_zero(rhs)):
if np.all(is_static_anyscale(rhs)):
return ScaledArray(prim.bind(lhs.data, rhs.data), lhs.scale)

# Power-of-2 stable!
Expand Down
18 changes: 15 additions & 3 deletions tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ def test__scaled_addsub__not_overflowing_scale(self, prim):
assert np.isfinite(z.scale)
npt.assert_array_almost_equal(z, prim.bind(np.asarray(x, np.float32), np.asarray(y, np.float32)), decimal=6)

@parameterized.parameters(
{"prim": lax.max_p},
{"prim": lax.min_p},
@parameterized.product(
prim=[lax.min_p, lax.max_p],
)
def test__scaled_minmax__static_zero_scale_propagation(self, prim):
scaled_op, _ = find_registered_scaled_op(prim)
Expand All @@ -172,6 +171,19 @@ def test__scaled_minmax__static_zero_scale_propagation(self, prim):
# Keep the lhs scale.
npt.assert_almost_equal(z.scale, 4.0)

@parameterized.product(
prim=[lax.min_p, lax.max_p],
)
def test__scaled_minmax__static_inf_scale_propagation(self, prim):
scaled_op, _ = find_registered_scaled_op(prim)
x = scaled_array([-1.0, 2.0], 4.0, dtype=np.float32, npapi=np)
y = scaled_array([-np.inf, np.inf], np.inf, dtype=np.float32, npapi=np)
z = scaled_op(x, y)
assert isinstance(z, ScaledArray)
assert z.dtype == x.dtype
# Keep the lhs scale.
npt.assert_almost_equal(z.scale, 4.0)

def test__scaled_mul__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
Expand Down
8 changes: 1 addition & 7 deletions tests/lax/test_scipy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,16 @@ def fn(a):
@chex.variants(with_jit=False, without_jit=True)
@parameterized.parameters(
{"dtype": np.float32},
# {"dtype": np.float16},
{"dtype": np.float16},
)
def test__scipy_logsumexp__accurate_scaled_op(self, dtype):
import jax
from jax.scipy.special import logsumexp

input_scaled = scaled_array(self.rs.rand(10), 4.0, dtype=dtype)

print(jax.make_jaxpr(logsumexp)(input_scaled.data))

# JAX `logsumexp` Jaxpr is a non-trivial graph!
out_scaled = self.variant(autoscale(logsumexp))(input_scaled)
out_expected = logsumexp(np.asarray(input_scaled))
assert out_scaled.dtype == out_expected.dtype
# Proper accuracy + keep the same scale.
npt.assert_array_equal(out_scaled.scale, input_scaled.scale)
npt.assert_array_almost_equal(out_scaled, out_expected, decimal=5)

assert False

0 comments on commit 95312db

Please sign in to comment.