Skip to content

Commit

Permalink
add support for initializing carry variables in scan
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 26, 2023
1 parent 2497f82 commit b12efcc
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 51 deletions.
14 changes: 8 additions & 6 deletions flax/core/axes_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def body_fn(b, c, x):
the function that performs the scan of the form:
(broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out).
"""
from flax.linen.module import tabulate_context

def transpose_to_front(ax, xs):
if ax is broadcast:
Expand Down Expand Up @@ -142,12 +143,13 @@ def body_fn(c, xs, init_mode=False):
)
input_avals = (carry_avals, scan_avals)

in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
lu.wrap_init(broadcast_body), in_tree
)
in_pvals = list(map(pe.PartialVal.unknown, in_avals))
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
with tabulate_context(add_call_info=False):
in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
lu.wrap_init(broadcast_body), in_tree
)
in_pvals = list(map(pe.PartialVal.unknown, in_avals))
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)

out_flat = []
for pv, const in out_pvals:
Expand Down
190 changes: 167 additions & 23 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,20 @@
)
import warnings

from flax import traceback_util
import jax
from jax import random

from . import axes_scan
from . import meta
from flax import traceback_util, traverse_util
from .frozen_dict import freeze
from .frozen_dict import unfreeze
import jax
import jax.numpy as jnp
from jax import random
from .scope import (
CollectionFilter,
DenyList, # pylint: disable=g-multiple-import
DenyList,
PRNGSequenceFilter, # pylint: disable=g-multiple-import
Filter,
PRNGSequenceFilter,
Scope,
group_collections,
in_filter,
Expand Down Expand Up @@ -575,7 +576,11 @@ def wrapper(vars_primals, args):
treedef = jax.tree_util.tree_structure(scope)

variable_tangents = tuple(
{k: v for k, v in vt.items() if v} # pylint: disable=g-complex-comprehension
{
k: v # pylint: disable=g-complex-comprehension
for k, v in vt.items()
if v
}
for vt in treedef.flatten_up_to(variable_tangents)
)
target = tuple(variable_tangents[0].keys())
Expand Down Expand Up @@ -817,6 +822,8 @@ def body_fn(scope, c, x):
``(scope, carry, *xxs) -> (carry, yys)``, where ``xxs`` and ``yys`` are the
scan values that go in and out of the loop.
"""
from flax.linen.module import tabulate_context

variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes)
variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items())
variable_out_groups, variable_out_axes = _unzip2(variable_out_axes.items())
Expand Down Expand Up @@ -853,16 +860,19 @@ def find_length(axis, x):
for rng_group, split in zip(rng_groups, rng_splits)
)

@functools.partial(
axes_scan.scan,
carry_vars_new_axes = 0
scan_partial = lambda length, unroll: axes_scan.scan(
scanned,
in_axes=(variable_in_axes, rng_axes, in_axes),
out_axes=(out_axes, variable_out_axes),
length=length,
out_axes=(out_axes, variable_out_axes, carry_vars_new_axes),
reverse=reverse,
unroll=unroll,
length=length,
)

def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args):
carry_vars, c = carry

variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups
if data_transform is not None:
variable_groups, rng_groups = data_transform(
Expand All @@ -872,37 +882,142 @@ def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args):
c, y = fn(scope, c, *args)
out_vars = repack_fn(scope)
broadcast_vars_out = out_vars[0]
carry_vars = out_vars[1]
carry_vars_out = out_vars[1]
scan_vars = out_vars[2:]

# compute new carry vars, these will be handled as outputs
carry_vars_new = tuple(
vars_diff(outputs, inputs)
for outputs, inputs in zip(carry_vars_out, carry_vars)
)
# remove new carry vars to maintain input shape
carry_vars = tuple(
vars_diff(outputs, new)
for outputs, new in zip(carry_vars_out, carry_vars_new)
)

# add immutable broadcast vars back to broadcast output
# otherwise they won't be fed to the actual scan body
for in_group, out_group in zip(broadcast_vars, broadcast_vars_out):
for col in in_group:
if col not in out_group:
out_group[col] = in_group[col]
return broadcast_vars_out, (carry_vars, c), (y, scan_vars)
return broadcast_vars_out, (carry_vars, c), (y, scan_vars, carry_vars_new)

