Skip to content

Commit

Permalink
Modify clipping implementation to avoid jnp.moveaxis, which causes un…
Browse files Browse the repository at this point in the history
…desirable all-to-all's in distributed environments.

Also improve documentation on per_example_global_norm_clip and per_example_layer_norm_clip, and refactor them to consume PyTrees instead of lists.

PiperOrigin-RevId: 676485567
  • Loading branch information
OptaxDev committed Sep 19, 2024
1 parent ee63e45 commit 2a88336
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 55 deletions.
129 changes: 74 additions & 55 deletions optax/transforms/_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:
"""Clips updates using their global norm.
References:
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
Pascanu et al., `On the difficulty of training Recurrent Neural Networks
<https://arxiv.org/abs/1211.5063>`_, 2012
Args:
max_norm: The maximum global norm for an update.
Expand Down Expand Up @@ -106,47 +107,87 @@ def clip_fn(t):
return base.GradientTransformation(base.init_empty_state, update_fn)


def _check_arrays_have_batch_dim(grads: chex.ArrayTree) -> bool:
"""Checks that each array in grads has a batch dimension in the 0th axis."""
grads = jax.tree.flatten(grads)[0]
batch_size = grads[0].shape[0]
return all(g.ndim >= 1 and batch_size == g.shape[0] for g in grads)


def per_example_global_norm_clip(
grads: list[chex.Array], l2_norm_clip: float
) -> tuple[list[chex.Array], jax.Array]:
grads: chex.ArrayTree, l2_norm_clip: float
) -> tuple[chex.ArrayTree, jax.Array]:
"""Applies gradient clipping per-example using their global norm.
Example:
>>> import optax
>>> import jax.numpy as jnp
>>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])]
>>> optax.per_example_global_norm_clip(grads, jnp.inf)
([Array([7., 7., 7.], dtype=float32)], Array(0, dtype=int32))
>>> optax.per_example_global_norm_clip(grads, 0.0)
([Array([0., 0., 0.], dtype=float32)], Array(3, dtype=int32))
>>> optax.per_example_global_norm_clip(grads, 1.25)
([Array([1.75, 1.75, 1.75], dtype=float32)], Array(3, dtype=int32))
See optax.contrib.differentially_private_aggregate for more more realistic
example usages.
References:
[Abadi et al, 2016](https://arxiv.org/abs/1607.00133)
Abadi et al., `Deep Learning with Differential Privacy
<https://arxiv.org/abs/1607.00133>`_, 2016
Args:
grads: flattened update; the function expects these to have a batch
dimension on the 0th axis.
grads: flattened update; the function expects each array in this list to
have a batch dimension on the 0th axis.
l2_norm_clip: maximum L2 norm of the per-example gradients.
Returns:
A tuple containing sum of the clipped per-example grads, and the number of
per-example grads that were clipped.
"""
bsize = grads[0].shape[0]

if any(g.ndim == 0 or bsize != g.shape[0] for g in grads):
if not _check_arrays_have_batch_dim(grads):
raise ValueError(
'Unlike other transforms, `per_example_global_norm_clip` expects'
' `grads` to have a batch dimension in the 0th axis.')

global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads)
divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
num_clipped = jnp.greater(divisors, 1.0).sum()
clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads]
multipliers = jnp.nan_to_num(
jnp.minimum(l2_norm_clip / global_grad_norms, 1.0), nan=1.0
)
num_clipped = jnp.sum(multipliers < 1.0)
clipped_sum = jax.tree.map(
lambda g: jnp.tensordot(multipliers, g, axes=1), grads
)
return clipped_sum, num_clipped


def per_example_layer_norm_clip(
grads: list[chex.Array],
grads: chex.ArrayTree,
global_l2_norm_clip: float,
uniform: bool = True,
eps: float = 1e-8,
) -> tuple[list[chex.Array], list[chex.Array]]:
uniform: bool = True
) -> tuple[chex.ArrayTree, chex.ArrayTree]:
"""Applies gradient clipping per-example using per-layer norms.
If len(grads) == 1, this function is equivent to
optax.per_example_global_norm_clip. If len(grads) > 1, each array in grads
will be independently clipped to a value ``C_i`` documented below.
Example:
>>> import optax
>>> import jax.numpy as jnp
>>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])]
>>> optax.per_example_layer_norm_clip(grads, jnp.inf)
([Array([7., 7., 7.], dtype=float32)], [Array(0, dtype=int32)])
>>> optax.per_example_layer_norm_clip(grads, 0.0)
([Array([0., 0., 0.], dtype=float32)], [Array(3, dtype=int32)])
>>> optax.per_example_layer_norm_clip(grads, 1.25)
([Array([1.75, 1.75, 1.75], dtype=float32)], [Array(3, dtype=int32)])
References:
[McMahan et al, 2012](https://arxiv.org/abs/1710.06963)]
McMahan et al., `Learning Differentially Private Recurrent Language Models
<https://arxiv.org/abs/1710.06963>`_, 2017
Args:
grads: flattened update; i.e. a list of gradients in which each item is
Expand All @@ -157,8 +198,6 @@ def per_example_layer_norm_clip(
where ``L`` is the number of layers. Otherwise, per-layer clip norm is
``global_l2_norm_clip * sqrt(f)``, where ``f`` is the fraction of total
model parameters that are in this layer.
eps: Small positive value to add to norms to avoid possible division by
zero.
Let ``C = global_l2_norm_clip value``. Then per-layer clipping is done as
follows:
Expand All @@ -174,54 +213,34 @@ def per_example_layer_norm_clip(
A tuple containing sum of the clipped per-example grads and the number of
per-example grads that were clipped for each layer.
"""
bsize = grads[0].shape[0]

if any(g.ndim == 0 or bsize != g.shape[0] for g in grads):
if not _check_arrays_have_batch_dim(grads):
raise ValueError(
'Unlike other transforms, `per_example_layer_norm_clip` expects'
' `grads` to have a batch dimension in the 0th axis; got shapes:'
f' {(g.shape for g in grads)}.'
f' {jax.tree.map(jnp.shape, grads)}.'
)

num_layers = len(grads)

# Compute per-layer clip norms, based on whether we are using uniform
# variant or not.
if uniform:
# Create list of length `num_layers` of per-layer clip norm.
layer_clip_norms = (
global_l2_norm_clip * (1.0 / num_layers) ** 0.5,
) * num_layers
num_layers = len(jax.tree.leaves(grads))
layer_clip_norms = jax.tree.map(
lambda _: global_l2_norm_clip * (1.0 / num_layers) ** 0.5,
grads
)
else:
total_params = sum(g[0].size for g in grads)
layer_clip_norms = tuple(
global_l2_norm_clip * (g[0].size / float(total_params)) ** 0.5
for g in grads
total_params = jax.tree.reduce(lambda x, g: x + g[0].size, grads, 0)
layer_clip_norms = jax.tree.map(
lambda g: global_l2_norm_clip * (g[0].size / total_params) ** 0.5,
grads
)

# Compute per-layer grad norms.
def map_layer_norm(grads_list):
return [jnp.linalg.norm(g, ord=None, axis=None) for g in grads_list]

layer_grad_norms_per_example = jax.vmap(map_layer_norm)(grads)

# Perform clipping.
divisors = (
tuple(
jnp.maximum(
layer_grad_norm / (layer_clip_norm + eps), 1.0
)
for layer_grad_norm, layer_clip_norm in zip(
layer_grad_norms_per_example, layer_clip_norms
)
)
)
num_clipped = [jnp.greater(divisor, 1.0).sum() for divisor in divisors]
clipped_sum = [
(g / jnp.expand_dims(d, axis=[i for i in range(1, g.ndim)])).sum(0)
for g, d in zip(grads, divisors)
]
return clipped_sum, num_clipped
result = jax.tree.map(per_example_global_norm_clip, grads, layer_clip_norms)
return jax.tree.transpose(outer_treedef=jax.tree.structure(grads),
inner_treedef=jax.tree.structure((0, 0)),
pytree_to_transpose=result)


def unitwise_norm(x: chex.Array) -> chex.Array:
Expand Down Expand Up @@ -259,8 +278,8 @@ def adaptive_grad_clip(clipping: float,
"""Clips updates to be at most ``clipping * parameter_norm``, unit-wise.
References:
[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
Recognition Without Normalization. (https://arxiv.org/abs/2102.06171)
Brock et al., `High-Performance Large-Scale Image Recognition Without
Normalization <https://arxiv.org/abs/2102.06171`_, 2021
Args:
clipping: The maximum allowed ratio of update norm to parameter norm.
Expand Down
27 changes: 27 additions & 0 deletions optax/transforms/_clipping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import chex
import jax
import jax.numpy as jnp
import numpy as np

from optax._src import linear_algebra
from optax.transforms import _clipping
Expand Down Expand Up @@ -91,6 +92,32 @@ def test_adaptive_grad_clip(self):
updates_step, _ = clipper.update(self.per_step_updates, None, params)
chex.assert_trees_all_close(updates, updates_step)

def test_per_example_global_norm_clip(self):
grads = [ # 3 users, 2 components
jnp.array([
[0, -0.5], # norm = sqrt(0^2 + 0.5^2 + 0^2)
[3, 4], # norm = sqrt(3^2 + 4^2 + 5^2)
[5, 6], # norm = sqrt(5^2 + 6^2 + 3^2)
[0, 0], # norm = 0
]),
jnp.array([[0], [5], [-3], [0]]),
]
answer = [
jnp.array([0, -0.5])
+ jnp.array([3, 4]) / jnp.sqrt(50)
+ jnp.array([5, 6]) / jnp.sqrt(70),
jnp.array([0])
+ jnp.array([5]) / jnp.sqrt(50)
+ jnp.array([-3]) / jnp.sqrt(70),
]
sum_clipped_grads, num_clipped = _clipping.per_example_global_norm_clip(
grads, l2_norm_clip=1.0
)

for actual, expected in zip(sum_clipped_grads, answer):
np.testing.assert_allclose(actual, expected, atol=1e-6)
self.assertEqual(num_clipped, 2)

def test_per_example_layer_norm_clip(self):
# Test data for a model with two layers and a batch size of 4. The
# 0th layer has one parameter (shape (1)), and the 1st layer has shape
Expand Down

0 comments on commit 2a88336

Please sign in to comment.