Skip to content

Commit

Permalink
Merge pull request #1066 from carlosgmartin:jax_tree_util_legacy
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677810392
  • Loading branch information
OptaxDev committed Sep 23, 2024
2 parents 84bd835 + e95a076 commit b06f6c5
Show file tree
Hide file tree
Showing 61 changed files with 331 additions and 342 deletions.
4 changes: 2 additions & 2 deletions docs/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,9 @@
"source": [
"### Applying updates ([update.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/update.py))\n",
"\n",
"After transforming an update using a {py:class}`GradientTransformation <optax.GradientTransformation>` or any custom manipulation of the update, you will typically apply the update to a set of parameters. This can be done trivially using `tree_map`.\n",
"After transforming an update using a {py:class}`GradientTransformation <optax.GradientTransformation>` or any custom manipulation of the update, you will typically apply the update to a set of parameters. This can be done trivially using `jax.tree.map`.\n",
"\n",
"For convenience, we expose an {py:class}`apply_updates <optax.apply_updates>` function to apply updates to parameters. The function just adds the updates and the parameters together, i.e. `tree_map(lambda p, u: p + u, params, updates)`."
"For convenience, we expose an {py:class}`apply_updates <optax.apply_updates>` function to apply updates to parameters. The function just adds the updates and the parameters together, i.e. `jax.tree.map(lambda p, u: p + u, params, updates)`."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar10_resnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@
" logits=logits, labels=labels\n",
" ).mean()\n",
" accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)\n",
" l2_params = jax.tree_util.tree_leaves(params)\n",
" l2_params = jax.tree.leaves(params)\n",
" # Computes regularization on all except batchnorm parameters.\n",
" weight_l2 = sum(jnp.sum(x**2) for x in l2_params if x.ndim \u003e 1)\n",
" loss = mean_loss + 0.5 * L2_REG * weight_l2\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/contrib/differentially_private_sgd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@
" grad_fn = jax.grad(loss_fn, has_aux=True)\n",
" if DPSGD:\n",
" # Inserts a dimension in axis 1 to use jax.vmap over the batch.\n",
" batch = jax.tree_util.tree_map(lambda x: x[:, None], batch)\n",
" batch = jax.tree.map(lambda x: x[:, None], batch)\n",
" # Uses jax.vmap across the batch to extract per-example gradients.\n",
" grad_fn = jax.vmap(grad_fn, in_axes=(None, 0))\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/nanolm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@
}
],
"source": [
"n_params = sum(p.size for p in jax.tree_util.tree_leaves(var_params))\n",
"n_params = sum(p.size for p in jax.tree.leaves(var_params))\n",
"\n",
"print(f\"Total number of parameters: {n_params:_}\")"
]
Expand Down
11 changes: 5 additions & 6 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from jax import flatten_util
import jax.numpy as jnp
import jax.random as jrd
import jax.tree_util as jtu
import numpy as np

from optax._src import alias
Expand Down Expand Up @@ -153,7 +152,7 @@ def step(params, state):
value, updates = jax.value_and_grad(objective)(params)
# Complex gradients need to be conjugated before being added to parameters
# https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
updates = jax.tree_util.tree_map(lambda x: x.conj(), updates)
updates = jax.tree.map(lambda x: x.conj(), updates)
if opt_name == 'polyak_sgd':
update_kwargs = {'value': value}
else:
Expand Down Expand Up @@ -548,11 +547,11 @@ def test_preconditioning_by_lbfgs_on_trees(self, idx: int):
)

flat_dws = [
flatten_util.ravel_pytree(jtu.tree_map(lambda dw: dw[i], dws))[0] # pylint: disable=cell-var-from-loop
flatten_util.ravel_pytree(jax.tree.map(lambda dw: dw[i], dws))[0] # pylint: disable=cell-var-from-loop
for i in range(m)
]
flat_dus = [
flatten_util.ravel_pytree(jtu.tree_map(lambda du: du[i], dus))[0] # pylint: disable=cell-var-from-loop
flatten_util.ravel_pytree(jax.tree.map(lambda du: du[i], dus))[0] # pylint: disable=cell-var-from-loop
for i in range(m)
]
flat_dws, flat_dus = jnp.stack(flat_dws), jnp.stack(flat_dus)
Expand Down Expand Up @@ -631,7 +630,7 @@ def fun_(x):
)

def fun(x):
return otu.tree_sum(jtu.tree_map(fun_, x))
return otu.tree_sum(jax.tree.map(fun_, x))

