Skip to content

Commit

Permalink
Update libraries to use JAX's limited (and ill-advised) trace-state-q…
Browse files Browse the repository at this point in the history
…uerying APIs rather than depending on JAX's deeper internals, which are about to change.

PiperOrigin-RevId: 678351335
  • Loading branch information
dougalm authored and Flax Authors committed Sep 24, 2024
1 parent e3772b2 commit 6d2355d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
3 changes: 1 addition & 2 deletions flax/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@
from .tracers import (
check_trace_level as check_trace_level,
current_trace as current_trace,
trace_level as trace_level,
)

from flax.typing import (
Array as Array,
)
)
2 changes: 1 addition & 1 deletion flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def __init__(
self.flags = freeze({} if flags is None else flags)

self._root = parent.root if parent else None
self.trace_level = tracers.trace_level(tracers.current_trace())
self.trace_level = tracers.current_trace()

self.rng_counters = {key: 0 for key in self.rngs}
self.reservations = collections.defaultdict(set)
Expand Down
19 changes: 9 additions & 10 deletions flax/core/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@


def current_trace():
"""Returns the innermost Jax tracer."""
return jax.core.find_top_trace(())


def trace_level(main):
"""Returns the level of the trace of -infinity if it is None."""
if main:
return main.level
return float('-inf')
"""Returns the current JAX state tracer."""
if jax.__version_info__ <= (0, 4, 33):
top = jax.core.find_top_trace(())
if top:
return top.level
else:
return float('-inf')

return jax.core.get_opaque_trace_state(convention="flax")

def check_trace_level(base_level):
level = trace_level(current_trace())
level = current_trace()
if level != base_level:
raise errors.JaxTransformError()
22 changes: 16 additions & 6 deletions flax/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
# Taken from flax/core/tracer.py 🏴‍☠️


from jax.core import MainTrace, thread_local_state
import jax
import jax.core
from jax.core import get_opaque_trace_state, OpaqueTraceState

from flax.nnx import reprlib


def current_jax_trace() -> MainTrace:
"""Returns the innermost Jax tracer."""
return thread_local_state.trace_state.trace_stack.dynamic
def current_jax_trace() -> OpaqueTraceState:
"""Returns the Jax tracing state."""
if jax.__version_info__ <= (0, 4, 33):
return jax.core.thread_local_state.trace_state.trace_stack.dynamic
return get_opaque_trace_state(convention="nnx")


class TraceState(reprlib.Representable):
Expand All @@ -36,7 +40,10 @@ def jax_trace(self):
return self._jax_trace

def is_valid(self) -> bool:
return self._jax_trace is current_jax_trace()
if jax.__version_info__ <= (0, 4, 33):
return self._jax_trace is current_jax_trace()

return self._jax_trace == current_jax_trace()

def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
Expand All @@ -52,4 +59,7 @@ def __treescope_repr__(self, path, subtree_renderer):
)

def __eq__(self, other):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace
if jax.__version_info__ <= (0, 4, 33):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace

return isinstance(other, TraceState) and self._jax_trace == other._jax_trace

0 comments on commit 6d2355d

Please sign in to comment.