diff --git a/flax/linen/module.py b/flax/linen/module.py index da1ec8b256..fdab1cd93a 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -158,33 +158,15 @@ def _tabulate_context(): # Track parent relationship across Modules. # ----------------------------------------------------------------------------- -class _DynamicContext: +class _DynamicContext(threading.local): """Dynamic context.""" # TODO(marcvanzee): switch to using contextvars once minimum python version is # 3.7 def __init__(self): - self._thread_data = threading.local() - - @property - def module_stack(self): - if not hasattr(self._thread_data, 'module_stack'): - self._thread_data.module_stack = [None,] - return self._thread_data.module_stack - - @property - def capture_stack(self): - """Keeps track of the active capture_intermediates filter functions.""" - if not hasattr(self._thread_data, 'capture_stack'): - self._thread_data.capture_stack = [] - return self._thread_data.capture_stack - - @property - def call_info_stack(self) -> List[_CallInfoContext]: - """Keeps track of the active call_info_context.""" - if not hasattr(self._thread_data, 'call_info_stack'): - self._thread_data.call_info_stack = [] - return self._thread_data.call_info_stack + self.module_stack = [None,] + self.capture_stack = [] + self.call_info_stack = [] # The global context _context = _DynamicContext()