From c6b094921ca99362c93473913313e13c127283a5 Mon Sep 17 00:00:00 2001 From: Jun Xu Date: Tue, 18 Jul 2023 15:19:41 -0700 Subject: [PATCH] Add a Haiku-like methods interceptor API. Note: 1. Methods decorated by nn.nowrap could not be intercepted. 2. Module dataclass fields and dunder methods, including __eq__, __repr__, __init__, __hash__, and __post__init__ could not be intercepted. 3. Module descriptors could not be intercepted. 4. InterceptorContext currently contains the module object, method name, and original method. More fields could be added in the future. PiperOrigin-RevId: 549126112 --- flax/linen/__init__.py | 1 + flax/linen/module.py | 538 ++++++++++++++++++++----------- tests/linen/linen_module_test.py | 198 ++++++++++++ 3 files changed, 557 insertions(+), 180 deletions(-) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index ffc631933d..74640b1a69 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -92,6 +92,7 @@ enable_named_call as enable_named_call, init_with_output as init_with_output, init as init, + intercept_methods as intercept_methods, merge_param as merge_param, nowrap as nowrap, override_named_call as override_named_call, diff --git a/flax/linen/module.py b/flax/linen/module.py index 1d248892a8..d96f67683e 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -28,6 +28,7 @@ Callable, Dict, Iterable, + Iterator, List, Literal, Mapping, @@ -93,6 +94,10 @@ # pylint: disable=protected-access,attribute-defined-outside-init +def _get_fn_name(fn): + if isinstance(fn, functools.partial): + return _get_fn_name(fn.func) + return getattr(fn, '__name__', 'unnamed_function') def _indent(x: str, num_spaces: int): @@ -156,7 +161,9 @@ def _module_repr(module: 'Module', num_spaces: int = 4): def _fix_path_part(part: str): """Fixes a path part by removing transformation name and parenthesis sometimes - inserted by lifted transformations""" + + inserted by lifted transformations + """ match = _find_non_lifted_module.match(part) if match: return match.group(1) @@ -233,11 +240,6 @@ def __deepcopy__(self, memo): def _derive_profiling_name(module, fn): - def _get_fn_name(fn): - if isinstance(fn, functools.partial): - return _get_fn_name(fn.func) - return fn.__name__ - fn_name = _get_fn_name(fn) method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' module_name = module.name or module.__class__.__name__ @@ -287,6 +289,153 @@ def override_named_call(enable: bool = True): _use_named_call = use_named_call_prev +# Intercept module methods. +# ----------------------------------------------------------------------------- +@dataclasses.dataclass(frozen=True) +class InterceptorContext: + """Read only state showing the calling context for method interceptors. + + Attributes: + module: The Module instance whose method is being called. + method_name: The name of the method being called on the module. + orig_method: The original method defined on the module. Calling it will + short circuit all other interceptors. + """ + + module: 'Module' + method_name: str + orig_method: Callable[..., Any] + + +class ThreadLocalStack(threading.local): + """Thread-local stack.""" + + def __init__(self): + self._storage = [] + + def push(self, elem: Any): + self._storage.append(elem) + + def pop(self): + return self._storage.pop() + + def __iter__(self) -> Iterator[Any]: + return iter(reversed(self._storage)) + + def __len__(self) -> int: + return len(self._storage) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self._storage})' + + +Args = Tuple[Any] +Kwargs = Dict[str, Any] +NextGetter = Callable[..., Any] +Interceptor = Callable[[NextGetter, Args, Kwargs, InterceptorContext], Any] +_global_interceptor_stack = ThreadLocalStack() + + +@contextlib.contextmanager +def intercept_methods(interceptor: Interceptor): + # pylint: disable=g-doc-return-or-yield + r"""Registers a new method interceptor. + + Method interceptors allow you to (at a distance) intercept method calls to + modules. It works similarly to decorators. You could modify args/kwargs before + calling the underlying method and/or modify the result returning from calling + the underlying method. Or you could completely skip calling the underlying + method and decide to do something differently. For example:: + + >>> import flax.linen as nn + >>> import jax.numpy as jnp + + >>> class Foo(nn.Module): + ... def __call__(self, x): + ... return x + + >>> def my_interceptor1(next_fun, args, kwargs, context): + ... print('calling my_interceptor1') + ... return next_fun(*args, **kwargs) + + >>> foo = Foo() + >>> with nn.intercept_methods(my_interceptor1): + ... _ = foo(jnp.ones([1])) + calling my_interceptor1 + + You could also register multiple interceptors on the same method. Interceptors + will run in order. For example:: + + >>> def my_interceptor2(next_fun, args, kwargs, context): + ... print('calling my_interceptor2') + ... return next_fun(*args, **kwargs) + + >>> with nn.intercept_methods(my_interceptor1), \ + ... nn.intercept_methods(my_interceptor2): + ... _ = foo(jnp.ones([1])) + calling my_interceptor1 + calling my_interceptor2 + + You could skip other interceptors by directly calling the + ``context.orig_method``. For example:: + + >>> def my_interceptor3(next_fun, args, kwargs, context): + ... print('calling my_interceptor3') + ... return context.orig_method(*args, **kwargs) + >>> with nn.intercept_methods(my_interceptor3), \ + ... nn.intercept_methods(my_interceptor1), \ + ... nn.intercept_methods(my_interceptor2): + ... _ = foo(jnp.ones([1])) + calling my_interceptor3 + + The following methods couldn't be intercepted: + + 1. Methods decoratored with ``nn.nowrap``. + 2. Dunder methods including '__eq__', '__repr__', '__init__', '__hash__', + and '__post_init__'. + 3. Module dataclass fields. + 4. Module descriptors. + + Args: + interceptor: A method interceptor. + """ + _global_interceptor_stack.push(interceptor) + try: + yield + finally: + assert _global_interceptor_stack.pop() is interceptor + + +def run_interceptors( + orig_method: Callable[..., Any], + module: 'Module', + *args, + **kwargs, +) -> Any: + """Runs method interceptors or `orig_method`.""" + if not _global_interceptor_stack: + return orig_method(module, *args, **kwargs) + + method_name = _get_fn_name(orig_method) + fun = functools.partial(orig_method, module) + context = InterceptorContext(module, method_name, fun) + + def wrap_interceptor(interceptor, fun): + """Wraps `fun` with `interceptor`.""" + + @functools.wraps(fun) + def wrapped(*args, **kwargs): + return interceptor(fun, args, kwargs, context) + + return wrapped + + # Wraps interceptors around the original method. The innermost interceptor is + # the last one added and directly wrapped around the original bound method. + for interceptor in _global_interceptor_stack: + fun = wrap_interceptor(interceptor, fun) + return fun(*args, **kwargs) + + # Utilities for pytrees of Modules defined inside setup() # ----------------------------------------------------------------------------- @@ -358,6 +507,7 @@ def compact(fun: _CallableT) -> _CallableT: Args: fun: The Module method to mark as compact. + Returns: The given function `fun` marked as compact. """ @@ -391,6 +541,7 @@ def __call__(self, x): Args: fun: The Module method to mark as nowrap. + Returns: The given function `fun` marked as nowrap. """ @@ -406,6 +557,7 @@ def _get_local_method_names( Args: cls: The class to get method names for. exclude: Names to exclude from output. + Returns: A list of method names. """ @@ -426,6 +578,7 @@ def _get_local_descriptor_names( Args: cls: The class to get property names for. exclude: Names to exclude from output. + Returns: A list of property names. """ @@ -447,6 +600,7 @@ def wrap_method_once(fun: Callable[..., Any]) -> Callable[..., Any]: Args: fun: User-defined Module method to manage state for. + Returns: Wrapped method. """ @@ -476,6 +630,7 @@ def wrap_descriptor_once(descriptor) -> 'DescriptorWrapper': Args: prop: User-defined Module attribute descriptor. + Returns: Wrapped descriptor. """ @@ -515,6 +670,7 @@ def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]: Args: method_or_fn: A class method or function. + Returns: An unbound version of input function. """ @@ -744,7 +900,9 @@ class ModuleBase: class Module(ModuleBase): - """Base class for all neural network modules. Layers and models should subclass this class. + """Base class for all neural network modules. + + Layers and models should subclass this class. All Flax Modules are Python 3.7 `dataclasses `_. Since @@ -889,6 +1047,7 @@ def _verify_single_or_no_compact(cls): @classmethod def _wrap_module_attributes(cls): """Wraps user-defined non-inherited methods and descriptors with state + management functions. """ # wrap methods @@ -919,7 +1078,7 @@ def _wrap_module_attributes(cls): return cls def _call_wrapped_method(self, fun, args, kwargs): - """ "Calls a wrapped method. + """Calls a wrapped method. This function is responsible for setting up the thread local state correctly before calling the method and cleaning up afterwards. @@ -935,7 +1094,7 @@ def _call_wrapped_method(self, fun, args, kwargs): The results of calling ``fun``. """ is_compact_method = hasattr(fun, 'compact') - fun_name = getattr(fun, '__name__', 'unnamed_function') + fun_name = _get_fn_name(fun) is_setup_method = fun_name == 'setup' add_call_info = not is_setup_method and len(_context.call_info_stack) > 0 # We lazily call setup() only when needed. @@ -963,9 +1122,9 @@ def _call_wrapped_method(self, fun, args, kwargs): # call method if _use_named_call: with jax.named_scope(_derive_profiling_name(self, fun)): - y = fun(self, *args, **kwargs) + y = run_interceptors(fun, self, *args, **kwargs) else: - y = fun(self, *args, **kwargs) + y = run_interceptors(fun, self, *args, **kwargs) if _context.capture_stack: filter_fn = _context.capture_stack[-1] @@ -1280,11 +1439,11 @@ def clone( parent: The parent of the clone. The clone will have no parent if no explicit parent is specified. _deep_clone: A boolean or a weak value dictionary to control deep cloning - of submodules. If True, submodules will be cloned recursively. If a - weak value dictionary is passed, it will be used to cache cloned - submodules. This flag is used by init/apply/bind to avoid scope - leakage. + of submodules. If True, submodules will be cloned recursively. If a weak + value dictionary is passed, it will be used to cache cloned submodules. + This flag is used by init/apply/bind to avoid scope leakage. **updates: Attribute updates. + Returns: A clone of the this Module with the updated attributes and parent. """ @@ -1401,10 +1560,10 @@ def variable( Args: col: The variable collection name. name: The variable name. - init_fn: The function that will be called to compute the initial value - of this variable. This function will only be called the first time - this variable is used in this module. If None, the variable must - already be initialized otherwise an error is raised. + init_fn: The function that will be called to compute the initial value of + this variable. This function will only be called the first time this + variable is used in this module. If None, the variable must already be + initialized otherwise an error is raised. *init_args: The arguments to pass to init_fn. unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed value, see ``flax.nn.meta.unbox`` (default: True). @@ -1475,9 +1634,9 @@ def param( Args: name: The parameter name. - init_fn: The function that will be called to compute the initial value - of this variable. This function will only be called the first time - this parameter is used in this module. + init_fn: The function that will be called to compute the initial value of + this variable. This function will only be called the first time this + parameter is used in this module. *init_args: The arguments to pass to init_fn. unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed value, see ``flax.nn.meta.unbox`` (default: True). @@ -1507,6 +1666,7 @@ def has_variable(self, col: str, name: str) -> bool: Args: col: The variable collection name. name: The name of the variable. + Returns: True if the variable exists. """ @@ -1537,6 +1697,7 @@ def make_rng(self, name: str) -> KeyArray: Args: name: The RNG sequence name. + Returns: The newly generated RNG key. """ @@ -1612,15 +1773,14 @@ def __call__(self, x): Args: variables: A dictionary containing variables keyed by variable - collections. See :mod:`flax.core.variables` for more details - about variables. + collections. See :mod:`flax.core.variables` for more details about + variables. *args: Named arguments (not used). rngs: a dict of PRNGKeys to initialize the PRNG sequences. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: - ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. - ``list``: A list of names of mutable collections. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. Returns: A copy of this instance with bound variables and RNGs. @@ -1637,9 +1797,9 @@ def unbind(self: M) -> Tuple[M, VariableDict]: ``unbind`` helps create a stateless version of a bound Module. An example of a common use case: to extract a sub-Module defined inside - ``setup()`` and its corresponding variables: 1) temporarily ``bind`` the parent - Module; and then 2) ``unbind`` the desired sub-Module. (Recall that ``setup()`` - is only called when the Module is bound.):: + ``setup()`` and its corresponding variables: 1) temporarily ``bind`` the + parent Module; and then 2) ``unbind`` the desired sub-Module. (Recall that + ``setup()`` is only called when the Module is bound.):: class AutoEncoder(nn.Module): def setup(self): @@ -1712,26 +1872,27 @@ def other_fn(instance, ...): Args: variables: A dictionary containing variables keyed by variable - collections. See :mod:`flax.core.variables` for more details - about variables. + collections. See :mod:`flax.core.variables` for more details about + variables. *args: Named arguments passed to the specified apply method. - rngs: a dict of PRNGKeys to initialize the PRNG sequences. - The "params" PRNG sequence is used to initialize parameters. + rngs: a dict of PRNGKeys to initialize the PRNG sequences. The "params" + PRNG sequence is used to initialize parameters. method: A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the ``__call__`` method of the module. A string can also be provided to specify a method by name. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. ``list``: A - list of names of mutable collections. - capture_intermediates: If `True`, captures intermediate return values - of all Modules inside the "intermediates" collection. By default only - the return values of all ``__call__`` methods are stored. A function can - be passed to change the filter behavior. The filter function takes - the Module instance and method name and returns a bool indicating - whether the output of that method invocation should be stored. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. + capture_intermediates: If `True`, captures intermediate return values of + all Modules inside the "intermediates" collection. By default only the + return values of all ``__call__`` methods are stored. A function can be + passed to change the filter behavior. The filter function takes the + Module instance and method name and returns a bool indicating whether + the output of that method invocation should be stored. **kwargs: Keyword arguments passed to the specified apply method. + Returns: If ``mutable`` is False, returns output. If any collections are mutable, returns ``(output, vars)``, where ``vars`` are is a dict @@ -1779,17 +1940,18 @@ def init_with_output( provided, applies the ``__call__`` method. A string can also be' provided to specify a method by name. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. ``list``: A - list of names of mutable collections. By default all collections - except "intermediates" are mutable. - capture_intermediates: If `True`, captures intermediate return values - of all Modules inside the "intermediates" collection. By default only - the return values of all ``__call__`` methods are stored. A function can - be passed to change the filter behavior. The filter function takes - the Module instance and method name and returns a bool indicating - whether the output of that method invocation should be stored. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. By default all collections except "intermediates" + are mutable. + capture_intermediates: If `True`, captures intermediate return values of + all Modules inside the "intermediates" collection. By default only the + return values of all ``__call__`` methods are stored. A function can be + passed to change the filter behavior. The filter function takes the + Module instance and method name and returns a bool indicating whether + the output of that method invocation should be stored. **kwargs: Keyword arguments passed to the init function. + Returns: `(output, vars)``, where ``vars`` are is a dict of the modified collections. @@ -1837,7 +1999,10 @@ def init( ) -> Union[FrozenVariableDict, Dict[str, Any]]: """Initializes a module method with variables and returns modified variables. - ``init`` takes as first argument either a single ``PRNGKey``, or a dictionary mapping variable collections names to their ``PRNGKeys``, and will call ``method`` (which is the module's ``__call__`` function by default) passing ``*args`` and ``**kwargs``, and returns + ``init`` takes as first argument either a single ``PRNGKey``, or a + dictionary mapping variable collections names to their ``PRNGKeys``, and + will call ``method`` (which is the module's ``__call__`` function by + default) passing ``*args`` and ``**kwargs``, and returns a dictionary of initialized variables. Example:: @@ -1858,9 +2023,10 @@ def init( >>> key = jax.random.PRNGKey(0) >>> variables = module.init(key, jnp.empty((1, 7)), train=False) - If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'`` RNG stream. - If you want to use a different RNG stream or need to use multiple streams, you must pass a - dictionary mapping each RNG stream name to its corresponding ``PRNGKey`` to ``init``. + If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'`` + RNG stream. If you want to use a different RNG stream or need to use + multiple streams, you must pass a dictionary mapping each RNG stream name + to its corresponding ``PRNGKey`` to ``init``. Example:: @@ -1878,7 +2044,8 @@ def init( ... return nn.Dense(1)(x) ... >>> module = Foo() - >>> rngs = {'params': jax.random.PRNGKey(0), 'noise': jax.random.PRNGKey(1)} + >>> rngs = {'params': jax.random.PRNGKey(0), + ... 'noise': jax.random.PRNGKey(1)} >>> variables = module.init(rngs, jnp.empty((1, 7)), train=False) Jitting `init` initializes a model lazily using only the shapes of the @@ -1889,27 +2056,29 @@ def init( >>> init_jit = jax.jit(module.init) >>> variables = init_jit(jax.random.PRNGKey(0), jnp.empty((1, 7))) - ``init`` is a light wrapper over ``apply``, so other ``apply`` arguments like - ``method``, ``mutable``, and ``capture_intermediates`` are also available. + ``init`` is a light wrapper over ``apply``, so other ``apply`` arguments + like ``method``, ``mutable``, and ``capture_intermediates`` are also + available. Args: rngs: The rngs for the variable collections. *args: Named arguments passed to the init function. method: An optional method. If provided, applies this method. If not - provided, applies the ``__call__`` method. A string can also be - provided to specify a method by name. + provided, applies the ``__call__`` method. A string can also be provided + to specify a method by name. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. ``list``: A - list of names of mutable collections. By default all collections - except "intermediates" are mutable. - capture_intermediates: If `True`, captures intermediate return values - of all Modules inside the "intermediates" collection. By default only - the return values of all ``__call__`` methods are stored. A function can - be passed to change the filter behavior. The filter function takes - the Module instance and method name and returns a bool indicating - whether the output of that method invocation should be stored. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. By default all collections except "intermediates" + are mutable. + capture_intermediates: If `True`, captures intermediate return values of + all Modules inside the "intermediates" collection. By default only the + return values of all ``__call__`` methods are stored. A function can be + passed to change the filter behavior. The filter function takes the + Module instance and method name and returns a bool indicating whether + the output of that method invocation should be stored. **kwargs: Keyword arguments passed to the init function. + Returns: The initialized variable dict. """ @@ -1937,17 +2106,18 @@ def lazy_init( """Initializes a module without computing on an actual input. lazy_init will initialize the variables without doing unnecessary compute. - The input data should be passed as a ``jax.ShapeDtypeStruct`` which specifies - the shape and dtype of the input but no concrete data. + The input data should be passed as a ``jax.ShapeDtypeStruct`` which + specifies the shape and dtype of the input but no concrete data. Example:: model = nn.Dense(features=256) - variables = model.lazy_init(rng, jax.ShapeDtypeStruct((1, 128), jnp.float32)) + variables = model.lazy_init( + rng, jax.ShapeDtypeStruct((1, 128), jnp.float32)) The args and kwargs args passed to ``lazy_init`` can be a mix of - concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. - Concrete values are only necessary for arguments that affect + concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) + values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise @@ -1959,11 +2129,12 @@ def lazy_init( method: An optional method. If provided, applies this method. If not provided, applies the ``__call__`` method. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. ``list``: A - list of names of mutable collections. By default all collections - except "intermediates" are mutable. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. By default all collections except "intermediates" + are mutable. **kwargs: Keyword arguments passed to the init function. + Returns: The initialized variable dict. """ @@ -2078,18 +2249,19 @@ def __call__(self, x): model = Foo2() variables = model.init(jax.random.PRNGKey(0), x) - y, state = model.apply(variables, jnp.ones((1, 1)), mutable=['intermediates']) + y, state = model.apply( + variables, jnp.ones((1, 1)), mutable=['intermediates']) print(state['intermediates']) # ==> {'h': [[3.]]} Args: col: The name of the variable collection. name: The name of the variable. value: The value of the variable. - reduce_fn: The function used to combine the existing value with - the new value. The default is to append the value to a tuple. - init_fn: For the first value stored, `reduce_fn` will be passed - the result of `init_fn` together with the value to be stored. - The default is an empty tuple. + reduce_fn: The function used to combine the existing value with the new + value. The default is to append the value to a tuple. + init_fn: For the first value stored, `reduce_fn` will be passed the result + of `init_fn` together with the value to be stored. The default is an + empty tuple. Returns: `True` if the value has been stored successfully, `False` otherwise. @@ -2146,16 +2318,18 @@ def loss(params, perturbations, inputs, targets): y = jnp.ones((2, 2)) model = Foo() variables = model.init(jax.random.PRNGKey(0), x) - intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y) + intm_grads = jax.grad(loss, argnums=1)( + variables['params'], variables['perturbations'], x, y) print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847] # [-1.456924 -0.44332537 0.02422847]] If perturbations are not passed to `apply`, `perturb` behaves like a no-op so you can easily disable the behavior when not needed:: - model.apply({'params': params, 'perturbations': perturbations}, x) # works as expected + model.apply( + {'params': params, 'perturbations': perturbations}, + x) # works as expected model.apply({'params': params}, x) # behaves like a no-op - """ def _root_has_collection(): @@ -2189,8 +2363,9 @@ def tabulate( the Module in a table. `tabulate` uses `jax.eval_shape` to run the forward computation without consuming any FLOPs or allocating memory. - Additional arguments can be passed into the `console_kwargs` argument, for example, - `{'width': 120}`. For a full list of `console_kwargs` arguments, see: + Additional arguments can be passed into the `console_kwargs` argument, for + example, `{'width': 120}`. For a full list of `console_kwargs` arguments, + see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console Example:: @@ -2213,23 +2388,23 @@ def __call__(self, x): This gives the following output:: Foo Summary - ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃ - ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ - │ │ Foo │ float32[16,9] │ float32[16,2] │ │ - ├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤ - │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ bias: float32[4] │ - │ │ │ │ │ kernel: float32[9,4] │ - │ │ │ │ │ │ - │ │ │ │ │ 40 (160 B) │ - ├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤ - │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ bias: float32[2] │ - │ │ │ │ │ kernel: float32[4,2] │ - │ │ │ │ │ │ - │ │ │ │ │ 10 (40 B) │ - ├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤ - │ │ │ │ Total │ 50 (200 B) │ - └─────────┴────────┴───────────────┴───────────────┴──────────────────────┘ + ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ + ┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃ + ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ + │ │ Foo │ float32[16,9] │ float32[16,2] │ │ + ├─────────┼────────┼───────────────┼───────────────┼─────────────────────┤ + │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ bias: float32[4] │ + │ │ │ │ │ kernel: float32[9,4]│ + │ │ │ │ │ │ + │ │ │ │ │ 40 (160 B) │ + ├─────────┼────────┼───────────────┼───────────────┼─────────────────────┤ + │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ bias: float32[2] │ + │ │ │ │ │ kernel: float32[4,2]│ + │ │ │ │ │ │ + │ │ │ │ │ 10 (40 B) │ + ├─────────┼────────┼───────────────┼───────────────┼─────────────────────┤ + │ │ │ │ Total │ 50 (200 B) │ + └─────────┴────────┴───────────────┴───────────────┴─────────────────────┘ Total Parameters: 50 (200 B) @@ -2242,24 +2417,26 @@ def __call__(self, x): *args: The arguments to the forward computation. depth: controls how many submodule deep the summary can go. By default its `None` which means no limit. If a submodule is not shown because of the - depth limit, its parameter count and bytes will be added to the row of its - first shown ancestor such that the sum of all rows always adds up to the - total number of parameters of the Module. + depth limit, its parameter count and bytes will be added to the row of + its first shown ancestor such that the sum of all rows always adds up to + the total number of parameters of the Module. show_repeated: If `True`, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is `False`. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: ``bool``: all/no collections are mutable. ``str``: The - name of a single mutable collection. ``list``: A list of names of mutable - collections. By default all collections except 'intermediates' are - mutable. - console_kwargs: An optional dictionary with additional keyword arguments that - are passed to `rich.console.Console` when rendering the table. Default arguments - are `{'force_terminal': True, 'force_jupyter': False}`. - table_kwargs: An optional dictionary with additional keyword arguments that - are passed to `rich.table.Table` constructor. - column_kwargs: An optional dictionary with additional keyword arguments that - are passed to `rich.table.Table.add_column` when adding columns to the table. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. By default all collections except 'intermediates' + are mutable. + console_kwargs: An optional dictionary with additional keyword arguments + that are passed to `rich.console.Console` when rendering the table. + Default arguments are `{'force_terminal': True, 'force_jupyter': + False}`. + table_kwargs: An optional dictionary with additional keyword arguments + that are passed to `rich.table.Table` constructor. + column_kwargs: An optional dictionary with additional keyword arguments + that are passed to `rich.table.Table.add_column` when adding columns to + the table. **kwargs: keyword arguments to pass to the forward computation. Returns: @@ -2305,9 +2482,9 @@ def __call__(self, train: Optional[bool] = None): name: the name of the parameter. Used for error messages. a: option a b: option b + Returns: a or b whichever is not `None`. - """ if a is None and b is None: raise ValueError( @@ -2333,10 +2510,10 @@ def apply( ) -> Callable[..., Any]: """Creates an apply function to call ``fn`` with a bound module. - Unlike ``Module.apply`` this function returns a new function with the signature - ``(variables, *args, rngs=None, **kwargs) -> T`` where `T` is the return type - of ``fn``. If ``mutable`` is not ``False`` the return type is a tuple where the - second item is a ``FrozenDict`` with the mutated variables. + Unlike ``Module.apply`` this function returns a new function with the + signature ``(variables, *args, rngs=None, **kwargs) -> T`` where `T` is the + return type of ``fn``. If ``mutable`` is not ``False`` the return type is a + tuple where the second item is a ``FrozenDict`` with the mutated variables. The apply function that is returned can be directly composed with JAX transformations like ``jax.jit``:: @@ -2352,22 +2529,22 @@ def f(foo, x): f_jitted(variables, x) Args: - fn: The function that should be applied. The first argument passed will - be an module instance of the ``module`` with variables and RNGs bound - to it. - module: The ``Module`` that will be used to bind variables and RNGs to. - The ``Module`` passed as the first argument to ``fn`` will be a clone - of module. + fn: The function that should be applied. The first argument passed will be + an module instance of the ``module`` with variables and RNGs bound to it. + module: The ``Module`` that will be used to bind variables and RNGs to. The + ``Module`` passed as the first argument to ``fn`` will be a clone of + module. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. ``list``: A - list of names of mutable collections. - capture_intermediates: If `True`, captures intermediate return values - of all Modules inside the "intermediates" collection. By default only - the return values of all `__call__` methods are stored. A function can - be passed to change the filter behavior. The filter function takes - the Module instance and method name and returns a bool indicating - whether the output of that method invocation should be stored. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: The + name of a single mutable collection. ``list``: A list of names of mutable + collections. + capture_intermediates: If `True`, captures intermediate return values of all + Modules inside the "intermediates" collection. By default only the return + values of all `__call__` methods are stored. A function can be passed to + change the filter behavior. The filter function takes the Module instance + and method name and returns a bool indicating whether the output of that + method invocation should be stored. + Returns: The apply function wrapping ``fn``. """ @@ -2396,10 +2573,11 @@ def init_with_output( ) -> Callable[..., Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: """Creates an init function to call ``fn`` with a bound module that also returns the function outputs. - Unlike ``Module.init_with_output`` this function returns a new function with the signature - ``(rngs, *args, **kwargs) -> (T, variables)`` where `T` is the return type of ``fn``. - The rngs can be a dict of PRNGKeys or a single ```PRNGKey`` which is - equivalent to passing a dict with one PRNGKey with the name "params". + Unlike ``Module.init_with_output`` this function returns a new function with + the signature ``(rngs, *args, **kwargs) -> (T, variables)`` where `T` is the + return type of ``fn``. The rngs can be a dict of PRNGKeys or a single + ```PRNGKey`` which is equivalent to passing a dict with one PRNGKey with the + name "params". The init function that is returned can be directly composed with JAX transformations like ``jax.jit``:: @@ -2415,23 +2593,23 @@ def f(foo, x): y, variables = f_jitted(rng, x) Args: - fn: The function that should be applied. The first argument passed will - be an module instance of the ``module`` with variables and RNGs bound - to it. - module: The ``Module`` that will be used to bind variables and RNGs to. - The ``Module`` passed as the first argument to ``fn`` will be a clone - of module. + fn: The function that should be applied. The first argument passed will be + an module instance of the ``module`` with variables and RNGs bound to it. + module: The ``Module`` that will be used to bind variables and RNGs to. The + ``Module`` passed as the first argument to ``fn`` will be a clone of + module. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. ``list``: A - list of names of mutable collections. By default all collections - except "intermediates" are mutable. - capture_intermediates: If `True`, captures intermediate return values - of all Modules inside the "intermediates" collection. By default only - the return values of all `__call__` methods are stored. A function can - be passed to change the filter behavior. The filter function takes - the Module instance and method name and returns a bool indicating - whether the output of that method invocation should be stored. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: The + name of a single mutable collection. ``list``: A list of names of mutable + collections. By default all collections except "intermediates" are + mutable. + capture_intermediates: If `True`, captures intermediate return values of all + Modules inside the "intermediates" collection. By default only the return + values of all `__call__` methods are stored. A function can be passed to + change the filter behavior. The filter function takes the Module instance + and method name and returns a bool indicating whether the output of that + method invocation should be stored. + Returns: The init function wrapping ``fn``. """ @@ -2479,23 +2657,23 @@ def f(foo, x): variables = f_jitted(rng, x) Args: - fn: The function that should be applied. The first argument passed will - be an module instance of the ``module`` with variables and RNGs bound - to it. - module: The ``Module`` that will be used to bind variables and RNGs to. - The ``Module`` passed as the first argument to ``fn`` will be a clone - of module. + fn: The function that should be applied. The first argument passed will be + an module instance of the ``module`` with variables and RNGs bound to it. + module: The ``Module`` that will be used to bind variables and RNGs to. The + ``Module`` passed as the first argument to ``fn`` will be a clone of + module. mutable: Can be bool, str, or list. Specifies which collections should be - treated as mutable: ``bool``: all/no collections are mutable. - ``str``: The name of a single mutable collection. ``list``: A - list of names of mutable collections. By default all collections - except "intermediates" are mutable. - capture_intermediates: If `True`, captures intermediate return values - of all Modules inside the "intermediates" collection. By default only - the return values of all `__call__` methods are stored. A function can - be passed to change the filter behavior. The filter function takes - the Module instance and method name and returns a bool indicating - whether the output of that method invocation should be stored. + treated as mutable: ``bool``: all/no collections are mutable. ``str``: The + name of a single mutable collection. ``list``: A list of names of mutable + collections. By default all collections except "intermediates" are + mutable. + capture_intermediates: If `True`, captures intermediate return values of all + Modules inside the "intermediates" collection. By default only the return + values of all `__call__` methods are stored. A function can be passed to + change the filter behavior. The filter function takes the Module instance + and method name and returns a bool indicating whether the output of that + method invocation should be stored. + Returns: The init function wrapping ``fn``. """ diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 562bd03224..c0f1b211fc 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -88,6 +88,18 @@ def __call__(self, x): return y +class IdentityModule(nn.Module): + + def __call__(self, x): + return x + + +class RaisesModule(nn.Module): + + def __call__(self): + assert False + + class ModuleTest(absltest.TestCase): def test_init_module(self): @@ -2201,6 +2213,192 @@ def __call__(self, x): # take positional arg. It takes BaseLayer's default kwargs though. np.testing.assert_equal(ChildLayer(8)(np.ones(10)), -8 * np.ones(10)) + def test_intercept_methods(self): + mod = IdentityModule(parent=None) + x = jnp.ones([]) + call_count = [] + + def add_one_interceptor(f, args, kwargs, context): + call_count.append(None) + self.assertLen(dataclasses.fields(context), 3) + self.assertIs(context.module, mod) + self.assertEqual(context.method_name, '__call__') + self.assertEqual(context.orig_method(3), 3) + self.assertEqual(args, (x,)) + self.assertEmpty(kwargs) + y = f(*args, **kwargs) + return y + 1 + + y1 = mod(x) + with nn.intercept_methods(add_one_interceptor): + y2 = mod(x) + y3 = mod(x) + + self.assertLen(call_count, 1) + self.assertEqual(y1, 1) + self.assertEqual(y2, 2) + self.assertEqual(y3, 1) + + def test_intercept_methods_compact(self): + class CompactModule(nn.Module): + + @compact + def __call__(self, x): + return nn.Dense(2)(x) + + mod = CompactModule() + x = jnp.ones(shape=(1, 3)) + variables = mod.init(jax.random.PRNGKey(0), x) + call_modules = [] + + def log_interceptor(f, args, kwargs, context): + call_modules.append(context.module) + self.assertLen(dataclasses.fields(context), 3) + self.assertEqual(context.method_name, '__call__') + self.assertEqual(args, (x,)) + self.assertEmpty(kwargs) + return f(*args, **kwargs) + + with nn.intercept_methods(log_interceptor): + _ = mod.apply(variables, x) + + self.assertLen(call_modules, 2) + self.assertIsInstance(call_modules[0], CompactModule) + self.assertIsInstance(call_modules[1], nn.Dense) + + def test_intercept_methods_setup(self): + class SetupModule(nn.Module): + + def setup(self): + self.layer = nn.Dense(2) + + def __call__(self, x): + return self.layer(x) + + mod = SetupModule() + x = jnp.ones(shape=(1, 3)) + variables = mod.init(jax.random.PRNGKey(0), x) + call_modules = [] + log = [] + + def log_interceptor(f, args, kwargs, context): + call_modules.append(context.module) + log.append((context.method_name, args, kwargs)) + return f(*args, **kwargs) + + with nn.intercept_methods(log_interceptor): + _ = mod.apply(variables, x) + + self.assertLen(call_modules, 3) + self.assertIsInstance(call_modules[0], SetupModule) + self.assertIsInstance(call_modules[1], SetupModule) + self.assertIsInstance(call_modules[2], nn.Dense) + self.assertEqual( + log, [('setup', (), {}), ('__call__', (x,), {}), ('__call__', (x,), {})] + ) + + def test_intercept_methods_calling_underlying_optional(self): + def do_nothing_interceptor(f, args, kwargs, context): + del f, context + self.assertEmpty(args) + self.assertEmpty(kwargs) + + m = RaisesModule() + with nn.intercept_methods(do_nothing_interceptor): + m() + + with self.assertRaises(AssertionError): + m() + + with nn.intercept_methods(do_nothing_interceptor): + m() + + def test_intercept_methods_run_in_lifo_order(self): + def op_interceptor(op): + def _interceptor(f, args, kwargs, context): + del context + y = f(*args, **kwargs) + return op(y) + + return _interceptor + + mod = IdentityModule(parent=None) + x = 7 + with nn.intercept_methods( + op_interceptor(lambda a: a + 1) + ), nn.intercept_methods(op_interceptor(lambda a: a**2)): + y = mod(x) + + self.assertEqual(y, (x**2) + 1) + + with nn.intercept_methods( + op_interceptor(lambda a: a**2) + ), nn.intercept_methods(op_interceptor(lambda a: a + 1)): + y = mod(x) + + self.assertEqual(y, (x + 1) ** 2) + + def test_intercept_methods_subclasses(self): + class Foo(IdentityModule): + + def __call__(self, x): # pylint: disable=useless-parent-delegation + return super().__call__(x) + + class Bar(Foo): + + def __call__(self, x): # pylint: disable=useless-parent-delegation + return super().__call__(x) + + bar = Bar(parent=None) + x = jnp.ones([]) + called = [] + + def record_interceptor(f, args, kwargs, context): + called.append(None) + self.assertIs(context.module, bar) + self.assertEqual(context.method_name, '__call__') + self.assertEqual(args, (x,)) + self.assertEmpty(kwargs) + return f(*args, **kwargs) + + with nn.intercept_methods(record_interceptor): + bar(x) + + # Bar.__call__, Foo.__call__ and IdenityModule.__call__ + self.assertLen(called, 3) + + def test_intercept_methods_nested_module(self): + class Foo(nn.Module): + + def __call__(self, x): + return x + + class Bar(nn.Module): + sub: nn.Module + + def __call__(self, x): + return self.sub(x) + + foo = Foo() + bar = Bar(sub=foo) + x = jnp.ones([]) + called = [] + + def record_interceptor(f, args, kwargs, context): + called.append(context.module) + self.assertEqual(context.method_name, '__call__') + self.assertEqual(args, (x,)) + self.assertEmpty(kwargs) + return f(*args, **kwargs) + + with nn.intercept_methods(record_interceptor): + bar(x) + + # bar.__call__ and foo.__call__ + self.assertLen(called, 2) + self.assertIs(called[0], bar) + self.assertIs(called[1], foo) + class LeakTests(absltest.TestCase):