broadcast_vars = variable_groups[0]
carry_vars = variable_groups[1]
scan_vars = variable_groups[2:]
new_scan_vars = []
for scan_group, axis in zip(scan_vars, variable_in_axes):
new_scan_vars.append(meta.remove_axis(scan_group, axis, metadata_params))
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
broadcast_vars,
(carry_vars, init),
tuple(new_scan_vars),
rng_groups,
args,
)

# compute new carry vars
with tabulate_context(
add_call_info=False
): # dont add call info while tracing
carry_vars_new = jax.eval_shape(
scan_partial(length, unroll),
broadcast_vars,
(carry_vars, init),
tuple(new_scan_vars),
rng_groups,
args,
)[2][2]
has_new_carry_vars = len(jax.tree_util.tree_leaves(carry_vars_new)) > 0

if has_new_carry_vars:
new_scan_vars0, rng_groups0, args0 = tree_map_upto_left(
lambda axis, tree: jax.tree_map(
lambda x: jax.lax.dynamic_slice_in_dim(x, 0, 1, axis),
tree,
),
left=(variable_in_axes, rng_axes, in_axes),
right=(tuple(new_scan_vars), rng_groups, args),
)
# run scan for 1 step
partial_length = 1 if length is not None else None
with tabulate_context(
add_call_info=False
): # dont add call info on first step
(
broadcast_vars,
(carry_vars, init),
(ys1, scan_vars1, carry_vars_new),
) = scan_partial(partial_length, 1)(
broadcast_vars,
(carry_vars, init),
new_scan_vars0,
rng_groups0,
args0,
)
# slice new carry vars and merge with existing
carry_vars_new = jax.tree_map(lambda x: x[0], carry_vars_new)
carry_vars = tuple(
vars_merge(existing, new)
for existing, new in zip(carry_vars, carry_vars_new)
)
# slice rest of the inputs
new_scan_vars_rest, rng_groups_rest, args_rest = tree_map_upto_left(
lambda axis, tree: jax.tree_map(
lambda x: jax.lax.dynamic_slice_in_dim(
x, 1, x.shape[axis] - 1, axis
),
tree,
),
left=(variable_in_axes, rng_axes, in_axes),
right=(tuple(new_scan_vars), rng_groups, args),
)
# run scan on the rest of the inputs
partial_length = length - 1 if length is not None else None
(
broadcast_vars,
(carry_vars, c),
(ys_rest, scan_vars_rest, carry_vars_new),
) = scan_partial(partial_length, unroll)(
broadcast_vars,
(carry_vars, init),
new_scan_vars_rest,
rng_groups_rest,
args_rest,
)
# concat ys and scan_vars
ys = tree_map_upto_left(
lambda axis, tuple_tree: jax.tree_map(
lambda a, b: jnp.concatenate((a, b), axis=axis),
*tuple_tree,
),
left=out_axes,
right=(ys1, ys_rest),
)
scan_vars = tree_map_upto_left(
lambda axis, tuple_tree: jax.tree_map(
lambda a, b: jnp.concatenate((a, b), axis=axis),
*tuple_tree,
),
left=variable_out_axes,
right=((scan_vars1, scan_vars_rest),),
)[0]
else:
(
broadcast_vars,
(carry_vars, c),
(ys, scan_vars, carry_vars_new),
) = scan_partial(length, unroll)(
broadcast_vars,
(carry_vars, init),
tuple(new_scan_vars),
rng_groups,
args,
)

has_new_carry_vars = len(jax.tree_util.tree_leaves(carry_vars_new)) > 0
assert not has_new_carry_vars

new_scan_vars = []
for scan_group, axis in zip(scan_vars, variable_out_axes):
new_scan_vars.append(meta.add_axis(scan_group, axis, metadata_params))
scan_vars = tuple(new_scan_vars)
out_vars = (
broadcast_vars,
carry_vars,
) + scan_vars
out_vars = (broadcast_vars, carry_vars) + scan_vars
return (c, ys), out_vars

