diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index ffc631933d..a9618cfcbc 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -84,17 +84,18 @@ Embed as Embed, ) from .module import ( - Module as Module, - Variable as Variable, - apply as apply, - compact as compact, - disable_named_call as disable_named_call, - enable_named_call as enable_named_call, - init_with_output as init_with_output, - init as init, - merge_param as merge_param, - nowrap as nowrap, - override_named_call as override_named_call, + Module as Module, + Variable as Variable, + apply as apply, + compact as compact, + disable_named_call as disable_named_call, + enable_named_call as enable_named_call, + init_with_output as init_with_output, + init as init, + merge_param as merge_param, + nowrap as nowrap, + override_named_call as override_named_call, + intercept_methods as intercept_methods, ) from .normalization import ( BatchNorm as BatchNorm, diff --git a/flax/linen/module.py b/flax/linen/module.py index ac9b58eb95..bb32a73a8b 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -28,12 +28,10 @@ Callable, Dict, Iterable, + Iterator, List, Mapping, - NamedTuple, Optional, - Sequence, - Set, Tuple, Type, TypeVar, @@ -94,6 +92,10 @@ # pylint: disable=protected-access,attribute-defined-outside-init +def _get_fn_name(fn): + if isinstance(fn, functools.partial): + return _get_fn_name(fn.func) + return getattr(fn, '__name__', 'unnamed_function') def _indent(x: str, num_spaces: int): @@ -234,11 +236,6 @@ def __deepcopy__(self, memo): def _derive_profiling_name(module, fn): - def _get_fn_name(fn): - if isinstance(fn, functools.partial): - return _get_fn_name(fn.func) - return fn.__name__ - fn_name = _get_fn_name(fn) method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' module_name = module.name or module.__class__.__name__ @@ -288,6 +285,148 @@ def override_named_call(enable: bool = True): _use_named_call = use_named_call_prev +# Intercept module methods. +# ----------------------------------------------------------------------------- +@dataclasses.dataclass(frozen=True) +class InterceptorContext: + """Read only state showing the calling context for method interceptors. + + Attributes: + module: The Module instance whose method is being called. + method_name: The name of the method being called on the module. + orig_method: The original method defined on the module. Calling it will + short circuit all other interceptors. + """ + module: 'Module' + method_name: str + orig_method: Callable[..., Any] + + +class ThreadLocalStack(threading.local): + """Thread-local stack.""" + + def __init__(self): + self._storage = [] + + def push(self, elem: Any): + self._storage.append(elem) + + def pop(self): + return self._storage.pop() + + def __iter__(self) -> Iterator[Any]: + return iter(reversed(self._storage)) + + def __len__(self) -> int: + return len(self._storage) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self._storage})' + + +Args = Tuple[Any] +Kwargs = Dict[str, Any] +NextGetter = Callable[..., Any] +Interceptor = Callable[[NextGetter, Args, Kwargs, InterceptorContext], Any] +_global_interceptor_stack = ThreadLocalStack() + + +@contextlib.contextmanager +def intercept_methods(interceptor: Interceptor): + # pylint: disable=g-doc-return-or-yield + r"""Registers a new method interceptor. + + Method interceptors allow you to (at a distance) intercept method calls to + modules. It works similarly to decorators. You could modify args/kwargs before + calling the underlying method and/or modify the result returning from calling + the underlying method. Or you could completely skip calling the underlying + method and decide to do something differently. For example:: + + >>> import flax.linen as nn + >>> import jax.numpy as jnp + + >>> class Foo(nn.Module): + ... def __call__(self, x): + ... return x + + >>> def my_interceptor1(next_fun, args, kwargs, context): + ... print('calling my_interceptor1') + ... return next_fun(*args, **kwargs) + + >>> foo = Foo() + >>> with nn.intercept_methods(my_interceptor1): + ... _ = foo(jnp.ones([1])) + calling my_interceptor1 + + You could also register multiple interceptors on the same method. Interceptors + will run in order. For example:: + + >>> def my_interceptor2(next_fun, args, kwargs, context): + ... print('calling my_interceptor2') + ... return next_fun(*args, **kwargs) + + >>> with nn.intercept_methods(my_interceptor1), \ + ... nn.intercept_methods(my_interceptor2): + ... _ = foo(jnp.ones([1])) + calling my_interceptor1 + calling my_interceptor2 + + You could skip other interceptors by directly calling the + ``context.orig_method``. For example:: + + >>> def my_interceptor3(next_fun, args, kwargs, context): + ... print('calling my_interceptor3') + ... return context.orig_method(*args, **kwargs) + >>> with nn.intercept_methods(my_interceptor3), \ + ... nn.intercept_methods(my_interceptor1), \ + ... nn.intercept_methods(my_interceptor2): + ... _ = foo(jnp.ones([1])) + calling my_interceptor3 + + The following methods couldn't be intercepted: + + 1. Methods decoratored with ``nn.nowrap``. + 2. Dunder methods including '__eq__', '__repr__', '__init__', '__hash__', + and '__post_init__'. + 3. Module dataclass fields. + 4. Module descriptors. + + Args: + interceptor: A method interceptor. + """ + _global_interceptor_stack.push(interceptor) + try: + yield + finally: + assert _global_interceptor_stack.pop() is interceptor + + +def run_interceptors( + orig_method: Callable[..., Any], + module: 'Module', + *args, + **kwargs, +) -> Any: + """Runs method interceptors or `orig_method`.""" + if not _global_interceptor_stack: + return orig_method(module, *args, **kwargs) + + method_name = _get_fn_name(orig_method) + fun = functools.partial(orig_method, module) + context = InterceptorContext(module, method_name, fun) + def wrap_interceptor(interceptor, fun): + """Wraps `fun` with `interceptor`.""" + @functools.wraps(fun) + def wrapped(*args, **kwargs): + return interceptor(fun, args, kwargs, context) + return wrapped + + # Wraps interceptors around the original method. The innermost interceptor is + # the last one added and directly wrapped around the original bound method. + for interceptor in _global_interceptor_stack: + fun = wrap_interceptor(interceptor, fun) + return fun(*args, **kwargs) + # Utilities for pytrees of Modules defined inside setup() # ----------------------------------------------------------------------------- @@ -920,7 +1059,7 @@ def _wrap_module_attributes(cls): return cls def _call_wrapped_method(self, fun, args, kwargs): - """ "Calls a wrapped method. + """Calls a wrapped method. This function is responsible for setting up the thread local state correctly before calling the method and cleaning up afterwards. @@ -936,7 +1075,7 @@ def _call_wrapped_method(self, fun, args, kwargs): The results of calling ``fun``. """ is_compact_method = hasattr(fun, 'compact') - fun_name = getattr(fun, '__name__', 'unnamed_function') + fun_name = _get_fn_name(fun) is_setup_method = fun_name == 'setup' add_call_info = not is_setup_method and len(_context.call_info_stack) > 0 # We lazily call setup() only when needed. @@ -964,9 +1103,9 @@ def _call_wrapped_method(self, fun, args, kwargs): # call method if _use_named_call: with jax.named_scope(_derive_profiling_name(self, fun)): - y = fun(self, *args, **kwargs) + y = run_interceptors(fun, self, *args, **kwargs) else: - y = fun(self, *args, **kwargs) + y = run_interceptors(fun, self, *args, **kwargs) if _context.capture_stack: filter_fn = _context.capture_stack[-1] diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 562bd03224..ca93fb5836 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -88,6 +88,18 @@ def __call__(self, x): return y +class IdentityModule(nn.Module): + + def __call__(self, x): + return x + + +class RaisesModule(nn.Module): + + def __call__(self): + assert False + + class ModuleTest(absltest.TestCase): def test_init_module(self): @@ -2201,6 +2213,188 @@ def __call__(self, x): # take positional arg. It takes BaseLayer's default kwargs though. np.testing.assert_equal(ChildLayer(8)(np.ones(10)), -8 * np.ones(10)) + def test_intercept_methods(self): + mod = IdentityModule(parent=None) + x = jnp.ones([]) + call_count = [] + + def add_one_interceptor(f, args, kwargs, context): + call_count.append(None) + self.assertLen(dataclasses.fields(context), 3) + self.assertIs(context.module, mod) + self.assertEqual(context.method_name, '__call__') + self.assertEqual(context.orig_method(3), 3) + self.assertEqual(args, (x,)) + self.assertEmpty(kwargs) + y = f(*args, **kwargs) + return y + 1 + + y1 = mod(x) + with nn.intercept_methods(add_one_interceptor): + y2 = mod(x) + y3 = mod(x) + + self.assertLen(call_count, 1) + self.assertEqual(y1, 1) + self.assertEqual(y2, 2) + self.assertEqual(y3, 1) + + def test_intercept_methods_compact(self): + class CompactModule(nn.Module): + + @compact + def __call__(self, x): + return nn.Dense(2)(x) + + mod = CompactModule() + x = jnp.ones(shape=(1, 3)) + variables = mod.init(jax.random.PRNGKey(0), x) + call_modules = [] + + def log_interceptor(f, args, kwargs, context): + call_modules.append(context.module) + self.assertLen(dataclasses.fields(context), 3) + self.assertEqual(context.method_name, '__call__') + self.assertEqual(args, (x,)) + self.assertEmpty(kwargs) + return f(*args, **kwargs) + + with nn.intercept_methods(log_interceptor): + _ = mod.apply(variables, x) + + self.assertLen(call_modules, 2) + self.assertIsInstance(call_modules[0], CompactModule) + self.assertIsInstance(call_modules[1], nn.Dense) + + def test_intercept_methods_setup(self): + class SetupModule(nn.Module): + def setup(self): + self.layer = nn.Dense(2) + + def __call__(self, x): + return self.layer(x) + + mod = SetupModule() + x = jnp.ones(shape=(1, 3)) + variables = mod.init(jax.random.PRNGKey(0), x) + call_modules = [] + log = [] + + def log_interceptor(f, args, kwargs, context): + call_modules.append(context.module) + log.append((context.method_name, args, kwargs)) + return f(*args, **kwargs) + + with nn.intercept_methods(log_interceptor): + _ = mod.apply(variables, x) + + self.assertLen(call_modules, 3) + self.assertIsInstance(call_modules[0], SetupModule) + self.assertIsInstance(call_modules[1], SetupModule) + self.assertIsInstance(call_modules[2], nn.Dense) + self.assertEqual( + log, + [('setup', (), {}), ('__call__', (x,), {}), ('__call__', (x,), {})] + ) + + def test_intercept_methods_calling_underlying_optional(self): + def do_nothing_interceptor(f, args, kwargs, context): + del f, context + self.assertEmpty(args) + self.assertEmpty(kwargs) + + m = RaisesModule() + with nn.intercept_methods(do_nothing_interceptor): + m() + + with self.assertRaises(AssertionError): + m() + + with nn.intercept_methods(do_nothing_interceptor): + m() + + def test_intercept_methods_run_in_lifo_order(self): + def op_interceptor(op): + def _interceptor(f, args, kwargs, context): + del context + y = f(*args, **kwargs) + return op(y) + return _interceptor + + mod = IdentityModule(parent=None) + x = 7 + with nn.intercept_methods(op_interceptor(lambda a: a + 1)), \ + nn.intercept_methods(op_interceptor(lambda a: a ** 2)): + y = mod(x) + + self.assertEqual(y, (x ** 2) + 1) + + with nn.intercept_methods(op_interceptor(lambda a: a ** 2)), \ + nn.intercept_methods(op_interceptor(lambda a: a + 1)): + y = mod(x) + + self.assertEqual(y, (x + 1) ** 2) + + def test_intercept_methods_subclasses(self): + class Foo(IdentityModule): + + def __call__(self, x): # pylint: disable=useless-parent-delegation + return super().__call__(x) + + class Bar(Foo): + + def __call__(self, x): # pylint: disable=useless-parent-delegation + return super().__call__(x) + + bar = Bar(parent=None) + x = jnp.ones([]) + called = [] + + def record_interceptor(f, args, kwargs, context): + called.append(None) + self.assertIs(context.module, bar) + self.assertEqual(context.method_name, '__call__') + self.assertEqual(args, (x,)) + self.assertEmpty(kwargs) + return f(*args, **kwargs) + + with nn.intercept_methods(record_interceptor): + bar(x) + + # Bar.__call__, Foo.__call__ and IdenityModule.__call__ + self.assertLen(called, 3) + + def test_intercept_methods_nested_module(self): + class Foo(nn.Module): + def __call__(self, x): + return x + + class Bar(nn.Module): + sub: nn.Module + + def __call__(self, x): + return self.sub(x) + + foo = Foo() + bar = Bar(sub=foo) + x = jnp.ones([]) + called = [] + + def record_interceptor(f, args, kwargs, context): + called.append(context.module) + self.assertEqual(context.method_name, '__call__') + self.assertEqual(args, (x,)) + self.assertEmpty(kwargs) + return f(*args, **kwargs) + + with nn.intercept_methods(record_interceptor): + bar(x) + + # bar.__call__ and foo.__call__ + self.assertLen(called, 2) + self.assertIs(called[0], bar) + self.assertIs(called[1], foo) + class LeakTests(absltest.TestCase):