Skip to content

Commit

Permalink
[nnx] fix custom_vjp
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 21, 2024
1 parent c692114 commit ed402f6
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 118 deletions.
206 changes: 121 additions & 85 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx import (
extract,
filterlib,
Expand Down Expand Up @@ -362,84 +363,92 @@ def value_and_grad(
return_value=True,
)

# -----------------------------------------------
# custom_vjp
# -----------------------------------------------
# custom_vjp is one of the most complicated transforms as it requires
# to handle 4 different functions:
# 1. CustomVJP: the main object that runs the outer logic, converts input graph nodes
# to pytrees and output pytrees to graph nodes.
# 2. CustomVjpFnWrapper: function that wraps the user's function, it converts
# its input pytrees to graph nodes and output graph nodes to pytrees.
# 3. FwdFn: wraps the user's fwd function, it converts its input pytrees to graph nodes
# and output graph nodes to pytrees. Since it might run by itself in a separate context,
# it needs to be aware if the update_context is active or not in order to update the outer
# referenes.
# 4. BwdFn: wraps the user's bwd function, it converts its input pytrees to graph nodes
# and output graph nodes to pytrees. It doesn't need to be aware of the outer context
# since it will never update the outer references as it runs during the backward pass.

def _custom_vjp_merge_fn(
ctx: graph.MergeContext,
path,
prefix: bool | DiffState,
value: extract.NodeStates,
*,
nondiff_states: deque[extract.GraphDefState],
ctx: graph.MergeContext, path, prefix: bool, value: extract.NodeStates
):
nondiff = nondiff_states.popleft()
return ctx.merge(nondiff.graphdef, value.state, nondiff.state)
return ctx.merge(value.graphdef, value.state)


def _custom_vjp_split_fn(
ctx: graph.SplitContext,
path,
prefix: bool | DiffState,
value,
*,
nondiff_states: deque[extract.GraphDefState],
):
def _custom_vjp_split_fn(ctx: graph.SplitContext, path, prefix: bool, value):
if prefix is False:
# pure non-differentiable arg, we pass all the state through
# but we return TreeNode.from_split with a graphdef to we can call from_tree
# on the nondiff args during the backward pass
graphdef, passed = ctx.split(value)
broadcast = State({}) # type: ignore[var-annotated]
nondiff_states.append(extract.GraphDefState(graphdef, broadcast))
return extract.NodeStates.from_split(graphdef, passed)
raise TypeError('graph nodes cannot appear in non-differentiable arguments')
elif prefix is True:
# pure differentiable arg, we pass all the state through
# but we return a TreeNode.from_states which doesn't have a graphdef
# in order to keep the gradients clean from any metadata
graphdef, passed = ctx.split(value)
broadcast = State({})
nondiff_states.append(extract.GraphDefState(graphdef, broadcast))
return extract.NodeStates.from_states(passed)
return extract.NodeStates.from_split(graphdef, passed)
else:
# differentiable arg with DiffState filter, we use the filter to split the state
# as before we return a TreeNode.from_states to keep the gradients clean
# from any metadata, the non-differentiable state is stored in a deque
# which is broadcasted during the forward pass
graphdef, passed, broadcast = ctx.split(value, prefix.filter, ...) # type: ignore[misc]
nondiff_states.append(extract.GraphDefState(graphdef, broadcast))
return extract.NodeStates.from_states(passed)
raise NotImplementedError('DiffState prefix are not yet supported')


class CustomVjpMetadata(struct.PyTreeNode):
nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False)
tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False)

def _extract_index_mappings(x, *, index_mappings: deque[FrozenDict]):
if isinstance(x, graph.NodeDef):
assert x.index_mapping is not None
index_mappings.append(x.index_mapping)
return dataclasses.replace(x, index_mapping=None)
return x

@dataclasses.dataclass(eq=False)
class CustomVjpFnWrapper:
f: tp.Callable[..., tp.Any]
ctxtag: str

def __post_init__(self):
functools.update_wrapper(self, self.f)
# functools.update_wrapper(self, self.f)
pass

