Skip to content

Commit

Permalink
Add use_fast_variance option to GroupNorm and BatchNorm to allow disa…
Browse files Browse the repository at this point in the history
…bling it.

PiperOrigin-RevId: 553390785
  • Loading branch information
Flax Team committed Aug 3, 2023
1 parent 8da8c46 commit bc6a6c1
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ class BatchNorm(Module):
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
the examples on the first two and last two devices. See `jax.lax.psum`
for more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""

use_running_average: Optional[bool] = None
Expand All @@ -243,6 +245,7 @@ class BatchNorm(Module):
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True

@compact
def __call__(self, x, use_running_average: Optional[bool] = None):
Expand Down Expand Up @@ -290,6 +293,7 @@ def __call__(self, x, use_running_average: Optional[bool] = None):
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
)

if not self.is_initializing():
Expand Down Expand Up @@ -515,6 +519,8 @@ class GroupNorm(Module):
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""

num_groups: Optional[int] = 32
Expand All @@ -528,6 +534,7 @@ class GroupNorm(Module):
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
use_fast_variance: bool = True

@compact
def __call__(self, x):
Expand Down Expand Up @@ -581,6 +588,7 @@ def __call__(self, x):
self.dtype,
self.axis_name,
self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
)
mean = jnp.repeat(mean, group_size, axis=-1)
var = jnp.repeat(var, group_size, axis=-1)
Expand Down

0 comments on commit bc6a6c1

Please sign in to comment.