Skip to content

Commit

Permalink
Add a Haiku-like methods interceptor API.
Browse files Browse the repository at this point in the history
Note:
1. Methods decorated by nn.nowrap could not be intercepted.
2. Module dataclass fields and dunder methods, including __eq__, __repr__, __init__, __hash__, and __post__init__ could not be intercepted.
3. Module descriptors could not be intercepted.
4. InterceptorContext currently contains the module object, method name, and original method. More fields could be added in the future.
PiperOrigin-RevId: 549126112
  • Loading branch information
JXRiver authored and Flax Authors committed Aug 10, 2023
1 parent 3ea6381 commit 24f6b17
Show file tree
Hide file tree
Showing 3 changed files with 357 additions and 23 deletions.
23 changes: 12 additions & 11 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,18 @@
Embed as Embed,
)
from .module import (
Module as Module,
Variable as Variable,
apply as apply,
compact as compact,
disable_named_call as disable_named_call,
enable_named_call as enable_named_call,
init_with_output as init_with_output,
init as init,
merge_param as merge_param,
nowrap as nowrap,
override_named_call as override_named_call,
Module as Module,
Variable as Variable,
apply as apply,
compact as compact,
disable_named_call as disable_named_call,
enable_named_call as enable_named_call,
init_with_output as init_with_output,
init as init,
merge_param as merge_param,
nowrap as nowrap,
override_named_call as override_named_call,
intercept_methods as intercept_methods,
)
from .normalization import (
BatchNorm as BatchNorm,
Expand Down
163 changes: 151 additions & 12 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -94,6 +92,10 @@


# pylint: disable=protected-access,attribute-defined-outside-init
def _get_fn_name(fn):
if isinstance(fn, functools.partial):
return _get_fn_name(fn.func)
return getattr(fn, '__name__', 'unnamed_function')


def _indent(x: str, num_spaces: int):
Expand Down Expand Up @@ -234,11 +236,6 @@ def __deepcopy__(self, memo):


def _derive_profiling_name(module, fn):
def _get_fn_name(fn):
if isinstance(fn, functools.partial):
return _get_fn_name(fn.func)
return fn.__name__

fn_name = _get_fn_name(fn)
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
module_name = module.name or module.__class__.__name__
Expand Down Expand Up @@ -288,6 +285,148 @@ def override_named_call(enable: bool = True):
_use_named_call = use_named_call_prev


# Intercept module methods.
# -----------------------------------------------------------------------------
@dataclasses.dataclass(frozen=True)
class InterceptorContext:
"""Read only state showing the calling context for method interceptors.
Attributes:
module: The Module instance whose method is being called.
method_name: The name of the method being called on the module.
orig_method: The original method defined on the module. Calling it will
short circuit all other interceptors.
"""
module: 'Module'
method_name: str
orig_method: Callable[..., Any]


class ThreadLocalStack(threading.local):
"""Thread-local stack."""

def __init__(self):
self._storage = []

def push(self, elem: Any):
self._storage.append(elem)

def pop(self):
return self._storage.pop()

def __iter__(self) -> Iterator[Any]:
return iter(reversed(self._storage))

def __len__(self) -> int:
return len(self._storage)

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self._storage})'


Args = Tuple[Any]
Kwargs = Dict[str, Any]
NextGetter = Callable[..., Any]
Interceptor = Callable[[NextGetter, Args, Kwargs, InterceptorContext], Any]
_global_interceptor_stack = ThreadLocalStack()