def __call__(self, *pure_args):
broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = (
extract.get_broadcast_state(self.ctxtag)
)
metadata, nondiff_states = broadcast
def __call__(self, metadata: CustomVjpMetadata, *pure_args):
args = extract.from_tree(
pure_args,
merge_fn=functools.partial(
_custom_vjp_merge_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag,
pure_args, merge_fn=_custom_vjp_merge_fn, ctxtag=self.ctxtag
)

out = self.f(*args)

args_out = extract.clear_non_graph_nodes(args)
# remove nondiff from pure_args_out_g
args_out = tuple(
x for i, x in enumerate(args) if i not in metadata.nondiff_argnums
)
args_out = extract.clear_non_graph_nodes(args_out)
pure_args_out, pure_out = extract.to_tree(
(args_out, out), ctxtag=self.ctxtag
)
# remove index_mapping from NodeDef's but store them in global context
index_mappings: deque[FrozenDict] = extract.get_broadcast_state(self.ctxtag)

pure_args_out, pure_out = jax.tree.map(
functools.partial(_extract_index_mappings, index_mappings=index_mappings),
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graph.NodeDef),
)

return pure_args_out, pure_out

Expand All @@ -452,28 +461,47 @@ class FwdFn:
def __post_init__(self):
functools.update_wrapper(self, self.fwd)

