Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jan 2, 2024
1 parent 2902549 commit a8d2bdd
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
4 changes: 3 additions & 1 deletion jax_scaled_arithmetics/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def pow2_round_up(val: Array) -> Array:

def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array:
"""Power-of-two rounding."""
if mode == Pow2RoundMode.DOWN:
if mode == Pow2RoundMode.NONE:
return val
elif mode == Pow2RoundMode.DOWN:
return pow2_round_down(val)
elif mode == Pow2RoundMode.UP:
return pow2_round_up(val)
Expand Down
6 changes: 5 additions & 1 deletion jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,11 @@ def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
assert isinstance(val, ScaledArray)
shape = val.shape
axes_size = np.array([shape[idx] for idx in axes])
# Rescale data component following reduction axes.
# Pow2 rounding for unit scaling "rule".
pow2_rounding_mode = get_autoscale_config().rounding_mode
# Rescale data component following reduction axes & round to power of 2 value.
axes_rescale = np.sqrt(np.prod(axes_size))
axes_rescale = pow2_round(axes_rescale, pow2_rounding_mode)
data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale.astype(val.data.dtype)
outscale = val.scale * axes_rescale.astype(val.scale.dtype)
return ScaledArray(data, outscale)
Expand All @@ -121,6 +124,7 @@ def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray:
shape = val.shape
data = lax.reduce_prod_p.bind(val.data, axes=axes)
axes_size = np.prod(np.array([shape[idx] for idx in axes]))
# Stable for power of 2.
scale = lax.integer_pow(val.scale, axes_size)
return ScaledArray(data, scale)

Expand Down
2 changes: 1 addition & 1 deletion tests/lax/test_scaled_ops_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def setUp(self):
self.rs = np.random.RandomState(42)

@parameterized.parameters(
{"reduce_prim": lax.reduce_sum_p, "expected_scale": 2 * np.sqrt(5)},
{"reduce_prim": lax.reduce_sum_p, "expected_scale": 2 * 2},
{"reduce_prim": lax.reduce_prod_p, "expected_scale": 2**5},
{"reduce_prim": lax.reduce_min_p, "expected_scale": 2},
{"reduce_prim": lax.reduce_max_p, "expected_scale": 2},
Expand Down

0 comments on commit a8d2bdd

Please sign in to comment.