diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 663b9a8ef6..3f43a19de0 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -19,6 +19,7 @@ from flax import struct +from flax.core.frozen_dict import FrozenDict from flax.nnx import ( extract, filterlib, @@ -362,56 +363,59 @@ def value_and_grad( return_value=True, ) +# ----------------------------------------------- +# custom_vjp +# ----------------------------------------------- +# custom_vjp is one of the most complicated transforms as it requires +# to handle 4 different functions: +# 1. CustomVJP: the main object that runs the outer logic, converts input graph nodes +# to pytrees and output pytrees to graph nodes. +# 2. CustomVjpFnWrapper: function that wraps the user's function, it converts +# its input pytrees to graph nodes and output graph nodes to pytrees. +# 3. FwdFn: wraps the user's fwd function, it converts its input pytrees to graph nodes +# and output graph nodes to pytrees. Since it might run by itself in a separate context, +# it needs to be aware if the update_context is active or not in order to update the outer +# referenes. +# 4. BwdFn: wraps the user's bwd function, it converts its input pytrees to graph nodes +# and output graph nodes to pytrees. It doesn't need to be aware of the outer context +# since it will never update the outer references as it runs during the backward pass. def _custom_vjp_merge_fn( - ctx: graph.MergeContext, - path, - prefix: bool | DiffState, - value: extract.NodeStates, - *, - nondiff_states: deque[extract.GraphDefState], + ctx: graph.MergeContext, path, prefix: bool, value: extract.NodeStates ): - nondiff = nondiff_states.popleft() - return ctx.merge(nondiff.graphdef, value.state, nondiff.state) + return ctx.merge(value.graphdef, value.state) -def _custom_vjp_split_fn( - ctx: graph.SplitContext, - path, - prefix: bool | DiffState, - value, - *, - nondiff_states: deque[extract.GraphDefState], -): +def _custom_vjp_split_fn(ctx: graph.SplitContext, path, prefix: bool, value): if prefix is False: # pure non-differentiable arg, we pass all the state through # but we return TreeNode.from_split with a graphdef to we can call from_tree # on the nondiff args during the backward pass - graphdef, passed = ctx.split(value) - broadcast = State({}) # type: ignore[var-annotated] - nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) - return extract.NodeStates.from_split(graphdef, passed) + raise TypeError('graph nodes cannot appear in non-differentiable arguments') elif prefix is True: # pure differentiable arg, we pass all the state through # but we return a TreeNode.from_states which doesn't have a graphdef # in order to keep the gradients clean from any metadata graphdef, passed = ctx.split(value) - broadcast = State({}) - nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) - return extract.NodeStates.from_states(passed) + return extract.NodeStates.from_split(graphdef, passed) else: # differentiable arg with DiffState filter, we use the filter to split the state # as before we return a TreeNode.from_states to keep the gradients clean # from any metadata, the non-differentiable state is stored in a deque # which is broadcasted during the forward pass - graphdef, passed, broadcast = ctx.split(value, prefix.filter, ...) # type: ignore[misc] - nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) - return extract.NodeStates.from_states(passed) + raise NotImplementedError('DiffState prefix are not yet supported') class CustomVjpMetadata(struct.PyTreeNode): + nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) +def _extract_index_mappings(x, *, index_mappings: deque[FrozenDict]): + if isinstance(x, graph.NodeDef): + assert x.index_mapping is not None + index_mappings.append(x.index_mapping) + return dataclasses.replace(x, index_mapping=None) + return x @dataclasses.dataclass(eq=False) class CustomVjpFnWrapper: @@ -419,27 +423,32 @@ class CustomVjpFnWrapper: ctxtag: str def __post_init__(self): - functools.update_wrapper(self, self.f) + # functools.update_wrapper(self, self.f) + pass - def __call__(self, *pure_args): - broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = ( - extract.get_broadcast_state(self.ctxtag) - ) - metadata, nondiff_states = broadcast + def __call__(self, metadata: CustomVjpMetadata, *pure_args): args = extract.from_tree( - pure_args, - merge_fn=functools.partial( - _custom_vjp_merge_fn, nondiff_states=nondiff_states - ), - ctxtag=self.ctxtag, + pure_args, merge_fn=_custom_vjp_merge_fn, ctxtag=self.ctxtag ) out = self.f(*args) - args_out = extract.clear_non_graph_nodes(args) + # remove nondiff from pure_args_out_g + args_out = tuple( + x for i, x in enumerate(args) if i not in metadata.nondiff_argnums + ) + args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( (args_out, out), ctxtag=self.ctxtag ) + # remove index_mapping from NodeDef's but store them in global context + index_mappings: deque[FrozenDict] = extract.get_broadcast_state(self.ctxtag) + + pure_args_out, pure_out = jax.tree.map( + functools.partial(_extract_index_mappings, index_mappings=index_mappings), + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), + ) return pure_args_out, pure_out @@ -452,28 +461,47 @@ class FwdFn: def __post_init__(self): functools.update_wrapper(self, self.fwd) - def __call__(self, *pure_args): - broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = ( - extract.get_broadcast_state(self.ctxtag) + def __call__(self, metadata: CustomVjpMetadata, *pure_args): + # here we need to be aware if the update_context is active or not + # when its not active, index_mappings will be None + # when its active, we will remove the index_mappings from the NodeDef's and store them + # in the index_mappings deque created by CustomVjp + update_context_active = ( + self.ctxtag in graph.GRAPH_CONTEXT.update_context_stacks ) - metadata, nondiff_states = broadcast args = extract.from_tree( pure_args, - merge_fn=functools.partial( - _custom_vjp_merge_fn, nondiff_states=nondiff_states - ), - ctxtag=self.ctxtag, + merge_fn=_custom_vjp_merge_fn, + ctxtag=self.ctxtag if update_context_active else None, ) out, residual = self.fwd(*args) - args_out = extract.clear_non_graph_nodes(args) + # remove nondiff from pure_args_out_g + args_out = tuple( + x for i, x in enumerate(args) if i not in metadata.nondiff_argnums + ) + args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( - (args_out, out), ctxtag=self.ctxtag + (args_out, out), + ctxtag=self.ctxtag if update_context_active else None, ) pure_residual = extract.to_tree(residual) - return (pure_args_out, pure_out), (metadata, pure_residual) + if update_context_active: + # remove index_mapping from NodeDef's but store them in global context + index_mappings: deque[FrozenDict] = extract.get_broadcast_state( + self.ctxtag + ) + pure_args_out, pure_out = jax.tree.map( + functools.partial( + _extract_index_mappings, index_mappings=index_mappings + ), + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), + ) + + return (pure_args_out, pure_out), pure_residual @dataclasses.dataclass(eq=False) @@ -484,29 +512,30 @@ def __post_init__(self): functools.update_wrapper(self, self.bwd) def __call__(self, *args): - res: tuple[CustomVjpMetadata, tp.Any] - pure_g: tuple[tp.Any, tp.Any] - *nondiff, res, pure_g = args - metadata, pure_residual = res - nondiff = extract.from_tree(nondiff) + metadata: CustomVjpMetadata + *nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args + metadata, *nondiff = nondiff + # nondiff = extract.from_tree(nondiff) residual = extract.from_tree(pure_residual) - pure_g = jax.tree.map( + (pure_args_out_g, pure_out_g) = jax.tree.map( lambda x: x.state if isinstance(x, extract.NodeStates) else x, - pure_g, + (pure_args_out_g, pure_out_g), is_leaf=lambda x: isinstance(x, extract.NodeStates), ) - tangent = self.bwd(*nondiff, residual, pure_g) + tangent = self.bwd(*nondiff, residual, (pure_args_out_g, pure_out_g)) - def state_to_tree_node(is_tree_node: bool, x): - if is_tree_node: - if not isinstance(x, State): + def state_to_node_states(is_differentiable: bool, x): + if is_differentiable: + if isinstance(x, jax.Array): + return x + elif not isinstance(x, State): raise ValueError(f'Expected State, got {type(x)}') return extract.NodeStates.from_states(x) return x pure_tangent = jax.tree.map( - state_to_tree_node, + state_to_node_states, metadata.tangent_tree_node_args, tangent, is_leaf=lambda x: isinstance(x, State), @@ -518,30 +547,26 @@ class CustomVjp(tp.Generic[A]): def __init__( self, fun: tp.Callable[..., A], - nondiff_argnums: tuple[int | DiffState, ...], + nondiff_argnums: tuple[int, ...], ): functools.update_wrapper(self, fun) - jax_nondiff_argnums = tuple( - x.argnum if isinstance(x, DiffState) else x for x in nondiff_argnums - ) + # first argument is metadata + jax_nondiff_argnums = (0,) + tuple(1 + x for x in nondiff_argnums) self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}' self.custom_vjp_fn = jax.custom_vjp( CustomVjpFnWrapper(fun, self.ctxtag), nondiff_argnums=jax_nondiff_argnums, ) self.nondiff_argnums = nondiff_argnums - self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {} - for argnum in self.nondiff_argnums: - index = argnum.argnum if isinstance(argnum, DiffState) else argnum + self.diff_filter: dict[int, tp.Literal[False]] = {} + for index in self.nondiff_argnums: if index in self.diff_filter: raise ValueError(f'argnum {index} is repeated in nondiff_argnums') - self.diff_filter[index] = ( - dataclasses.replace(argnum, argnum=-1) - if isinstance(argnum, DiffState) - else False - ) + self.diff_filter[index] = False def __getattr__(self, name: str) -> tp.Any: + if not hasattr(self.custom_vjp_fn, name): + raise AttributeError(f'{self.__class__.__name__} has no attribute {name}') return getattr(self.custom_vjp_fn, name) def __call__( @@ -550,20 +575,17 @@ def __call__( with graph.update_context(self.ctxtag): args = resolve_kwargs(self.custom_vjp_fn, args, kwargs) del kwargs - nondiff_states: deque[extract.GraphDefState] = deque() arg_filters = tuple( self.diff_filter.get(i, True) for i in range(len(args)) ) pure_args = extract.to_tree( args, prefix=arg_filters, - split_fn=functools.partial( - _custom_vjp_split_fn, nondiff_states=nondiff_states - ), + split_fn=_custom_vjp_split_fn, ctxtag=self.ctxtag, ) tangent_args = tp.cast( - tuple[tp.Literal[True] | DiffState, ...], + tuple[tp.Literal[True], ...], tuple(x for x in arg_filters if x is not False), ) tree_node_args = jax.tree.map( @@ -571,15 +593,29 @@ def __call__( pure_args, is_leaf=lambda x: isinstance(x, extract.NodeStates), ) + # TODO(cgarciae): why is this unused? tangent_tree_node_args = tuple( arg for arg, is_tree_node in zip(args, tree_node_args) if is_tree_node is not False ) - metadata = CustomVjpMetadata(tangent_args) - - with extract.broadcast_state(self.ctxtag, (metadata, nondiff_states)): - pure_args_out, pure_out = self.custom_vjp_fn(*pure_args) + index_mappings: deque[FrozenDict] = deque() + metadata = CustomVjpMetadata(self.nondiff_argnums, tangent_args) + with extract.broadcast_state(self.ctxtag, index_mappings): + pure_args_out, pure_out = self.custom_vjp_fn(metadata, *pure_args) + + # insert index_mappings + def _insert_index_mappings(x): + if isinstance(x, graph.NodeDef): + index_mapping: FrozenDict = index_mappings.popleft() + return dataclasses.replace(x, index_mapping=index_mapping) + return x + + pure_args_out, pure_out = jax.tree_util.tree_map( + _insert_index_mappings, + (pure_args_out, pure_out), + is_leaf=lambda x: isinstance(x, graph.NodeDef), + ) args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag=self.ctxtag @@ -679,17 +715,17 @@ def f_bwd(res, g): def custom_vjp( fun: tp.Callable[..., A], *, - nondiff_argnums: tuple[int | DiffState, ...] = (), + nondiff_argnums: tuple[int, ...] = (), ) -> CustomVjp[A]: ... @tp.overload def custom_vjp( *, - nondiff_argnums: tuple[int | DiffState, ...] = (), + nondiff_argnums: tuple[int, ...] = (), ) -> tp.Callable[[tp.Callable[..., A]], CustomVjp[A]]: ... def custom_vjp( fun: tp.Callable[..., A] | Missing = MISSING, *, - nondiff_argnums: tuple[int | DiffState, ...] = (), + nondiff_argnums: tuple[int, ...] = (), ) -> CustomVjp[A] | tp.Callable[[tp.Callable[..., A]], CustomVjp[A]]: """Reference aware version of `jax.custom_vjp `__. diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 5f478c4328..090788f8a3 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -598,7 +598,7 @@ def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): self.assertIn('bias', grads_m2[0]) -class TestCustomVJP(absltest.TestCase): +class TestCustomVJP(parameterized.TestCase): def test_basic_call(self): m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) @@ -644,16 +644,16 @@ def f_fwd(m: Foo): return y, res def f_bwd(res, g): - inputs_g, out_g = g + (m_g,), out_g = g cos_x, sin_x, m = res - self.assertIsInstance(inputs_g, tuple) - self.assertLen(inputs_g, 1) - self.assertIsInstance(inputs_g[0], nnx.State) + self.assertIsInstance(m_g, nnx.State) self.assertEqual(out_g.shape, ()) self.assertIsInstance(m, Foo) - m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g return (m_g,) f.defvjp(f_fwd, f_bwd) @@ -666,6 +666,51 @@ def f_bwd(res, g): np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore self.assertEqual(m.z, 1) + def test_jax_example_with_remat(self): + @dataclasses.dataclass + class Foo(nnx.Module): + x: nnx.Param[jax.Array] + y: nnx.Param[jax.Array] + z: int + + @nnx.custom_vjp + @nnx.remat + def f(m: Foo): + m.z += 1 + return jnp.sin(m.x.value) * m.y # type: ignore + + def f_fwd(m: Foo): + y = f(m) + res = (jnp.cos(m.x.value), jnp.sin(m.x.value), m) # type: ignore + return y, res + + def f_bwd(res, g): + (m_g,), out_g = g + cos_x, sin_x, m = res + + self.assertIsInstance(m_g, nnx.State) + self.assertEqual(out_g.shape, ()) + self.assertIsInstance(m, Foo) + + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g + return (m_g,) + + f.defvjp(f_fwd, f_bwd) + + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) + + @nnx.jit + def loss_fn(m): + return f(m) + + grad: nnx.State = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) + + np.testing.assert_allclose(grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 1) + def test_two_args(self): @dataclasses.dataclass class Foo(nnx.Module): @@ -726,45 +771,49 @@ class Foo(nnx.Module): y: nnx.Param[jax.Array] z: int - @nnx.custom_vjp(nondiff_argnums=(1, 2)) - def f(m1: Foo, m2: Foo, m3): - m1.z += 1 - y = jnp.sin(m1.x) * m1.y # type: ignore - return y, m2 + @nnx.custom_vjp(nondiff_argnums=(0, 2)) + def f(a, m: Foo, b): + self.assertEqual(a, 1) + self.assertEqual(b, 2) + m.z += 1 + return jnp.sin(m.x) * m.y # type: ignore - def f_fwd(m1: Foo, m2: Foo, m3): - y, m2 = f(m1, m2, m3) - res = (jnp.cos(m1.x), jnp.sin(m1.x), m1) # type: ignore - return (y, m2), res + def f_fwd(a, m: Foo, b): + self.assertEqual(a, 1) + self.assertEqual(b, 2) + y = f(a, m, b) + res = (jnp.cos(m.x), jnp.sin(m.x), m) # type: ignore + return y, res - def f_bwd(m2, m3, res, g): - (m1_g, m2_g, m3_g), (y_g, _) = g + def f_bwd(a, b, res, g): + (m_g,), out_g = g cos_x, sin_x, m = res - self.assertIsInstance(m1_g, nnx.State) - self.assertIsInstance(m2_g, nnx.State) - self.assertEqual(y_g.shape, ()) + self.assertEqual(a, 1) + self.assertEqual(b, 2) + self.assertIsInstance(m_g, nnx.State) + self.assertEqual(out_g.shape, ()) self.assertIsInstance(m, Foo) - m1_g = nnx.State(dict(x=cos_x * y_g * m.y, y=sin_x * y_g)) - - return (m1_g,) + # m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g}) + m_g.x.value = cos_x * out_g * m.y + m_g.y.value = sin_x * out_g + return (m_g,) f.defvjp(f_fwd, f_bwd) - m1 = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - m2 = Foo(nnx.Param(jnp.array(3.0)), nnx.Param(jnp.array(4.0)), 0) + m = Foo(nnx.Param(jnp.array(1.0)), nnx.Param(jnp.array(2.0)), 0) - def loss_fn(m1, m2, m3): - y, m2 = f(m1, m2, m3) - return y + m2.x * m2.y + def loss_fn(m): + a = 1 + b = 2 + return f(a, m, b) - m1_grad: nnx.State - m1_grad = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m1, m2, m2) + grad: nnx.State = nnx.grad(loss_fn, argnums=nnx.DiffState(0, ...))(m) - np.testing.assert_allclose(m1_grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore - np.testing.assert_allclose(m1_grad['y'].value, jnp.sin(1.0)) # type: ignore - self.assertEqual(m1.z, 1) + np.testing.assert_allclose(grad['x'].value, jnp.cos(1.0) * 2.0) # type: ignore + np.testing.assert_allclose(grad['y'].value, jnp.sin(1.0)) # type: ignore + self.assertEqual(m.z, 1) def test_docs_example(self): import jax.numpy as jnp @@ -794,6 +843,56 @@ def f_bwd(res, g): m = Foo(x=jnp.array(1.0), y=jnp.array(2.0)) grads = nnx.grad(f)(m) + @parameterized.parameters( + {'use_custom_vjp': False}, + {'use_custom_vjp': True}, + ) + def test_issue(self, use_custom_vjp: bool): + class MyLinear(nnx.Module): + def __init__( + self, in_features: int, out_features: int, *, rngs: nnx.Rngs + ): + kernel_init = nnx.initializers.normal(in_features**-0.5) + self.kernel = nnx.Param( + kernel_init(rngs.params(), (in_features, out_features), jnp.float32) + ) + self.bias = nnx.Param(jnp.zeros((out_features,), jnp.float32)) + + def linear(m: MyLinear, x: jax.Array) -> jax.Array: + y = x @ m.kernel + m.bias + return y + + def linear_fwd(m: MyLinear, x: jax.Array): + return linear(m, x), (m, x) + + def linear_bwd(res, g): + m, x = res + (m_g, _x_grad), outputs_g = g + kernel_grad = outputs_g[None, :] * x[:, None] + bias_grad = outputs_g + x_grad = m.kernel @ outputs_g + assert x_grad.shape == x.shape, 'Shape mismatch for x' + assert ( + m.kernel.value.shape == kernel_grad.shape + ), 'Shape mismatch for kernel' + assert m.bias.value.shape == bias_grad.shape, 'Shape mismatch for bias' + return (m_g, x_grad) + + if use_custom_vjp: + linear = nnx.custom_vjp(linear) + linear.defvjp(linear_fwd, linear_bwd) + + @nnx.jit + def loss_fn(x): + mod = MyLinear(10, 5, rngs=nnx.Rngs(0)) + y = linear(mod, x) + return y.mean() + + x = jax.random.normal(jax.random.key(0), (10,)) + loss, grad = nnx.value_and_grad(loss_fn)(x) + self.assertEqual(loss.shape, ()) + self.assertEqual(grad.shape, (10,)) + class TestScan(absltest.TestCase): def test_basic(self):