key = jrd.PRNGKey(0)
init_array = jrd.normal(key, (2, 4))
Expand Down Expand Up @@ -677,7 +676,7 @@ def test_binary_logreg(self, scale_init_precond):
def fun(weights):
inputs, labels = data
logits = jnp.dot(inputs, weights)
losses = jtu.tree_map(
losses = jax.tree.map(
lambda z, y: jax.nn.softplus(jnp.where(y, -z, z)), logits, labels
)
return jnp.mean(losses)
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def set_to_zero() -> GradientTransformation:

def update_fn(updates, state, params=None):
del params # Unused by the zero transform.
return jax.tree_util.tree_map(jnp.zeros_like, updates), state
return jax.tree.map(jnp.zeros_like, updates), state

return GradientTransformation(init_empty_state, update_fn)

Expand Down Expand Up @@ -297,7 +297,7 @@ def stateless_with_tree_map(
This wrapper eliminates the boilerplate needed to create a transformation that
does not require saved state between iterations, just like optax.stateless.
In addition, this function will apply the tree_map over update/params for you.
In addition, this function will apply the tree map over update/params for you.
Args:
f: Update function that takes in an update array (e.g. gradients) and
Expand All @@ -311,10 +311,10 @@ def stateless_with_tree_map(
def update_fn(updates, state, params=None):
del state
if params is not None:
return jax.tree_util.tree_map(f, updates, params), EmptyState()
return jax.tree.map(f, updates, params), EmptyState()
else:
f_ = lambda u: f(u, None)
return jax.tree_util.tree_map(f_, updates), EmptyState()
return jax.tree.map(f_, updates), EmptyState()

return GradientTransformation(init_empty_state, update_fn)

Expand Down
6 changes: 3 additions & 3 deletions optax/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_stateless(self):

@base.stateless
def opt(g, p):
return jax.tree_util.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p)
return jax.tree.map(lambda g_, p_: g_ + 0.1 * p_, g, p)

state = opt.init(params)
update_fn = self.variant(opt.update)
Expand All @@ -146,7 +146,7 @@ def test_stateless_no_params(self):

@base.stateless
def opt(g, _):
return jax.tree_util.tree_map(lambda g_: g_ * 2, g)
return jax.tree.map(lambda g_: g_ * 2, g)

state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars
update_fn = self.variant(opt.update)
Expand All @@ -156,7 +156,7 @@ def opt(g, _):

def test_init_returns_emptystate(self):
def weight_decay(g, p):
return jax.tree_util.tree_map(lambda g_, p_: g_ + 0.1 * p_, g, p)
return jax.tree.map(lambda g_, p_: g_ + 0.1 * p_, g, p)

opt = base.stateless(weight_decay)
state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars
Expand Down
10 changes: 5 additions & 5 deletions optax/_src/combine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ class MultiTransformTest(chex.TestCase):
@parameterized.parameters(True, False)
def test_multi_transform(self, use_fn):
params = {'a1': 1., 'b1': 2., 'z1': {'a2': 3., 'z2': {'c1': 4.}}}
params = jax.tree_util.tree_map(jnp.asarray, params)
input_updates = jax.tree_util.tree_map(lambda x: x / 10.0, params)
params = jax.tree.map(jnp.asarray, params)
input_updates = jax.tree.map(lambda x: x / 10.0, params)
tx_dict = {'a': transform.scale(-1.0),
'b': transform.ema(0.0), # stateful
'c': transform.scale(2.0)}
Expand Down Expand Up @@ -230,7 +230,7 @@ def test_empty(self, container):
def test_labels_mismatch(self, use_extra_label, use_fn):
# The labels from label_fn must be a subet of the keys for the tx.
params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}}
params = jax.tree_util.tree_map(jnp.asarray, params)
params = jax.tree.map(jnp.asarray, params)
label_tree = {'a': 0, 'b': [1, 0], 'c': 1} # prefix of params

if use_extra_label:
Expand All @@ -247,7 +247,7 @@ def test_labels_mismatch(self, use_extra_label, use_fn):
self.variant(init_fn)(params)
else:
state = self.variant(init_fn)(params)
updates = jax.tree_util.tree_map(lambda x: x / 10.0, params)
updates = jax.tree.map(lambda x: x / 10.0, params)
self.variant(update_fn)(updates, state)


Expand All @@ -260,7 +260,7 @@ def init_fn(params):

def update_fn(updates, state, params, *, loss, **extra_args):
del params, extra_args
updates = jax.tree_util.tree_map(
updates = jax.tree.map(
lambda u: u / loss, updates)
return updates, state

Expand Down
14 changes: 7 additions & 7 deletions optax/_src/factorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _factored_dims(

@dataclasses.dataclass
class _UpdateResult:
"""Opaque containter that is not traversed by jax.tree_util.tree_map."""
"""Opaque containter that is not traversed by jax.tree.map."""
update: chex.Array # the update to apply to params
v_row: chex.Array # used for factored params.
v_col: chex.Array # used for factored params.
Expand Down Expand Up @@ -118,9 +118,9 @@ def _to_state(count: chex.Array, result_tree):
"""Maps from a tree of (factored) values to separate trees of values."""
return FactoredState(
count=count,
v_row=jax.tree_util.tree_map(lambda o: o.v_row, result_tree),
v_col=jax.tree_util.tree_map(lambda o: o.v_col, result_tree),
v=jax.tree_util.tree_map(lambda o: o.v, result_tree))
v_row=jax.tree.map(lambda o: o.v_row, result_tree),
v_col=jax.tree.map(lambda o: o.v_col, result_tree),
v=jax.tree.map(lambda o: o.v, result_tree))

def init_fn(params):
"""Initialise the optimiser's state."""
Expand All @@ -145,7 +145,7 @@ def _init(param):
v=jnp.zeros(param.shape))

