diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index b8ae5e417b..0900b9a91e 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -80,6 +80,7 @@ def body_fn(b, c, x): the function that performs the scan of the form: (broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out). """ + from flax.linen.module import tabulate_context def transpose_to_front(ax, xs): if ax is broadcast: @@ -142,12 +143,13 @@ def body_fn(c, xs, init_mode=False): ) input_avals = (carry_avals, scan_avals) - in_avals, in_tree = jax.tree_util.tree_flatten(input_avals) - f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( - lu.wrap_init(broadcast_body), in_tree - ) - in_pvals = list(map(pe.PartialVal.unknown, in_avals)) - _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) + with tabulate_context(add_call_info=False): + in_avals, in_tree = jax.tree_util.tree_flatten(input_avals) + f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( + lu.wrap_init(broadcast_body), in_tree + ) + in_pvals = list(map(pe.PartialVal.unknown, in_avals)) + _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) out_flat = [] for pv, const in out_pvals: diff --git a/flax/core/lift.py b/flax/core/lift.py index 1bc132fddf..54e5156028 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -33,19 +33,20 @@ ) import warnings -from flax import traceback_util -import jax -from jax import random from . import axes_scan from . import meta +from flax import traceback_util, traverse_util from .frozen_dict import freeze from .frozen_dict import unfreeze +import jax +import jax.numpy as jnp +from jax import random from .scope import ( CollectionFilter, - DenyList, # pylint: disable=g-multiple-import + DenyList, + PRNGSequenceFilter, # pylint: disable=g-multiple-import Filter, - PRNGSequenceFilter, Scope, group_collections, in_filter, @@ -575,7 +576,11 @@ def wrapper(vars_primals, args): treedef = jax.tree_util.tree_structure(scope) variable_tangents = tuple( - {k: v for k, v in vt.items() if v} # pylint: disable=g-complex-comprehension + { + k: v # pylint: disable=g-complex-comprehension + for k, v in vt.items() + if v + } for vt in treedef.flatten_up_to(variable_tangents) ) target = tuple(variable_tangents[0].keys()) @@ -817,6 +822,8 @@ def body_fn(scope, c, x): ``(scope, carry, *xxs) -> (carry, yys)``, where ``xxs`` and ``yys`` are the scan values that go in and out of the loop. """ + from flax.linen.module import tabulate_context + variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes) variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items()) variable_out_groups, variable_out_axes = _unzip2(variable_out_axes.items()) @@ -853,16 +860,19 @@ def find_length(axis, x): for rng_group, split in zip(rng_groups, rng_splits) ) - @functools.partial( - axes_scan.scan, + carry_vars_new_axes = 0 + scan_partial = lambda length, unroll: axes_scan.scan( + scanned, in_axes=(variable_in_axes, rng_axes, in_axes), - out_axes=(out_axes, variable_out_axes), - length=length, + out_axes=(out_axes, variable_out_axes, carry_vars_new_axes), reverse=reverse, unroll=unroll, + length=length, ) + def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): carry_vars, c = carry + variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups if data_transform is not None: variable_groups, rng_groups = data_transform( @@ -872,15 +882,27 @@ def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): c, y = fn(scope, c, *args) out_vars = repack_fn(scope) broadcast_vars_out = out_vars[0] - carry_vars = out_vars[1] + carry_vars_out = out_vars[1] scan_vars = out_vars[2:] + + # compute new carry vars, these will be handled as outputs + carry_vars_new = tuple( + vars_diff(outputs, inputs) + for outputs, inputs in zip(carry_vars_out, carry_vars) + ) + # remove new carry vars to maintain input shape + carry_vars = tuple( + vars_diff(outputs, new) + for outputs, new in zip(carry_vars_out, carry_vars_new) + ) + # add immutable broadcast vars back to broadcast output # otherwise they won't be fed to the actual scan body for in_group, out_group in zip(broadcast_vars, broadcast_vars_out): for col in in_group: if col not in out_group: out_group[col] = in_group[col] - return broadcast_vars_out, (carry_vars, c), (y, scan_vars) + return broadcast_vars_out, (carry_vars, c), (y, scan_vars, carry_vars_new) broadcast_vars = variable_groups[0] carry_vars = variable_groups[1] @@ -888,21 +910,114 @@ def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): new_scan_vars = [] for scan_group, axis in zip(scan_vars, variable_in_axes): new_scan_vars.append(meta.remove_axis(scan_group, axis, metadata_params)) - broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned( - broadcast_vars, - (carry_vars, init), - tuple(new_scan_vars), - rng_groups, - args, - ) + + # compute new carry vars + with tabulate_context( + add_call_info=False + ): # dont add call info while tracing + carry_vars_new = jax.eval_shape( + scan_partial(length, unroll), + broadcast_vars, + (carry_vars, init), + tuple(new_scan_vars), + rng_groups, + args, + )[2][2] + has_new_carry_vars = len(jax.tree_util.tree_leaves(carry_vars_new)) > 0 + + if has_new_carry_vars: + new_scan_vars0, rng_groups0, args0 = tree_map_upto_left( + lambda axis, tree: jax.tree_map( + lambda x: jax.lax.dynamic_slice_in_dim(x, 0, 1, axis), + tree, + ), + left=(variable_in_axes, rng_axes, in_axes), + right=(tuple(new_scan_vars), rng_groups, args), + ) + # run scan for 1 step + partial_length = 1 if length is not None else None + with tabulate_context( + add_call_info=False + ): # dont add call info on first step + ( + broadcast_vars, + (carry_vars, init), + (ys1, scan_vars1, carry_vars_new), + ) = scan_partial(partial_length, 1)( + broadcast_vars, + (carry_vars, init), + new_scan_vars0, + rng_groups0, + args0, + ) + # slice new carry vars and merge with existing + carry_vars_new = jax.tree_map(lambda x: x[0], carry_vars_new) + carry_vars = tuple( + vars_merge(existing, new) + for existing, new in zip(carry_vars, carry_vars_new) + ) + # slice rest of the inputs + new_scan_vars_rest, rng_groups_rest, args_rest = tree_map_upto_left( + lambda axis, tree: jax.tree_map( + lambda x: jax.lax.dynamic_slice_in_dim( + x, 1, x.shape[axis] - 1, axis + ), + tree, + ), + left=(variable_in_axes, rng_axes, in_axes), + right=(tuple(new_scan_vars), rng_groups, args), + ) + # run scan on the rest of the inputs + partial_length = length - 1 if length is not None else None + ( + broadcast_vars, + (carry_vars, c), + (ys_rest, scan_vars_rest, carry_vars_new), + ) = scan_partial(partial_length, unroll)( + broadcast_vars, + (carry_vars, init), + new_scan_vars_rest, + rng_groups_rest, + args_rest, + ) + # concat ys and scan_vars + ys = tree_map_upto_left( + lambda axis, tuple_tree: jax.tree_map( + lambda a, b: jnp.concatenate((a, b), axis=axis), + *tuple_tree, + ), + left=out_axes, + right=(ys1, ys_rest), + ) + scan_vars = tree_map_upto_left( + lambda axis, tuple_tree: jax.tree_map( + lambda a, b: jnp.concatenate((a, b), axis=axis), + *tuple_tree, + ), + left=variable_out_axes, + right=((scan_vars1, scan_vars_rest),), + )[0] + else: + ( + broadcast_vars, + (carry_vars, c), + (ys, scan_vars, carry_vars_new), + ) = scan_partial(length, unroll)( + broadcast_vars, + (carry_vars, init), + tuple(new_scan_vars), + rng_groups, + args, + ) + + has_new_carry_vars = len(jax.tree_util.tree_leaves(carry_vars_new)) > 0 + assert not has_new_carry_vars + new_scan_vars = [] for scan_group, axis in zip(scan_vars, variable_out_axes): new_scan_vars.append(meta.add_axis(scan_group, axis, metadata_params)) scan_vars = tuple(new_scan_vars) - out_vars = ( - broadcast_vars, - carry_vars, - ) + scan_vars + out_vars = (broadcast_vars, carry_vars) + scan_vars return (c, ys), out_vars return pack( @@ -1528,3 +1643,32 @@ def inner_loop(scope, carry): def _unzip2(xs): ys = tuple(zip(*xs)) return ys if ys else ((), ()) + + +def vars_diff(a, b): + a = traverse_util.flatten_dict(a, sep='/') + b = traverse_util.flatten_dict(b, sep='/') + + c = {path: value for path, value in a.items() if path not in b} + c = traverse_util.unflatten_dict(c, sep='/') + return c + + +def vars_merge(a, b): + a = traverse_util.flatten_dict(a, sep='/') + b = traverse_util.flatten_dict(b, sep='/') + a.update(b) + c = traverse_util.unflatten_dict(a, sep='/') + return c + + +def tree_map_upto_left( + f: Callable[[Any, Any], Any], left: Any, right: Any +) -> Any: + leaves_left, treedef = jax.tree_util.tree_flatten(left) + leaves_right = treedef.flatten_up_to(right) + + return treedef.unflatten( + f(left_leaf, right_leaf) + for left_leaf, right_leaf in zip(leaves_left, leaves_right) + ) diff --git a/flax/linen/module.py b/flax/linen/module.py index 39bcadadc5..0e96643a23 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -187,8 +187,13 @@ def get_call_index(self, module: 'Module') -> int: @contextlib.contextmanager -def _tabulate_context(): - _context.call_info_stack.append(_CallInfoContext(0, [])) +def tabulate_context(add_call_info: bool = True): + if _context.call_info_stack and _context.call_info_stack[-1] is None: + _context.call_info_stack.append(None) + elif add_call_info: + _context.call_info_stack.append(_CallInfoContext(0, [])) + else: + _context.call_info_stack.append(None) try: yield finally: @@ -204,11 +209,9 @@ class _DynamicContext(threading.local): # 3.7 def __init__(self): - self.module_stack = [ - None, - ] + self.module_stack = [None] self.capture_stack = [] - self.call_info_stack = [] + self.call_info_stack: list[Optional[_CallInfoContext]] = [] # The global context @@ -937,7 +940,11 @@ def _call_wrapped_method(self, fun, args, kwargs): is_compact_method = hasattr(fun, 'compact') fun_name = getattr(fun, '__name__', 'unnamed_function') is_setup_method = fun_name == 'setup' - add_call_info = not is_setup_method and len(_context.call_info_stack) > 0 + add_call_info = ( + not is_setup_method + and _context.call_info_stack + and _context.call_info_stack[-1] is not None + ) # We lazily call setup() only when needed. if is_setup_method: if self.scope is None: @@ -957,6 +964,7 @@ def _call_wrapped_method(self, fun, args, kwargs): # get call info if add_call_info: assert self.scope is not None + assert _context.call_info_stack[-1] is not None call_index = _context.call_info_stack[-1].get_call_index(self) scope_path = jax.tree_util.tree_map(_fix_path_part, self.scope.path) @@ -972,6 +980,7 @@ def _call_wrapped_method(self, fun, args, kwargs): if filter_fn and filter_fn(self, fun_name): self.sow('intermediates', fun_name, y) if add_call_info: + assert _context.call_info_stack[-1] is not None _args, _kwargs, _y = flax.linen.summary._represent_tree( (args, kwargs, y) ) diff --git a/flax/linen/summary.py b/flax/linen/summary.py index 67c195c337..a7f7af7ab1 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -271,7 +271,7 @@ def _get_module_table( but returns the Table representation of the Module.""" def _get_table_fn(*args, **kwargs): - with module_lib._tabulate_context(): + with module_lib.tabulate_context(): def _get_variables(): return module.init(*args, **kwargs) diff --git a/tests/linen/linen_recurrent_test.py b/tests/linen/linen_recurrent_test.py index 2f09a487bf..694fa7f329 100644 --- a/tests/linen/linen_recurrent_test.py +++ b/tests/linen/linen_recurrent_test.py @@ -289,10 +289,10 @@ def test_numerical_equivalence_single_batch_nn_scan(self): cell_carry, xs[batch_idx : batch_idx + 1, i, :], ) - np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-5) + np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-4) carry_i = jax.tree_map(lambda x: x[batch_idx : batch_idx + 1], carry) - np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-5) + np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-4) def test_numerical_equivalence_single_batch_jax_scan(self): batch_size = 3 diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 6c9b5c03ff..baefdd8d8a 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -1034,6 +1034,38 @@ def __call__(self, x): } self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes)) + def test_scan_carry_init(self): + class Block(nn.Module): + train: bool + + @nn.compact + def __call__(self, x, _): + x = nn.Dense(3)(x) + x = nn.BatchNorm(use_running_average=not self.train)(x) + x = nn.relu(x) + return x, None + + MLP = nn.scan( + Block, + variable_axes={'params': 0}, + split_rngs={'params': True}, + variable_carry='batch_stats', + length=5, + ) + + mlp = MLP(train=True) + + variables = mlp.init(random.PRNGKey(0), jnp.ones((1, 3)), None) + + self.assertEqual( + variables['batch_stats']['BatchNorm_0']['mean'].shape, (3,) + ) + self.assertEqual(variables['batch_stats']['BatchNorm_0']['var'].shape, (3,)) + self.assertEqual(variables['params']['BatchNorm_0']['scale'].shape, (5, 3)) + self.assertEqual(variables['params']['BatchNorm_0']['bias'].shape, (5, 3)) + self.assertEqual(variables['params']['Dense_0']['kernel'].shape, (5, 3, 3)) + self.assertEqual(variables['params']['Dense_0']['bias'].shape, (5, 3)) + def test_variable_in_args_transform(self): class Test(nn.Module): diff --git a/tests/linen/summary_test.py b/tests/linen/summary_test.py index eb2f5624b7..77ea1874e0 100644 --- a/tests/linen/summary_test.py +++ b/tests/linen/summary_test.py @@ -417,12 +417,11 @@ def __call__(self, x): lstm = LSTM(features=128) - with jax.check_tracer_leaks(True): - module_repr = lstm.tabulate( - random.PRNGKey(0), - x=jnp.ones((32, 128, 64)), - console_kwargs=CONSOLE_TEST_KWARGS, - ) + module_repr = lstm.tabulate( + random.PRNGKey(0), + x=jnp.ones((32, 128, 64)), + console_kwargs=CONSOLE_TEST_KWARGS, + ) lines = module_repr.splitlines() @@ -452,12 +451,11 @@ def __call__(self, x): lstm = LSTM(features=128) - with jax.check_tracer_leaks(True): - module_repr = lstm.tabulate( - random.PRNGKey(0), - x=jnp.ones((32, 128, 64)), - console_kwargs=CONSOLE_TEST_KWARGS, - ) + module_repr = lstm.tabulate( + random.PRNGKey(0), + x=jnp.ones((32, 128, 64)), + console_kwargs=CONSOLE_TEST_KWARGS, + ) lines = module_repr.splitlines()