Skip to content

Commit

Permalink
Restore previous _compute_stats behavior in local case.
Browse files Browse the repository at this point in the history
Some earlier code refactors introduced a concat+split "no-op" to the non-distributed codepath
that turns out not to be such a no-op and can cause issues in some compiled code.  We polish
the refactor a bit to avoid this introduced change.

PiperOrigin-RevId: 554682671
  • Loading branch information
levskaya authored and Flax Authors committed Aug 8, 2023
1 parent fe54d39 commit b6d4bf3
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,29 @@ def _compute_stats(
# but preserves double or complex floating points
dtype = jnp.promote_types(dtype, jnp.float32)
x = jnp.asarray(x, dtype)
axes = _canonicalize_axes(x.ndim, axes)

def mean(x, axes=axes):
mu = x.mean(axes)
def maybe_distributed_mean(*xs):
mus = tuple(map(lambda x: x.mean(axes), xs))
if axis_name is None:
return mu
return lax.pmean(mu, axis_name, axis_index_groups=axis_index_groups)
return mus
else:
stacked_mus = jnp.stack(mus, axis=0)
reduced_mus = lax.pmean(
stacked_mus, axis_name, axis_index_groups=axis_index_groups)
return jnp.split(reduced_mus, reduced_mus.shape[0], axis=0)

if use_mean:
if use_fast_variance:
axes = _canonicalize_axes(x.ndim, axes)
mu, mu2 = mean(jnp.stack([x, _abs_sq(x)]), axes=[a + 1 for a in axes])
mu, mu2 = maybe_distributed_mean(x, _abs_sq(x))
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
else:
mu = mean(x)
var = mean(_abs_sq(x - jnp.expand_dims(mu, axes)))
mu = maybe_distributed_mean(x)
var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)))
else:
var = mean(_abs_sq(x))
var = maybe_distributed_mean(_abs_sq(x))
mu = jnp.zeros_like(var)
return mu, var

Expand Down

0 comments on commit b6d4bf3

Please sign in to comment.