Skip to content

Commit

Permalink
Fix JAX 0.4.28 regression in SciPy logsumexp scale propagation.
Browse files Browse the repository at this point in the history
New JAX version introducing an addtional `max` operation with `-inf`.
Needs proper handling of this special value in scalify.
  • Loading branch information
balancap committed Jun 12, 2024
1 parent 68cf268 commit 3fcdccb
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 9 deletions.
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
asarray,
get_scale_dtype,
is_scaled_leaf,
is_static_anyscale,
is_static_one_scalar,
is_static_zero,
make_scaled_scalar,
Expand Down
26 changes: 26 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,32 @@ def is_static_zero(val: Union[Array, ScaledArray]) -> Array:
return np.zeros(val.shape, dtype=np.bool_)


def is_static_anyscale(val: Union[Array, ScaledArray]) -> Array:
"""Is a scaled array a static anyscale values (i.e. 0/inf/-inf during JAX tracing as well)?
Returns a boolean Numpy array of the shape of the input.
"""

def np_anyscale(arr):
# Check if 0, np.inf or -np.inf
absarr = np.abs(arr)
return np.logical_or(np.equal(absarr, 0), np.equal(absarr, np.inf))

if is_numpy_scalar_or_array(val):
return np_anyscale(val)
if isinstance(val, ScaledArray):
# TODO: deal with 0 * inf issue?
data_mask = (
np_anyscale(val.data) if is_numpy_scalar_or_array(val.data) else np.zeros(val.data.shape, dtype=np.bool_)
)
scale_mask = (
np_anyscale(val.scale) if is_numpy_scalar_or_array(val.scale) else np.zeros(val.scale.shape, dtype=np.bool_)
)
return np.logical_or(data_mask, scale_mask)
# By default: can't decide.
return np.zeros(val.shape, dtype=np.bool_)


def is_static_one_scalar(val: Array) -> Union[bool, np.bool_]:
"""Is a scaled array a static one scalar value (i.e. one during JAX tracing as well)?"""
if isinstance(val, (int, float)):
Expand Down
4 changes: 2 additions & 2 deletions jax_scaled_arithmetics/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
# Type aliasing. To be compatible with JAX 0.3 as well.
if jax.__version_info__[1] > 3:
Array = jax.Array
ArrayTypes = (jax.Array,)
ArrayTypes = (jax.Array, jax.stages.ArgInfo)
else:
Array = jaxlib.xla_extension.DeviceArray
ArrayTypes = (jaxlib.xla_extension.DeviceArray, jax.interpreters.partial_eval.DynamicJaxprTracer) # type:ignore
ArrayTypes = (jaxlib.xla_extension.DeviceArray, jax.interpreters.partial_eval.DynamicJaxprTracer)


def get_numpy_api(val: Any) -> Any:
Expand Down
7 changes: 4 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Shape,
as_scaled_array,
get_scale_dtype,
is_static_anyscale,
is_static_zero,
safe_div,
)
Expand Down Expand Up @@ -223,10 +224,10 @@ def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> Array:
def scaled_minmax(prim: jax.core.Primitive, lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
"""General min/max scaled translation: propagating the largest input scale."""
check_scalar_scales(lhs, rhs)
# Specific rule if lhs/rhs is zero => propagate the other term scale.
if np.all(is_static_zero(lhs)):
# Specific rule if lhs/rhs is zero or inf => propagate the other term scale.
if np.all(is_static_anyscale(lhs)):
return ScaledArray(prim.bind(lhs.data, rhs.data), rhs.scale)
if np.all(is_static_zero(rhs)):
if np.all(is_static_anyscale(rhs)):
return ScaledArray(prim.bind(lhs.data, rhs.data), lhs.scale)

# Power-of-2 stable!
Expand Down
20 changes: 20 additions & 0 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
asarray,
get_scale_dtype,
is_scaled_leaf,
is_static_anyscale,
is_static_one_scalar,
is_static_zero,
make_scaled_scalar,
Expand Down Expand Up @@ -304,6 +305,25 @@ def test__is_static_zero__proper_all_result(self, val, result):
all_zero = np.all(is_static_zero(val))
assert all_zero == result

@parameterized.parameters(
{"val": 0, "result": True},
{"val": 0.0, "result": True},
{"val": np.inf, "result": True},
{"val": np.int32(0), "result": True},
{"val": np.float16(0), "result": True},
{"val": np.float16(-np.inf), "result": True},
{"val": np.array([1, 2]), "result": False},
{"val": np.array([0, 0.0, -np.inf, np.inf]), "result": True},
{"val": jnp.array([0, 0.0, np.inf]), "result": False},
{"val": ScaledArray(np.array([0, 0.0]), jnp.array(2.0)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), np.array(0.0)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), np.array(np.inf)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), jnp.array(0.0)), "result": False},
)
def test__is_static_anyscale__proper_all_result(self, val, result):
all_zero = np.all(is_static_anyscale(val))
assert all_zero == result

@parameterized.parameters(
{"val": 0, "result": False},
{"val": 1, "result": True},
Expand Down
18 changes: 15 additions & 3 deletions tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ def test__scaled_addsub__not_overflowing_scale(self, prim):
assert np.isfinite(z.scale)
npt.assert_array_almost_equal(z, prim.bind(np.asarray(x, np.float32), np.asarray(y, np.float32)), decimal=6)

@parameterized.parameters(
{"prim": lax.max_p},
{"prim": lax.min_p},
@parameterized.product(
prim=[lax.min_p, lax.max_p],
)
def test__scaled_minmax__static_zero_scale_propagation(self, prim):
scaled_op, _ = find_registered_scaled_op(prim)
Expand All @@ -172,6 +171,19 @@ def test__scaled_minmax__static_zero_scale_propagation(self, prim):
# Keep the lhs scale.
npt.assert_almost_equal(z.scale, 4.0)

@parameterized.product(
prim=[lax.min_p, lax.max_p],
)
def test__scaled_minmax__static_inf_scale_propagation(self, prim):
scaled_op, _ = find_registered_scaled_op(prim)
x = scaled_array([-1.0, 2.0], 4.0, dtype=np.float32, npapi=np)
y = scaled_array([-np.inf, np.inf], np.inf, dtype=np.float32, npapi=np)
z = scaled_op(x, y)
assert isinstance(z, ScaledArray)
assert z.dtype == x.dtype
# Keep the lhs scale.
npt.assert_almost_equal(z.scale, 4.0)

def test__scaled_mul__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
Expand Down
2 changes: 1 addition & 1 deletion tests/lax/test_scipy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def fn(a):
autoscale(fn)(a)
# FIMXE/TODO: what should be the expected result?

@chex.variants(with_jit=True, without_jit=True)
@chex.variants(with_jit=False, without_jit=True)
@parameterized.parameters(
{"dtype": np.float32},
{"dtype": np.float16},
Expand Down

0 comments on commit 3fcdccb

Please sign in to comment.