Skip to content

Commit

Permalink
Hide JAX's internal tracing state and update libraries to use limited…
Browse files Browse the repository at this point in the history
… trace-state-querying APIs as needed. This is prep work for stackless which will change those internals while preserving the API.

PiperOrigin-RevId: 677843398
  • Loading branch information
dougalm authored and Flax Authors committed Sep 23, 2024
1 parent b2277ab commit 3e1749a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
9 changes: 3 additions & 6 deletions flax/core/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@


def current_trace():
"""Returns the innermost Jax tracer."""
return jax.core.find_top_trace(())
"""Returns the current JAX state tracer."""
return jax.core.get_opaque_trace_state()


def trace_level(main):
"""Returns the level of the trace of -infinity if it is None."""
if main:
return main.level
return float('-inf')
return main


def check_trace_level(base_level):
Expand Down
12 changes: 6 additions & 6 deletions flax/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# Taken from flax/core/tracer.py 🏴‍☠️


from jax.core import MainTrace, thread_local_state
from jax.core import get_opaque_trace_state, OpaqueTraceState

from flax.nnx import reprlib


def current_jax_trace() -> MainTrace:
"""Returns the innermost Jax tracer."""
return thread_local_state.trace_state.trace_stack.dynamic
def current_jax_trace() -> OpaqueTraceState:
"""Returns the Jax tracing state."""
return get_opaque_trace_state()


class TraceState(reprlib.Representable):
Expand All @@ -36,7 +36,7 @@ def jax_trace(self):
return self._jax_trace

def is_valid(self) -> bool:
return self._jax_trace is current_jax_trace()
return self._jax_trace == current_jax_trace()

def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
Expand All @@ -52,4 +52,4 @@ def __treescope_repr__(self, path, subtree_renderer):
)

def __eq__(self, other):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace
return isinstance(other, TraceState) and self._jax_trace == other._jax_trace

0 comments on commit 3e1749a

Please sign in to comment.