Skip to content

Commit

Permalink
[linen] Use stack instead of concatenate in compute_stats, to h…
Browse files Browse the repository at this point in the history
…andle scalar stats case.

Added test for scalar stats, which were broken by previous change.

PiperOrigin-RevId: 549263368
  • Loading branch information
chr1sj0nes authored and Flax Authors committed Jul 19, 2023
1 parent 881c449 commit 15d6857
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
7 changes: 1 addition & 6 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,7 @@ def pmean(x):

if use_mean:
if use_fast_variance:
mean = x.mean(axes)
mean2 = _abs_sq(x).mean(axes)
if mean.ndim > 0:
mean, mean2 = jnp.split(pmean(jnp.concatenate([mean, mean2])), 2)
else:
mean, mean2 = pmean(mean), pmean(mean2)
mean, mean2 = pmean(jnp.stack([x.mean(axes), _abs_sq(x).mean(axes)]))
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
Expand Down
1 change: 1 addition & 0 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def test_batch_norm_complex(self):
{'reduction_axes': -1},
{'reduction_axes': 1},
{'reduction_axes': (1, 2)},
{'reduction_axes': (0, 1, 2)},
{'reduction_axes': -1, 'use_fast_variance': False},
)
def test_layer_norm(self, reduction_axes, use_fast_variance=True):
Expand Down

0 comments on commit 15d6857

Please sign in to comment.