diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 7caf62fc39..68594daa1d 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -87,25 +87,29 @@ def _compute_stats( # but preserves double or complex floating points dtype = jnp.promote_types(dtype, jnp.float32) x = jnp.asarray(x, dtype) + axes = _canonicalize_axes(x.ndim, axes) - def mean(x, axes=axes): - mu = x.mean(axes) + def maybe_distributed_mean(*xs): + mus = tuple(map(lambda x: x.mean(axes), xs)) if axis_name is None: - return mu - return lax.pmean(mu, axis_name, axis_index_groups=axis_index_groups) + return mus + else: + stacked_mus = jnp.stack(mus, axis=0) + reduced_mus = lax.pmean( + stacked_mus, axis_name, axis_index_groups=axis_index_groups) + return jnp.split(reduced_mus, reduced_mus.shape[0], axis=0) if use_mean: if use_fast_variance: - axes = _canonicalize_axes(x.ndim, axes) - mu, mu2 = mean(jnp.stack([x, _abs_sq(x)]), axes=[a + 1 for a in axes]) + mu, mu2 = maybe_distributed_mean(x, _abs_sq(x)) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0.0, mu2 - _abs_sq(mu)) else: - mu = mean(x) - var = mean(_abs_sq(x - jnp.expand_dims(mu, axes))) + mu = maybe_distributed_mean(x) + var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes))) else: - var = mean(_abs_sq(x)) + var = maybe_distributed_mean(_abs_sq(x)) mu = jnp.zeros_like(var) return mu, var