Skip to content

Commit

Permalink
Fix JAX 0.4.28 regression in SciPy logsumexp scale propagation. (#108)
Browse files Browse the repository at this point in the history
New JAX version introducing an addtional `max` operation with `-inf`.
Needs proper handling of this special value in scalify.
  • Loading branch information
balancap committed Jun 12, 2024
1 parent 68cf268 commit 8d94ca0
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 9 deletions.
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
asarray,
get_scale_dtype,
is_scaled_leaf,
is_static_anyscale,
is_static_one_scalar,
is_static_zero,
make_scaled_scalar,
Expand Down
26 changes: 26 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,32 @@ def is_static_zero(val: Union[Array, ScaledArray]) -> Array:
return np.zeros(val.shape, dtype=np.bool_)


def is_static_anyscale(val: Union[Array, ScaledArray]) -> Array:
"""Is a scaled array a static anyscale values (i.e. 0/inf/-inf during JAX tracing as well)?
Returns a boolean Numpy array of the shape of the input.
"""

def np_anyscale(arr):
# Check if 0, np.inf or -np.inf
absarr = np.abs(arr)
return np.logical_or(np.equal(absarr, 0), np.equal(absarr, np.inf))

if is_numpy_scalar_or_array(val):
return np_anyscale(val)
if isinstance(val, ScaledArray):
# TODO: deal with 0 * inf issue?
data_mask = (
np_anyscale(val.data) if is_numpy_scalar_or_array(val.data) else np.zeros(val.data.shape, dtype=np.bool_)
)
scale_mask = (
np_anyscale(val.scale) if is_numpy_scalar_or_array(val.scale) else np.zeros(val.scale.shape, dtype=np.bool_)
)
return np.logical_or(data_mask, scale_mask)
# By default: can't decide.
return np.zeros(val.shape, dtype=np.bool_)


def is_static_one_scalar(val: Array) -> Union[bool, np.bool_]:
"""Is a scaled array a static one scalar value (i.e. one during JAX tracing as well)?"""
if isinstance(val, (int, float)):
Expand Down
4 changes: 2 additions & 2 deletions jax_scaled_arithmetics/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
# Type aliasing. To be compatible with JAX 0.3 as well.
if jax.__version_info__[1] > 3:
Array = jax.Array
ArrayTypes = (jax.Array,)
ArrayTypes = (jax.Array, jax.stages.ArgInfo)
else:
Array = jaxlib.xla_extension.DeviceArray
ArrayTypes = (jaxlib.xla_extension.DeviceArray, jax.interpreters.partial_eval.DynamicJaxprTracer) # type:ignore
ArrayTypes = (jaxlib.xla_extension.DeviceArray, jax.interpreters.partial_eval.DynamicJaxprTracer)


def get_numpy_api(val: Any) -> Any:
Expand Down
7 changes: 4 additions & 3 deletions jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Shape,
as_scaled_array,
get_scale_dtype,
is_static_anyscale,
is_static_zero,
safe_div,
)
Expand Down Expand Up @@ -223,10 +224,10 @@ def scaled_le(lhs: ScaledArray, rhs: ScaledArray) -> Array:
def scaled_minmax(prim: jax.core.Primitive, lhs: ScaledArray, rhs: ScaledArray) -> ScaledArray:
"""General min/max scaled translation: propagating the largest input scale."""
check_scalar_scales(lhs, rhs)
# Specific rule if lhs/rhs is zero => propagate the other term scale.
if np.all(is_static_zero(lhs)):
# Specific rule if lhs/rhs is zero or inf => propagate the other term scale.
if np.all(is_static_anyscale(lhs)):
return ScaledArray(prim.bind(lhs.data, rhs.data), rhs.scale)
if np.all(is_static_zero(rhs)):
if np.all(is_static_anyscale(rhs)):
return ScaledArray(prim.bind(lhs.data, rhs.data), lhs.scale)

# Power-of-2 stable!
Expand Down
20 changes: 20 additions & 0 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
asarray,
get_scale_dtype,
is_scaled_leaf,
is_static_anyscale,
is_static_one_scalar,
is_static_zero,
make_scaled_scalar,
Expand Down Expand Up @@ -304,6 +305,25 @@ def test__is_static_zero__proper_all_result(self, val, result):
all_zero = np.all(is_static_zero(val))
assert all_zero == result

@parameterized.parameters(
{"val": 0, "result": True},
{"val": 0.0, "result": True},
{"val": np.inf, "result": True},
{"val": np.int32(0), "result": True},
{"val": np.float16(0), "result": True},
{"val": np.float16(-np.inf), "result": True},
{"val": np.array([1, 2]), "result": False},
{"val": np.array([0, 0.0, -np.inf, np.inf]), "result": True},
{"val": jnp.array([0, 0.0, np.inf]), "result": False},
{"val": ScaledArray(np.array([0, 0.0]), jnp.array(2.0)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), np.array(0.0)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), np.array(np.inf)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), jnp.array(0.0)), "result": False},
)
def test__is_static_anyscale__proper_all_result(self, val, result):
all_zero = np.all(is_static_anyscale(val))
assert all_zero == result

@parameterized.parameters(
{"val": 0, "result": False},
{"val": 1, "result": True},
Expand Down
18 changes: 15 additions & 3 deletions tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ def test__scaled_addsub__not_overflowing_scale(self, prim):
assert np.isfinite(z.scale)
npt.assert_array_almost_equal(z, prim.bind(np.asarray(x, np.float32), np.asarray(y, np.float32)), decimal=6)

@parameterized.parameters(
{"prim": lax.max_p},
{"prim": lax.min_p},
@parameterized.product(
prim=[lax.min_p, lax.max_p],
)
def test__scaled_minmax__static_zero_scale_propagation(self, prim):
scaled_op, _ = find_registered_scaled_op(prim)
Expand All @@ -172,6 +171,19 @@ def test__scaled_minmax__static_zero_scale_propagation(self, prim):
# Keep the lhs scale.
npt.assert_almost_equal(z.scale, 4.0)

@parameterized.product(
prim=[lax.min_p, lax.max_p],
)
def test__scaled_minmax__static_inf_scale_propagation(self, prim):
scaled_op, _ = find_registered_scaled_op(prim)
x = scaled_array([-1.0, 2.0], 4.0, dtype=np.float32, npapi=np)
y = scaled_array([-np.inf, np.inf], np.inf, dtype=np.float32, npapi=np)
z = scaled_op(x, y)
assert isinstance(z, ScaledArray)
assert z.dtype == x.dtype
# Keep the lhs scale.
npt.assert_almost_equal(z.scale, 4.0)

def test__scaled_mul__proper_scaling(self):
x = scaled_array([-2.0, 2.0], 3, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
Expand Down
2 changes: 1 addition & 1 deletion tests/lax/test_scipy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def fn(a):
autoscale(fn)(a)
# FIMXE/TODO: what should be the expected result?

@chex.variants(with_jit=True, without_jit=True)
@chex.variants(with_jit=False, without_jit=True)
@parameterized.parameters(
{"dtype": np.float32},
{"dtype": np.float16},
Expand Down

0 comments on commit 8d94ca0

Please sign in to comment.