Skip to content

Commit

Permalink
Moved everything to kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
jlperla committed Oct 31, 2024
1 parent cea33da commit 342adde
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 38 deletions.
41 changes: 6 additions & 35 deletions flax/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt)))
self.wrt = wrt

def update(self, grads, value=None, value_fn=None, model_static=None, **kwargs):
def update(self, grads, **kwargs):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the
gradients are with respect to the same :class:`Variable` types as defined in
Expand Down Expand Up @@ -245,49 +245,20 @@ def update(self, grads, value=None, value_fn=None, model_static=None, **kwargs):
... state.update(grads=grads)
Note that internally this function calls ``.tx.update()`` followed by a call
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``. For
``optax.GradientTransformationExtraArgs`` such as ``optax.scale_by_zoom_linesearch``,
the optional ``value, value_fn`` and ``**kwargs`` are passed to ``.tx.update()``.
The ``value_fn`` is assumed to be a univariate function of ``state.model`` if
``model_static`` is not provided. Otherwise, ``model_static`` is assumed to be the result of
``model_static, model_state = nnx.split(state.model, self.wrt)``, and ``value_fn`` is a
univariate function of ``model_state``.
to ``optax.apply_updates()`` to update ``params`` and ``opt_state``.
Args:
grads: the gradients derived from ``nnx.grad``.
value_fn (optional): function to evaluate the objective given the model, used by linesearch optimizers.
value (optional): value of the objective associated with the current grads update.
model_static (optional): graph of static elements from ``nnx.split(state.model, self.wrt)``.
**kwargs: additional keyword arguments passed to the tx.update.
**kwargs: additional keyword arguments passed to the tx.update, to support
``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``.
"""
params = nnx.state(self.model, self.wrt)
opt_state = _opt_state_variables_to_state(self.opt_state)

if value is None or value_fn is None:
updates, new_opt_state = self.tx.update(grads, opt_state, params)
else:
if model_static is None:
graphdef, _ = nnx.split(self.model, self.wrt)
def value_fn_wrapped(state):
model = nnx.merge(graphdef, state)
return value_fn(model)
else:
value_fn_wrapped = value_fn

updates, new_opt_state = self.tx.update(
grads,
opt_state,
params,
grad=grads,
value=value,
value_fn=value_fn_wrapped,
**kwargs,
)

updates, new_opt_state = self.tx.update(grads, opt_state, params, **kwargs)
new_params = optax.apply_updates(params, updates)
assert isinstance(new_params, nnx.State)

self.step.value += 1
nnx.update(self.model, new_params)
_update_opt_state(self.opt_state, new_opt_state)
_update_opt_state(self.opt_state, new_opt_state)
13 changes: 10 additions & 3 deletions tests/nnx/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def jax_jit_train_step(graphdef, state, x, y):
state = nnx.merge(graphdef, state)
model_static, model_state = nnx.split(state.model)
grads = jax.grad(loss_fn, argnums=1)(model_static, model_state, x, y)
state.update(grads, value = initial_loss, model_static = model_static, value_fn = lambda state: loss_fn(model_static, state, x, y))
state.update(grads, grad = grads, value = initial_loss, value_fn = lambda state: loss_fn(model_static, state, x, y))
return nnx.split(state)

graphdef, state = jit_decorator(jax_jit_train_step)(
Expand All @@ -164,12 +164,16 @@ def jax_jit_train_step(graphdef, state, x, y):
new_loss = loss_fn(*nnx.split(state.model), x, y)

else:
graphdef = nnx.graphdef(model)
loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()

loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y)

initial_loss = loss_fn(state.model, x, y)

def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y):
grads = nnx.grad(loss_fn)(optimizer.model, x, y)
optimizer.update(grads, value = initial_loss,value_fn = lambda model: loss_fn(model, x, y))
optimizer.update(grads, grad = grads, value = initial_loss, value_fn = loss_fn_split)

jit_decorator(nnx_jit_train_step)(state, x, y)
new_loss = loss_fn(state.model, x, y)
Expand Down Expand Up @@ -279,7 +283,10 @@ def test_wrt_update_linesearch(self, variable):
state.model, x, y
)
initial_loss = loss_fn(model, x, y)
state.update(grads=grads, value_fn = lambda model: loss_fn(model, x, y), value = initial_loss)
graphdef = nnx.graphdef(model)
loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y)

state.update(grads, grad=grads, value_fn = loss_fn_split, value = initial_loss)
self.assertTrue(loss_fn(model, x, y) < initial_loss)

# make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged
Expand Down

0 comments on commit 342adde

Please sign in to comment.