Skip to content

Commit

Permalink
Restore previous _compute_stats behavior in local case.
Browse files Browse the repository at this point in the history
Some earlier code refactors introduced a concat+split "no-op" to the non-distributed codepath
that turns out not to be such a no-op and can cause issues in some compiled code.  We polish
the refactor a bit to avoid this introduced change.

PiperOrigin-RevId: 555039226
  • Loading branch information
levskaya authored and Flax Authors committed Aug 9, 2023
1 parent d826006 commit 976906f
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,35 @@ def _compute_stats(
# but preserves double or complex floating points
dtype = jnp.promote_types(dtype, jnp.float32)
x = jnp.asarray(x, dtype)
axes = _canonicalize_axes(x.ndim, axes)

def mean(x, axes=axes):
mu = x.mean(axes)
def maybe_distributed_mean(*xs):
mus = tuple(x.mean(axes) for x in xs)
if axis_name is None:
return mu
return lax.pmean(mu, axis_name, axis_index_groups=axis_index_groups)
return mus if len(xs) > 1 else mus[0]
else:
# In the distributed case we stack multiple arrays to speed comms.
if len(xs) > 1:
reduced_mus = lax.pmean(
jnp.stack(mus, axis=0),
axis_name,
axis_index_groups=axis_index_groups,
)
return tuple(reduced_mus[i] for i in range(len(xs)))
else:
return lax.pmean(mus[0], axis_name, axis_index_groups=axis_index_groups)

if use_mean:
if use_fast_variance:
axes = _canonicalize_axes(x.ndim, axes)
mu, mu2 = mean(jnp.stack([x, _abs_sq(x)]), axes=[a + 1 for a in axes])
mu, mu2 = maybe_distributed_mean(x, _abs_sq(x))
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
var = jnp.maximum(0.0, mu2 - _abs_sq(mu))
else:
mu = mean(x)
var = mean(_abs_sq(x - jnp.expand_dims(mu, axes)))
mu = maybe_distributed_mean(x)
var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)))
else:
var = mean(_abs_sq(x))
var = maybe_distributed_mean(_abs_sq(x))
mu = jnp.zeros_like(var)
return mu, var

Expand All @@ -125,7 +135,7 @@ def _normalize(
bias_init: Callable[[PRNGKey, Shape, Dtype], Array],
scale_init: Callable[[PRNGKey, Shape, Dtype], Array],
):
""" "Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
"""Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
Arguments:
mdl: Module to apply the normalization in (normalization params will reside
Expand Down Expand Up @@ -208,27 +218,27 @@ class BatchNorm(Module):
y = BN.apply(vars_in, x)
Attributes:
use_running_average: if True, the statistics stored in batch_stats
will be used instead of computing the batch statistics on the input.
use_running_average: if True, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of
the batch statistics.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma).
When the next layer is linear (also e.g. nn.relu), this can be disabled
since the scaling will be done by the next layer.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
the examples on the first two and last two devices. See `jax.lax.psum`
for more details.
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""
Expand Down Expand Up @@ -260,8 +270,8 @@ def __call__(self, x, use_running_average: Optional[bool] = None):
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats
will be used instead of computing the batch statistics on the input.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
Expand Down Expand Up @@ -436,9 +446,9 @@ class RMSNorm(Module):
array being normalized is sharded across devices within a pmap.
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
the examples on the first two and last two devices. See `jax.lax.psum`
for more details.
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
"""

epsilon: float = 1e-6
Expand Down

0 comments on commit 976906f

Please sign in to comment.