From 25d1d5dc6487259edfc6ccb0718245fa47c5b36c Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Mon, 23 Sep 2024 10:14:31 -0700 Subject: [PATCH] Hide JAX's internal tracing state and update libraries to use limited trace-state-querying APIs as needed. This is prep work for stackless which will change those internals while preserving the API. PiperOrigin-RevId: 677843398 --- flax/core/tracers.py | 9 +++------ flax/nnx/tracers.py | 12 ++++++------ 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/flax/core/tracers.py b/flax/core/tracers.py index 9d8472bdc7..c30d6bd708 100644 --- a/flax/core/tracers.py +++ b/flax/core/tracers.py @@ -20,15 +20,12 @@ def current_trace(): - """Returns the innermost Jax tracer.""" - return jax.core.find_top_trace(()) + """Returns the current JAX state tracer.""" + return jax.extend.core.get_opaque_trace_state(convention="flax") def trace_level(main): - """Returns the level of the trace of -infinity if it is None.""" - if main: - return main.level - return float('-inf') + return main def check_trace_level(base_level): diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py index 3db066376b..3555472b2a 100644 --- a/flax/nnx/tracers.py +++ b/flax/nnx/tracers.py @@ -15,14 +15,14 @@ # Taken from flax/core/tracer.py 🏴‍☠️ -from jax.core import MainTrace, thread_local_state +from jax.extend.core import get_opaque_trace_state, OpaqueTraceState from flax.nnx import reprlib -def current_jax_trace() -> MainTrace: - """Returns the innermost Jax tracer.""" - return thread_local_state.trace_state.trace_stack.dynamic +def current_jax_trace() -> OpaqueTraceState: + """Returns the Jax tracing state.""" + return get_opaque_trace_state(convention="nnx") class TraceState(reprlib.Representable): @@ -36,7 +36,7 @@ def jax_trace(self): return self._jax_trace def is_valid(self) -> bool: - return self._jax_trace is current_jax_trace() + return self._jax_trace == current_jax_trace() def __nnx_repr__(self): yield reprlib.Object(f'{type(self).__name__}') @@ -52,4 +52,4 @@ def __treescope_repr__(self, path, subtree_renderer): ) def __eq__(self, other): - return isinstance(other, TraceState) and self._jax_trace is other._jax_trace + return isinstance(other, TraceState) and self._jax_trace == other._jax_trace