From 03486b3e55ed97c7211f2b0fbe95550cc8da61ee Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Mon, 7 Aug 2023 21:05:36 -0700 Subject: [PATCH] Restore previous _compute_stats behavior in local case. Some earlier code refactors introduced a concat+split "no-op" to the non-distributed codepath that turns out not to be such a no-op and can cause issues in some compiled code. We polish the refactor a bit to avoid this introduced change. PiperOrigin-RevId: 554682671 --- flax/linen/normalization.py | 60 +++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 7caf62fc39..2be694634a 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -87,25 +87,35 @@ def _compute_stats( # but preserves double or complex floating points dtype = jnp.promote_types(dtype, jnp.float32) x = jnp.asarray(x, dtype) + axes = _canonicalize_axes(x.ndim, axes) - def mean(x, axes=axes): - mu = x.mean(axes) + def maybe_distributed_mean(*xs): + mus = tuple(x.mean(axes) for x in xs) if axis_name is None: - return mu - return lax.pmean(mu, axis_name, axis_index_groups=axis_index_groups) + return mus if len(xs) > 1 else mus[0] + else: + # In the distributed case we stack multiple arrays to speed comms. + if len(xs) > 1: + reduced_mus = lax.pmean( + jnp.stack(mus, axis=0), + axis_name, + axis_index_groups=axis_index_groups, + ) + return tuple(reduced_mus[i] for i in range(len(xs))) + else: + return lax.pmean(mus[0], axis_name, axis_index_groups=axis_index_groups) if use_mean: if use_fast_variance: - axes = _canonicalize_axes(x.ndim, axes) - mu, mu2 = mean(jnp.stack([x, _abs_sq(x)]), axes=[a + 1 for a in axes]) + mu, mu2 = maybe_distributed_mean(x, _abs_sq(x)) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. var = jnp.maximum(0.0, mu2 - _abs_sq(mu)) else: - mu = mean(x) - var = mean(_abs_sq(x - jnp.expand_dims(mu, axes))) + mu = maybe_distributed_mean(x) + var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes))) else: - var = mean(_abs_sq(x)) + var = maybe_distributed_mean(_abs_sq(x)) mu = jnp.zeros_like(var) return mu, var @@ -125,7 +135,7 @@ def _normalize( bias_init: Callable[[PRNGKey, Shape, Dtype], Array], scale_init: Callable[[PRNGKey, Shape, Dtype], Array], ): - """ "Normalizes the input of a normalization layer and optionally applies a learned scale and bias. + """Normalizes the input of a normalization layer and optionally applies a learned scale and bias. Arguments: mdl: Module to apply the normalization in (normalization params will reside @@ -208,27 +218,27 @@ class BatchNorm(Module): y = BN.apply(vars_in, x) Attributes: - use_running_average: if True, the statistics stored in batch_stats - will be used instead of computing the batch statistics on the input. + use_running_average: if True, the statistics stored in batch_stats will be + used instead of computing the batch statistics on the input. axis: the feature or non-batch axis of the input. - momentum: decay rate for the exponential moving average of - the batch statistics. + momentum: decay rate for the exponential moving average of the batch + statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the result (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). use_bias: if True, bias (beta) is added. - use_scale: if True, multiply by scale (gamma). - When the next layer is linear (also e.g. nn.relu), this can be disabled - since the scaling will be done by the next layer. + use_scale: if True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple devices. See `jax.pmap` for a description of axis names (default: None). axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For - example, `[[0, 1], [2, 3]]` would independently batch-normalize over - the examples on the first two and last two devices. See `jax.lax.psum` - for more details. + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. use_fast_variance: If true, use a faster, but less numerically stable, calculation for the variance. """ @@ -260,8 +270,8 @@ def __call__(self, x, use_running_average: Optional[bool] = None): Args: x: the input to be normalized. - use_running_average: if true, the statistics stored in batch_stats - will be used instead of computing the batch statistics on the input. + use_running_average: if true, the statistics stored in batch_stats will be + used instead of computing the batch statistics on the input. Returns: Normalized inputs (the same shape as inputs). @@ -436,9 +446,9 @@ class RMSNorm(Module): array being normalized is sharded across devices within a pmap. axis_index_groups: groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For - example, `[[0, 1], [2, 3]]` would independently batch-normalize over - the examples on the first two and last two devices. See `jax.lax.psum` - for more details. + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. """ epsilon: float = 1e-6