Skip to content

Commit

Permalink
Restore previous _compute_stats behavior in local case.
Browse files Browse the repository at this point in the history
Some earlier code refactors introduced a concat+split "no-op" to the non-distributed codepath
that turns out not to be such a no-op and can cause issues in some compiled code.  We polish
the refactor a bit to avoid this introduced change.

PiperOrigin-RevId: 554682671
  • Loading branch information
levskaya authored and Flax Authors committed Aug 9, 2023
1 parent fe54d39 commit cafb5ad
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,35 @@ def _compute_stats(
# but preserves double or complex floating points
dtype = jnp.promote_types(dtype, jnp.float32)
x = jnp.asarray(x, dtype)
axes = _canonicalize_axes(x.ndim, axes)

def mean(x, axes=axes):
mu = x.mean(axes)
def maybe_distributed_mean(*xs):
mus = tuple(x.mean(axes) for x in xs)
if axis_name is None:
return mu
return lax.pmean(mu, axis_name, axis_index_groups=axis_index_groups)
return mus if len(xs) > 1 else mus[0]
else:
# In the distributed case we stack multiple arrays to speed comms.
if len(xs) > 1:
reduced_mus = lax.pmean(jnp.stack(mus, axis=0),
axis_name,
axis_index_groups=axis_index_groups)
return tuple(reduced_mus[i] for i in range(len(xs)))
else:
return lax.pmean(mus[0],
axis_name,
axis_index_groups=axis_index_groups)

if use_mean:
if use_fast_variance:
axes = _canonicalize_axes(x.ndim, axes)
mu, mu2 = mean(jnp.stack([x, _abs_sq(x)]), axes=[a + 1 for a in axes])
mu, mu2 = maybe_distributed_mean(x, _abs_sq(x))
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
else:
mu = mean(x)
var = mean(_abs_sq(x - jnp.expand_dims(mu, axes)))
mu = maybe_distributed_mean(x)
var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)))
else:
var = mean(_abs_sq(x))
var = maybe_distributed_mean(_abs_sq(x))
mu = jnp.zeros_like(var)
return mu, var

Expand All @@ -125,7 +135,7 @@ def _normalize(
bias_init: Callable[[PRNGKey, Shape, Dtype], Array],
scale_init: Callable[[PRNGKey, Shape, Dtype], Array],
):
""" "Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
"""Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
Arguments:
mdl: Module to apply the normalization in (normalization params will reside
Expand Down

0 comments on commit cafb5ad

Please sign in to comment.