return pack(
Expand Down Expand Up @@ -1528,3 +1643,32 @@ def inner_loop(scope, carry):
def _unzip2(xs):
ys = tuple(zip(*xs))
return ys if ys else ((), ())


def vars_diff(a, b):
a = traverse_util.flatten_dict(a, sep='/')
b = traverse_util.flatten_dict(b, sep='/')

c = {path: value for path, value in a.items() if path not in b}
c = traverse_util.unflatten_dict(c, sep='/')
return c


def vars_merge(a, b):
a = traverse_util.flatten_dict(a, sep='/')
b = traverse_util.flatten_dict(b, sep='/')
a.update(b)
c = traverse_util.unflatten_dict(a, sep='/')
return c


def tree_map_upto_left(
f: Callable[[Any, Any], Any], left: Any, right: Any
) -> Any:
leaves_left, treedef = jax.tree_util.tree_flatten(left)
leaves_right = treedef.flatten_up_to(right)

return treedef.unflatten(
f(left_leaf, right_leaf)
for left_leaf, right_leaf in zip(leaves_left, leaves_right)
)
23 changes: 16 additions & 7 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,13 @@ def get_call_index(self, module: 'Module') -> int:


@contextlib.contextmanager
def _tabulate_context():
_context.call_info_stack.append(_CallInfoContext(0, []))
def tabulate_context(add_call_info: bool = True):
if _context.call_info_stack and _context.call_info_stack[-1] is None:
_context.call_info_stack.append(None)
elif add_call_info:
_context.call_info_stack.append(_CallInfoContext(0, []))
else:
_context.call_info_stack.append(None)
try:
yield
finally:
Expand All @@ -204,11 +209,9 @@ class _DynamicContext(threading.local):
# 3.7

def __init__(self):
self.module_stack = [
None,
]
self.module_stack = [None]
self.capture_stack = []
self.call_info_stack = []
self.call_info_stack: list[Optional[_CallInfoContext]] = []


# The global context
Expand Down Expand Up @@ -937,7 +940,11 @@ def _call_wrapped_method(self, fun, args, kwargs):
is_compact_method = hasattr(fun, 'compact')
fun_name = getattr(fun, '__name__', 'unnamed_function')
is_setup_method = fun_name == 'setup'
add_call_info = not is_setup_method and len(_context.call_info_stack) > 0
add_call_info = (
not is_setup_method
and _context.call_info_stack
and _context.call_info_stack[-1] is not None
)
# We lazily call setup() only when needed.
if is_setup_method:
if self.scope is None:
Expand All @@ -957,6 +964,7 @@ def _call_wrapped_method(self, fun, args, kwargs):
# get call info
if add_call_info:
assert self.scope is not None
assert _context.call_info_stack[-1] is not None
call_index = _context.call_info_stack[-1].get_call_index(self)
scope_path = jax.tree_util.tree_map(_fix_path_part, self.scope.path)

Expand All @@ -972,6 +980,7 @@ def _call_wrapped_method(self, fun, args, kwargs):
if filter_fn and filter_fn(self, fun_name):
self.sow('intermediates', fun_name, y)
if add_call_info:
assert _context.call_info_stack[-1] is not None
_args, _kwargs, _y = flax.linen.summary._represent_tree(
(args, kwargs, y)
)
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _get_module_table(
but returns the Table representation of the Module."""

def _get_table_fn(*args, **kwargs):
with module_lib._tabulate_context():
with module_lib.tabulate_context():

def _get_variables():
return module.init(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions tests/linen/linen_recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ def test_numerical_equivalence_single_batch_nn_scan(self):
cell_carry,
xs[batch_idx : batch_idx + 1, i, :],
)
np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-5)
np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-4)

carry_i = jax.tree_map(lambda x: x[batch_idx : batch_idx + 1], carry)
np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-5)
np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-4)

def test_numerical_equivalence_single_batch_jax_scan(self):
batch_size = 3
Expand Down
Loading

0 comments on commit b12efcc

Please sign in to comment.