Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jun 12, 2024
1 parent c91009c commit 45f634b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
asarray,
get_scale_dtype,
is_scaled_leaf,
is_static_anyscale,
is_static_one_scalar,
is_static_zero,
make_scaled_scalar,
Expand Down
26 changes: 26 additions & 0 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,32 @@ def is_static_zero(val: Union[Array, ScaledArray]) -> Array:
return np.zeros(val.shape, dtype=np.bool_)


def is_static_anyscale(val: Union[Array, ScaledArray]) -> Array:
"""Is a scaled array a static anyscale values (i.e. 0/inf/-inf during JAX tracing as well)?
Returns a boolean Numpy array of the shape of the input.
"""

def np_anyscale(arr):
# Check if 0, np.inf or -np.inf
absarr = np.abs(arr)
return np.logical_or(np.equal(absarr, 0), np.equal(absarr, np.inf))

if is_numpy_scalar_or_array(val):
return np_anyscale(val)
if isinstance(val, ScaledArray):
# TODO: deal with 0 * inf issue?
data_mask = (
np_anyscale(val.data) if is_numpy_scalar_or_array(val.data) else np.zeros(val.data.shape, dtype=np.bool_)
)
scale_mask = (
np_anyscale(val.scale) if is_numpy_scalar_or_array(val.scale) else np.zeros(val.scale.shape, dtype=np.bool_)
)
return np.logical_or(data_mask, scale_mask)
# By default: can't decide.
return np.zeros(val.shape, dtype=np.bool_)


def is_static_one_scalar(val: Array) -> Union[bool, np.bool_]:
"""Is a scaled array a static one scalar value (i.e. one during JAX tracing as well)?"""
if isinstance(val, (int, float)):
Expand Down
20 changes: 20 additions & 0 deletions tests/core/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
asarray,
get_scale_dtype,
is_scaled_leaf,
is_static_anyscale,
is_static_one_scalar,
is_static_zero,
make_scaled_scalar,
Expand Down Expand Up @@ -304,6 +305,25 @@ def test__is_static_zero__proper_all_result(self, val, result):
all_zero = np.all(is_static_zero(val))
assert all_zero == result

@parameterized.parameters(
{"val": 0, "result": True},
{"val": 0.0, "result": True},
{"val": np.inf, "result": True},
{"val": np.int32(0), "result": True},
{"val": np.float16(0), "result": True},
{"val": np.float16(-np.inf), "result": True},
{"val": np.array([1, 2]), "result": False},
{"val": np.array([0, 0.0, -np.inf, np.inf]), "result": True},
{"val": jnp.array([0, 0.0, np.inf]), "result": False},
{"val": ScaledArray(np.array([0, 0.0]), jnp.array(2.0)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), np.array(0.0)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), np.array(np.inf)), "result": True},
{"val": ScaledArray(jnp.array([3, 4.0]), jnp.array(0.0)), "result": False},
)
def test__is_static_anyscale__proper_all_result(self, val, result):
all_zero = np.all(is_static_anyscale(val))
assert all_zero == result

@parameterized.parameters(
{"val": 0, "result": False},
{"val": 1, "result": True},
Expand Down

0 comments on commit 45f634b

Please sign in to comment.