Skip to content

Commit

Permalink
Directly call original method if method interceptor stack is empty.
Browse files Browse the repository at this point in the history
So the run_interceptors call won't show up in the stack trace.

PiperOrigin-RevId: 556845700
  • Loading branch information
JXRiver authored and Flax Authors committed Aug 15, 2023
1 parent 3445296 commit bd36c53
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,10 @@ class ThreadLocalStack(threading.local):
def __init__(self):
self._storage = []

def push(self, elem: Any):
def push(self, elem: Any) -> None:
self._storage.append(elem)

def pop(self):
def pop(self) -> Any:
return self._storage.pop()

def __iter__(self) -> Iterator[Any]:
Expand Down Expand Up @@ -412,10 +412,7 @@ def run_interceptors(
*args,
**kwargs,
) -> Any:
"""Runs method interceptors or `orig_method`."""
if not _global_interceptor_stack:
return orig_method(module, *args, **kwargs)

"""Runs method interceptors."""
method_name = _get_fn_name(orig_method)
fun = functools.partial(orig_method, module)
context = InterceptorContext(module, method_name, fun)
Expand Down Expand Up @@ -1119,12 +1116,17 @@ def _call_wrapped_method(self, fun, args, kwargs):
call_index = _context.call_info_stack[-1].get_call_index(self)
scope_path = jax.tree_util.tree_map(_fix_path_part, self.scope.path)

if _global_interceptor_stack:
run_fun = functools.partial(run_interceptors, fun)
else:
run_fun = fun

# call method
if _use_named_call:
with jax.named_scope(_derive_profiling_name(self, fun)):
y = run_interceptors(fun, self, *args, **kwargs)
y = run_fun(self, *args, **kwargs)
else:
y = run_interceptors(fun, self, *args, **kwargs)
y = run_fun(self, *args, **kwargs)

if _context.capture_stack:
filter_fn = _context.capture_stack[-1]
Expand Down

0 comments on commit bd36c53

Please sign in to comment.