def __call__(self, *pure_args):
broadcast: tuple[CustomVjpMetadata, deque[extract.GraphDefState]] = (
extract.get_broadcast_state(self.ctxtag)
def __call__(self, metadata: CustomVjpMetadata, *pure_args):
# here we need to be aware if the update_context is active or not
# when its not active, index_mappings will be None
# when its active, we will remove the index_mappings from the NodeDef's and store them
# in the index_mappings deque created by CustomVjp
update_context_active = (
self.ctxtag in graph.GRAPH_CONTEXT.update_context_stacks
)
metadata, nondiff_states = broadcast
args = extract.from_tree(
pure_args,
merge_fn=functools.partial(
_custom_vjp_merge_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag,
merge_fn=_custom_vjp_merge_fn,
ctxtag=self.ctxtag if update_context_active else None,
)

out, residual = self.fwd(*args)

args_out = extract.clear_non_graph_nodes(args)
# remove nondiff from pure_args_out_g
args_out = tuple(
x for i, x in enumerate(args) if i not in metadata.nondiff_argnums
)
args_out = extract.clear_non_graph_nodes(args_out)
pure_args_out, pure_out = extract.to_tree(
(args_out, out), ctxtag=self.ctxtag
(args_out, out),
ctxtag=self.ctxtag if update_context_active else None,
)
pure_residual = extract.to_tree(residual)

return (pure_args_out, pure_out), (metadata, pure_residual)
if update_context_active:
# remove index_mapping from NodeDef's but store them in global context
index_mappings: deque[FrozenDict] = extract.get_broadcast_state(
self.ctxtag
)
pure_args_out, pure_out = jax.tree.map(
functools.partial(
_extract_index_mappings, index_mappings=index_mappings
),
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graph.NodeDef),
)

return (pure_args_out, pure_out), pure_residual


@dataclasses.dataclass(eq=False)
Expand All @@ -484,29 +512,30 @@ def __post_init__(self):
functools.update_wrapper(self, self.bwd)

def __call__(self, *args):
res: tuple[CustomVjpMetadata, tp.Any]
pure_g: tuple[tp.Any, tp.Any]
*nondiff, res, pure_g = args
metadata, pure_residual = res
nondiff = extract.from_tree(nondiff)
metadata: CustomVjpMetadata
*nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args
metadata, *nondiff = nondiff
# nondiff = extract.from_tree(nondiff)
residual = extract.from_tree(pure_residual)
pure_g = jax.tree.map(
(pure_args_out_g, pure_out_g) = jax.tree.map(
lambda x: x.state if isinstance(x, extract.NodeStates) else x,
pure_g,
(pure_args_out_g, pure_out_g),
is_leaf=lambda x: isinstance(x, extract.NodeStates),
)

tangent = self.bwd(*nondiff, residual, pure_g)
tangent = self.bwd(*nondiff, residual, (pure_args_out_g, pure_out_g))

def state_to_tree_node(is_tree_node: bool, x):
if is_tree_node:
if not isinstance(x, State):
def state_to_node_states(is_differentiable: bool, x):
if is_differentiable:
if isinstance(x, jax.Array):
return x
elif not isinstance(x, State):
raise ValueError(f'Expected State, got {type(x)}')
return extract.NodeStates.from_states(x)
return x

pure_tangent = jax.tree.map(
state_to_tree_node,
state_to_node_states,
metadata.tangent_tree_node_args,
tangent,
is_leaf=lambda x: isinstance(x, State),
Expand All @@ -518,30 +547,26 @@ class CustomVjp(tp.Generic[A]):
def __init__(
self,
fun: tp.Callable[..., A],
nondiff_argnums: tuple[int | DiffState, ...],
nondiff_argnums: tuple[int, ...],
):
functools.update_wrapper(self, fun)
jax_nondiff_argnums = tuple(
x.argnum if isinstance(x, DiffState) else x for x in nondiff_argnums
)
# first argument is metadata
jax_nondiff_argnums = (0,) + tuple(1 + x for x in nondiff_argnums)
self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}'
self.custom_vjp_fn = jax.custom_vjp(
CustomVjpFnWrapper(fun, self.ctxtag),
nondiff_argnums=jax_nondiff_argnums,
)
self.nondiff_argnums = nondiff_argnums
self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {}
for argnum in self.nondiff_argnums:
index = argnum.argnum if isinstance(argnum, DiffState) else argnum
self.diff_filter: dict[int, tp.Literal[False]] = {}
for index in self.nondiff_argnums:
if index in self.diff_filter:
raise ValueError(f'argnum {index} is repeated in nondiff_argnums')
self.diff_filter[index] = (
dataclasses.replace(argnum, argnum=-1)
if isinstance(argnum, DiffState)
else False
)
self.diff_filter[index] = False

def __getattr__(self, name: str) -> tp.Any:
if not hasattr(self.custom_vjp_fn, name):
raise AttributeError(f'{self.__class__.__name__} has no attribute {name}')
return getattr(self.custom_vjp_fn, name)

def __call__(
Expand All @@ -550,36 +575,47 @@ def __call__(
with graph.update_context(self.ctxtag):
args = resolve_kwargs(self.custom_vjp_fn, args, kwargs)
del kwargs
nondiff_states: deque[extract.GraphDefState] = deque()
arg_filters = tuple(
self.diff_filter.get(i, True) for i in range(len(args))
)
pure_args = extract.to_tree(
args,
prefix=arg_filters,
split_fn=functools.partial(
_custom_vjp_split_fn, nondiff_states=nondiff_states
),
split_fn=_custom_vjp_split_fn,
ctxtag=self.ctxtag,
)
tangent_args = tp.cast(
tuple[tp.Literal[True] | DiffState, ...],
tuple[tp.Literal[True], ...],
tuple(x for x in arg_filters if x is not False),
)
tree_node_args = jax.tree.map(
lambda x: isinstance(x, extract.NodeStates),
pure_args,
is_leaf=lambda x: isinstance(x, extract.NodeStates),
)
# TODO(cgarciae): why is this unused?
tangent_tree_node_args = tuple(
arg
for arg, is_tree_node in zip(args, tree_node_args)
if is_tree_node is not False
)
metadata = CustomVjpMetadata(tangent_args)

with extract.broadcast_state(self.ctxtag, (metadata, nondiff_states)):
pure_args_out, pure_out = self.custom_vjp_fn(*pure_args)
index_mappings: deque[FrozenDict] = deque()
metadata = CustomVjpMetadata(self.nondiff_argnums, tangent_args)
with extract.broadcast_state(self.ctxtag, index_mappings):
pure_args_out, pure_out = self.custom_vjp_fn(metadata, *pure_args)

# insert index_mappings
def _insert_index_mappings(x):
if isinstance(x, graph.NodeDef):
index_mapping: FrozenDict = index_mappings.popleft()
return dataclasses.replace(x, index_mapping=index_mapping)
return x

pure_args_out, pure_out = jax.tree_util.tree_map(
_insert_index_mappings,
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graph.NodeDef),
)

args_out, out = extract.from_tree(
(pure_args_out, pure_out), ctxtag=self.ctxtag
Expand Down Expand Up @@ -679,17 +715,17 @@ def f_bwd(res, g):
def custom_vjp(
fun: tp.Callable[..., A],
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
nondiff_argnums: tuple[int, ...] = (),
) -> CustomVjp[A]: ...
@tp.overload
def custom_vjp(
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
nondiff_argnums: tuple[int, ...] = (),
) -> tp.Callable[[tp.Callable[..., A]], CustomVjp[A]]: ...
def custom_vjp(
fun: tp.Callable[..., A] | Missing = MISSING,
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
nondiff_argnums: tuple[int, ...] = (),
) -> CustomVjp[A] | tp.Callable[[tp.Callable[..., A]], CustomVjp[A]]:
"""Reference aware version of
`jax.custom_vjp <https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_vjp.html>`__.
Expand Down
Loading

0 comments on commit ed402f6

Please sign in to comment.