@contextlib.contextmanager
def intercept_methods(interceptor: Interceptor):
# pylint: disable=g-doc-return-or-yield
r"""Registers a new method interceptor.
Method interceptors allow you to (at a distance) intercept method calls to
modules. It works similarly to decorators. You could modify args/kwargs before
calling the underlying method and/or modify the result returning from calling
the underlying method. Or you could completely skip calling the underlying
method and decide to do something differently. For example::
>>> import flax.linen as nn
>>> import jax.numpy as jnp
>>> class Foo(nn.Module):
... def __call__(self, x):
... return x
>>> def my_interceptor1(next_fun, args, kwargs, context):
... print('calling my_interceptor1')
... return next_fun(*args, **kwargs)
>>> foo = Foo()
>>> with nn.intercept_methods(my_interceptor1):
... _ = foo(jnp.ones([1]))
calling my_interceptor1
You could also register multiple interceptors on the same method. Interceptors
will run in order. For example::
>>> def my_interceptor2(next_fun, args, kwargs, context):
... print('calling my_interceptor2')
... return next_fun(*args, **kwargs)
>>> with nn.intercept_methods(my_interceptor1), \
... nn.intercept_methods(my_interceptor2):
... _ = foo(jnp.ones([1]))
calling my_interceptor1
calling my_interceptor2
You could skip other interceptors by directly calling the
``context.orig_method``. For example::
>>> def my_interceptor3(next_fun, args, kwargs, context):
... print('calling my_interceptor3')
... return context.orig_method(*args, **kwargs)
>>> with nn.intercept_methods(my_interceptor3), \
... nn.intercept_methods(my_interceptor1), \
... nn.intercept_methods(my_interceptor2):
... _ = foo(jnp.ones([1]))
calling my_interceptor3
The following methods couldn't be intercepted:
1. Methods decoratored with ``nn.nowrap``.
2. Dunder methods including '__eq__', '__repr__', '__init__', '__hash__',
and '__post_init__'.
3. Module dataclass fields.
4. Module descriptors.
Args:
interceptor: A method interceptor.
"""
_global_interceptor_stack.push(interceptor)
try:
yield
finally:
assert _global_interceptor_stack.pop() is interceptor


def run_interceptors(
orig_method: Callable[..., Any],
module: 'Module',
*args,
**kwargs,
) -> Any:
"""Runs method interceptors or `orig_method`."""
if not _global_interceptor_stack:
return orig_method(module, *args, **kwargs)

method_name = _get_fn_name(orig_method)
fun = functools.partial(orig_method, module)
context = InterceptorContext(module, method_name, fun)
def wrap_interceptor(interceptor, fun):
"""Wraps `fun` with `interceptor`."""
@functools.wraps(fun)
def wrapped(*args, **kwargs):
return interceptor(fun, args, kwargs, context)
return wrapped

# Wraps interceptors around the original method. The innermost interceptor is
# the last one added and directly wrapped around the original bound method.
for interceptor in _global_interceptor_stack:
fun = wrap_interceptor(interceptor, fun)
return fun(*args, **kwargs)

# Utilities for pytrees of Modules defined inside setup()
# -----------------------------------------------------------------------------

Expand Down Expand Up @@ -920,7 +1059,7 @@ def _wrap_module_attributes(cls):
return cls

def _call_wrapped_method(self, fun, args, kwargs):
""" "Calls a wrapped method.
"""Calls a wrapped method.
This function is responsible for setting up the thread local state
correctly before calling the method and cleaning up afterwards.
Expand All @@ -936,7 +1075,7 @@ def _call_wrapped_method(self, fun, args, kwargs):
The results of calling ``fun``.
"""
is_compact_method = hasattr(fun, 'compact')
fun_name = getattr(fun, '__name__', 'unnamed_function')
fun_name = _get_fn_name(fun)
is_setup_method = fun_name == 'setup'
add_call_info = not is_setup_method and len(_context.call_info_stack) > 0
# We lazily call setup() only when needed.
Expand Down Expand Up @@ -964,9 +1103,9 @@ def _call_wrapped_method(self, fun, args, kwargs):
# call method
if _use_named_call:
with jax.named_scope(_derive_profiling_name(self, fun)):
y = fun(self, *args, **kwargs)
y = run_interceptors(fun, self, *args, **kwargs)
else:
y = fun(self, *args, **kwargs)
y = run_interceptors(fun, self, *args, **kwargs)

if _context.capture_stack:
filter_fn = _context.capture_stack[-1]
Expand Down
Loading

0 comments on commit 24f6b17

Please sign in to comment.