return _to_state(
jnp.zeros([], jnp.int32), jax.tree_util.tree_map(_init, params))
jnp.zeros([], jnp.int32), jax.tree.map(_init, params))

def update_fn(grads, state, params):
"""Apply gradient transformation."""
Expand Down Expand Up @@ -187,12 +187,12 @@ def _update(grad, v_row, v_col, v, param, step):
return _UpdateResult(update, new_v_row, new_v_col, new_v)

# Transform grad and compute new per-parameter stats.
output = jax.tree_util.tree_map(
output = jax.tree.map(
lambda *args: _update(*args, state.count),
grads, state.v_row, state.v_col, state.v, params)

# Unpack updates / stats and return.
updates = jax.tree_util.tree_map(lambda o: o.update, output)
updates = jax.tree.map(lambda o: o.update, output)
return updates, _to_state(numerics.safe_increment(state.count), output)

return base.GradientTransformation(init_fn, update_fn)
4 changes: 2 additions & 2 deletions optax/_src/float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
class Float64Test(parameterized.TestCase):

def _assert_dtype_equals(self, tree1, tree2):
tree1_types = jax.tree_util.tree_map(lambda t: t.dtype, tree1)
tree2_types = jax.tree_util.tree_map(lambda t: t.dtype, tree2)
tree1_types = jax.tree.map(lambda t: t.dtype, tree1)
tree2_types = jax.tree.map(lambda t: t.dtype, tree2)
self.assertEqual(tree1_types, tree2_types)

@chex.all_variants
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _normalize_tree(x):
def global_norm(updates: base.PyTree) -> chex.Array:
"""Compute the global norm across a nested structure of tensors."""
return jnp.sqrt(sum(
jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates)))
jnp.sum(numerics.abs_sq(x)) for x in jax.tree.leaves(updates)))


def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars):
Expand Down
8 changes: 4 additions & 4 deletions optax/_src/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def update_fn(
# Jittable way of resetting the fast optimizer state if parameters will be
# synchronized after this update step.
initial_state = fast_optimizer.init(params.fast)
fast_state = jax.tree_util.tree_map(
fast_state = jax.tree.map(
lambda current, init: (1 - sync_next) * current + sync_next * init,
fast_state,
initial_state,
Expand Down Expand Up @@ -186,13 +186,13 @@ def _lookahead_update(
# slow_updates = slow_step_size * sync_next * last_difference
# fast_updates = updates - (
# 1 - slow_step_size) * sync_next * last_difference
last_difference = jax.tree_util.tree_map(
last_difference = jax.tree.map(
lambda f, u, s: f + u - s, params.fast, updates, params.slow
)
slow_updates = jax.tree_util.tree_map(
slow_updates = jax.tree.map(
lambda diff: slow_step_size * sync_next * diff, last_difference
)
fast_updates = jax.tree_util.tree_map(
fast_updates = jax.tree.map(
lambda up, diff: up - sync_next * (1 - slow_step_size) * diff,
updates,
last_difference,
Expand Down
4 changes: 2 additions & 2 deletions optax/_src/lookahead_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ def _test_optimizer(step_size: float) -> base.GradientTransformation:
# Use SGD for simplicity but add non-trivial optimizer state so that the
# resetting behaviour of lookahead can be tested.
def init_fn(params):
aggregate_grads = jax.tree_util.tree_map(jnp.zeros_like, params)
aggregate_grads = jax.tree.map(jnp.zeros_like, params)
return OptimizerTestState(aggregate_grads, is_reset=True)

def update_fn(updates, state, params):
# The test optimizer does not use the parameters, but we check that they
# have been passed correctly.
chex.assert_trees_all_equal_shapes(updates, params)
aggregate_grads = update.apply_updates(state.aggregate_grads, updates)
updates = jax.tree_util.tree_map(lambda u: step_size * u, updates)
updates = jax.tree.map(lambda u: step_size * u, updates)
return updates, OptimizerTestState(aggregate_grads, is_reset=False)

return base.GradientTransformation(init_fn, update_fn)
Expand Down
Loading

0 comments on commit b06f6c5

Please sign in to comment.