diff --git a/jax_scaled_arithmetics/core/__init__.py b/jax_scaled_arithmetics/core/__init__.py index 199a74a..aefd0dc 100644 --- a/jax_scaled_arithmetics/core/__init__.py +++ b/jax_scaled_arithmetics/core/__init__.py @@ -7,6 +7,7 @@ asarray, get_scale_dtype, is_scaled_leaf, + is_static_anyscale, is_static_one_scalar, is_static_zero, make_scaled_scalar, diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 2d293fa..dba692f 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -268,6 +268,32 @@ def is_static_zero(val: Union[Array, ScaledArray]) -> Array: return np.zeros(val.shape, dtype=np.bool_) +def is_static_anyscale(val: Union[Array, ScaledArray]) -> Array: + """Is a scaled array a static anyscale values (i.e. 0/inf/-inf during JAX tracing as well)? + + Returns a boolean Numpy array of the shape of the input. + """ + + def np_anyscale(arr): + # Check if 0, np.inf or -np.inf + absarr = np.abs(arr) + return np.logical_or(np.equal(absarr, 0), np.equal(absarr, np.inf)) + + if is_numpy_scalar_or_array(val): + return np_anyscale(val) + if isinstance(val, ScaledArray): + # TODO: deal with 0 * inf issue? + data_mask = ( + np_anyscale(val.data) if is_numpy_scalar_or_array(val.data) else np.zeros(val.data.shape, dtype=np.bool_) + ) + scale_mask = ( + np_anyscale(val.scale) if is_numpy_scalar_or_array(val.scale) else np.zeros(val.scale.shape, dtype=np.bool_) + ) + return np.logical_or(data_mask, scale_mask) + # By default: can't decide. + return np.zeros(val.shape, dtype=np.bool_) + + def is_static_one_scalar(val: Array) -> Union[bool, np.bool_]: """Is a scaled array a static one scalar value (i.e. one during JAX tracing as well)?""" if isinstance(val, (int, float)): diff --git a/tests/core/test_datatype.py b/tests/core/test_datatype.py index 1d664f6..512877e 100644 --- a/tests/core/test_datatype.py +++ b/tests/core/test_datatype.py @@ -13,6 +13,7 @@ asarray, get_scale_dtype, is_scaled_leaf, + is_static_anyscale, is_static_one_scalar, is_static_zero, make_scaled_scalar, @@ -304,6 +305,25 @@ def test__is_static_zero__proper_all_result(self, val, result): all_zero = np.all(is_static_zero(val)) assert all_zero == result + @parameterized.parameters( + {"val": 0, "result": True}, + {"val": 0.0, "result": True}, + {"val": np.inf, "result": True}, + {"val": np.int32(0), "result": True}, + {"val": np.float16(0), "result": True}, + {"val": np.float16(-np.inf), "result": True}, + {"val": np.array([1, 2]), "result": False}, + {"val": np.array([0, 0.0, -np.inf, np.inf]), "result": True}, + {"val": jnp.array([0, 0.0, np.inf]), "result": False}, + {"val": ScaledArray(np.array([0, 0.0]), jnp.array(2.0)), "result": True}, + {"val": ScaledArray(jnp.array([3, 4.0]), np.array(0.0)), "result": True}, + {"val": ScaledArray(jnp.array([3, 4.0]), np.array(np.inf)), "result": True}, + {"val": ScaledArray(jnp.array([3, 4.0]), jnp.array(0.0)), "result": False}, + ) + def test__is_static_anyscale__proper_all_result(self, val, result): + all_zero = np.all(is_static_anyscale(val)) + assert all_zero == result + @parameterized.parameters( {"val": 0, "result": False}, {"val": 1, "result": True},