Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stackless [WIP] #23299

Draft
wants to merge 204 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
204 commits
Select commit Hold shift + click to select a range
cf94b57
controlled burn
dougalm Jul 8, 2024
1ae1fe9
CoreTest.test_jit passing
dougalm Jul 9, 2024
2a930e0
CoreTest.test_jvp passing
dougalm Jul 9, 2024
06d48fe
Most of the rest of core tests passing
dougalm Jul 9, 2024
ec60fb0
more
dougalm Jul 9, 2024
5825377
vmap
dougalm Jul 10, 2024
c806225
More vmap
dougalm Jul 10, 2024
1f960d8
Some custom vjp/jvp
dougalm Jul 15, 2024
f848128
Batching custom vjp
dougalm Jul 16, 2024
6099f29
More custom vjp tests
dougalm Jul 16, 2024
38fae8f
more custom jvp
dougalm Jul 16, 2024
fdb6975
some control flow tests
dougalm Jul 17, 2024
bcedd70
Remove custom bind
dougalm Jul 17, 2024
8f0d867
more control flow tests
dougalm Jul 17, 2024
bfc9f9e
WIP - mattjj to fix :)
dougalm Jul 19, 2024
e86ad4a
fix all_gather on constants
mattjj Jul 20, 2024
a3ddef6
fix up some confusing logic in all-reduce primitives
mattjj Jul 21, 2024
a16ff7b
update all_to_all batcher to take axis_data
mattjj Jul 21, 2024
8444027
one failing test left!
mattjj Jul 21, 2024
57a4bac
progress on control_flow_test.py
mattjj Jul 24, 2024
ecbe339
attrs
dougalm Jul 24, 2024
87a7c2b
checkify
dougalm Jul 24, 2024
a687f16
more custom jvp
dougalm Jul 24, 2024
2483cda
Merge remote-tracking branch 'origin/main' into stackless
mattjj Jul 25, 2024
f3aaa14
fix
mattjj Jul 25, 2024
21627e9
more
dougalm Jul 25, 2024
b45ca67
short-circuit psum path and some pmap tests passing
dougalm Jul 26, 2024
753feb8
more pmap
dougalm Jul 26, 2024
22f6f63
Merge remote-tracking branch 'origin/main' into stackless
mattjj Jul 27, 2024
681f70e
fix bug introduced in merge
mattjj Jul 27, 2024
00f4495
fix custom_solve batching rule with axis_data
mattjj Jul 27, 2024
e23b8b1
final style core_test! test_jit_43 done!!!
mattjj Jul 29, 2024
f8194e2
DynamicJaxprTrace.process_call should call to_jaxpr_tracer
mattjj Jul 30, 2024
9777c27
fix refs to DynamicJaxprTrace.instantiate_const
dougalm Jul 30, 2024
39ad05c
Add a `take_current_trace ctx manager to mostly replace find_cur_trac…
dougalm Jul 30, 2024
5cfb8e1
started working on shard_map, one test working
mattjj Jul 30, 2024
7dcd2b2
Merge branch 'stackless' of github.com:google/jax into stackless
dougalm Jul 30, 2024
7c8010a
Shard map tests
dougalm Jul 30, 2024
1172dd0
batching tests
dougalm Jul 31, 2024
67a15ff
Avoid extending the axis mesh twice (todo: figure out a proper story …
dougalm Aug 1, 2024
9c6ca68
Custom vmap tests
dougalm Aug 5, 2024
ae0b448
custom transpose tests
dougalm Aug 5, 2024
b75ac75
more custom vjp/jvp rules
dougalm Aug 5, 2024
18af559
more custom ad
dougalm Aug 5, 2024
db38066
Avoid adding no_axis_name to env
dougalm Aug 5, 2024
612cd17
reinstate leak checker
dougalm Aug 6, 2024
450282c
avoid false positive tracer leak errors due to frame.constid_to_tracer
dougalm Aug 6, 2024
038217f
jet
dougalm Aug 6, 2024
778b3bb
sparsify tests
dougalm Aug 6, 2024
b937243
pmap tests
dougalm Aug 6, 2024
3bf4a2a
more pmap
dougalm Aug 6, 2024
cf7da64
pallas call tests
dougalm Aug 7, 2024
3ab29f5
pmap tests
dougalm Aug 7, 2024
0136695
more pmap
dougalm Aug 7, 2024
b7244b2
callback tests
dougalm Aug 8, 2024
4a624d8
fix some pmap tests, 161 -> 126 pmap_test failures
mattjj Aug 8, 2024
25e105b
remove breakpoints (sorry)
mattjj Aug 8, 2024
9b428ae
sparse transform infinite recursion
dougalm Aug 9, 2024
babe549
use correct tangent space dtype for zero in sparse array constructors
dougalm Aug 9, 2024
9ed34db
Avoid cyclic refs so we don't have to run gc during trace leak checking
dougalm Aug 9, 2024
2f8c03f
fix all non-eager pmap tests
mattjj Aug 9, 2024
1320052
fix pmap bind logic, down to 24 failing pmap tests
mattjj Aug 10, 2024
44087f5
fix eager pmap axis_index
mattjj Aug 10, 2024
af30a2a
fix nested eager pmap
mattjj Aug 10, 2024
bbad195
more eager pmap fixes
mattjj Aug 10, 2024
59ea325
fix last of the eager pmap tests
mattjj Aug 10, 2024
f382187
skip some for_loop tests that were timing out
dougalm Aug 11, 2024
5ab7e3c
Extra short-circuit for not-implemented batching rules
dougalm Aug 12, 2024
28e9383
Add a special case for DimExpr
dougalm Aug 12, 2024
e83bb5b
set current trace in pjit dynamic shapes staging rule
dougalm Aug 12, 2024
4bf0fe1
more dynamic_api_test
dougalm Aug 12, 2024
5a40ee5
reset trace state between tests
dougalm Aug 12, 2024
0cf761c
pallas tests
dougalm Aug 13, 2024
f52aa26
skipping infeed tests
dougalm Aug 13, 2024
3120391
xla_computation
dougalm Aug 13, 2024
26e5fb5
dead trace invalidation
dougalm Aug 13, 2024
3442473
Reset name stack when tracing out a jaxpr
dougalm Aug 14, 2024
ccf9877
Add the eval context back during callbacks
dougalm Aug 14, 2024
296b299
Custom vjp/jvp closing over stuff
dougalm Aug 14, 2024
fa29f36
pallas tests
dougalm Aug 14, 2024
bdd18c4
Sparse transform test
dougalm Aug 14, 2024
99d905f
mutable arrays
dougalm Aug 14, 2024
30376dd
test fix
dougalm Aug 14, 2024
77c09f1
effects test
dougalm Aug 14, 2024
b2edd22
Check whether a tracer's trace is invalidated
dougalm Aug 26, 2024
aeae13d
tweak escaped tracer tests
dougalm Aug 26, 2024
16af376
Remove CustomJVPException (only relevant for post_process_call)
dougalm Aug 26, 2024
10f71ca
more tweaks to escaped tracer tests
dougalm Aug 26, 2024
9015af1
Use primal-type-to-tangent-type mapping when creating symbolic zeros …
dougalm Aug 28, 2024
6b9b191
wip
dougalm Aug 28, 2024
f4200ba
Merge branch 'main' into stackless
dougalm Aug 28, 2024
c1f871a
fix aval error
dougalm Aug 28, 2024
4944e2f
update remat opt
dougalm Aug 28, 2024
2d854f5
Use correct tangent space avals in custom jvp thunk
dougalm Aug 29, 2024
5d427aa
move promotion-to-float outside of logaddexp primitive to avoid float…
dougalm Aug 29, 2024
df488b0
update attrs to use take_current_trace
dougalm Aug 30, 2024
3a56e67
Organize tracing context, hoping to get caching working again
dougalm Sep 3, 2024
eb64ddc
Add axis size to cache key
dougalm Sep 3, 2024
4ede817
Use weak refs to traces in cache key
dougalm Sep 3, 2024
6ea0c68
Fix
dougalm Sep 3, 2024
7354526
fixes
dougalm Sep 3, 2024
1a32b9a
set eval trace for key reuse checker
dougalm Sep 3, 2024
67c38cf
set trace for dimension_as_value
dougalm Sep 3, 2024
420f40d
fix pallas primitive custom bind logic
dougalm Sep 3, 2024
46af63b
Fix attribute errors
dougalm Sep 3, 2024
4167de0
Add back partial_eval._memoize with a patently incorrect new implemen…
dougalm Sep 4, 2024
12cd781
Use an empty context when forcing jaxpr thunks
dougalm Sep 4, 2024
2566b6a
fix invalidation negation
dougalm Sep 4, 2024
acd4bb7
unexpected tracer error messages
dougalm Sep 4, 2024
49fc277
small fixes
dougalm Sep 4, 2024
5351c62
tangent dtype conversion in test
dougalm Sep 4, 2024
34e1fe9
Add back a deleted check in `bind`
dougalm Sep 4, 2024
7159909
revert - need to handle function arguments
dougalm Sep 4, 2024
e89afc8
don't tempt fate by creating jvp tracers with symbolic zero tangents …
dougalm Sep 4, 2024
64d10d7
more float0
dougalm Sep 4, 2024
2e96464
skip a test
dougalm Sep 5, 2024
b237bd2
Merge branch 'main' into stackless
dougalm Sep 5, 2024
f26c578
Add back _convert_element_type custom bind
dougalm Sep 5, 2024
b9f1f23
deleted xmap cruft
dougalm Sep 5, 2024
69cf62f
more deletion
dougalm Sep 5, 2024
2abe624
delete whitespace
dougalm Sep 5, 2024
d0a28a0
Fix pytype errors
dougalm Sep 6, 2024
500cb43
Put `TraceTag` in core
dougalm Sep 6, 2024
0155827
fix
dougalm Sep 6, 2024
2d03ea6
more
dougalm Sep 6, 2024
75be2e3
process_custom_jvp for rewrite trace
dougalm Sep 6, 2024
f22d8c0
shard map rewrite custom_vjp
dougalm Sep 6, 2024
9a8beab
shard_map batching
dougalm Sep 6, 2024
d242e81
more shard map rewrites
dougalm Sep 6, 2024
6926716
even more rewrite
dougalm Sep 6, 2024
e32c5b9
axis env
dougalm Sep 7, 2024
7e5a6b4
axis index custom bind
dougalm Sep 9, 2024
4f2fc87
ptype errors
dougalm Sep 9, 2024
ed32420
lint errors
dougalm Sep 9, 2024
9dee9f1
fix to lint fix
dougalm Sep 9, 2024
49e5c57
revert core test monkey patch
dougalm Sep 9, 2024
ceefd62
remove breakpoints
dougalm Sep 9, 2024
a591af0
tweak docstring for test
dougalm Sep 9, 2024
aa5bcbd
fix pytype error
dougalm Sep 9, 2024
bacd510
keep mypy happy by removing type:ignore ?
dougalm Sep 9, 2024
ee130dd
Merge branch 'main' into stackless
dougalm Sep 9, 2024
8dd89ac
helper method for overriding bind_with_trace
dougalm Sep 10, 2024
873550a
comment out jax2tf stuff that's not ready yet to satisfy mypy
dougalm Sep 10, 2024
640ea0f
reword a line to try to make mypy happy?
dougalm Sep 10, 2024
788c37a
check for shapedarray
dougalm Sep 10, 2024
c83e492
mypy
dougalm Sep 10, 2024
758879d
jax2tf
dougalm Sep 10, 2024
fe49a2c
Merge branch 'main' into stackless
dougalm Sep 10, 2024
96d751b
fix all_to_all batching rule
dougalm Sep 10, 2024
612e941
tf tracing
dougalm Sep 11, 2024
d3bf703
missed one
dougalm Sep 11, 2024
0f52c8a
trace tag hash hack
dougalm Sep 12, 2024
2af9167
Merge branch 'main' into stackless
dougalm Sep 13, 2024
6b674c2
lint
dougalm Sep 13, 2024
6e8828d
skip test
dougalm Sep 13, 2024
885b3a0
Add leak checker to batch tracing
dougalm Sep 13, 2024
874fa26
oof leak checker false positives. this work?
dougalm Sep 13, 2024
3676cc9
skip const-forwarding test
dougalm Sep 13, 2024
fe3e6b4
tweak a test
dougalm Sep 13, 2024
b7cfa55
use tangent type in custom lin
dougalm Sep 14, 2024
8f2c429
Avoid blowing away custom jvp rule during partial eval
dougalm Sep 16, 2024
65e0922
Merge branch 'main' into stackless
dougalm Sep 16, 2024
bef2066
unused import
dougalm Sep 16, 2024
2604e50
Rename `at_least_vspace` -> `to_tangent_aval`, `from_value` -> `from_…
dougalm Sep 16, 2024
227e50f
Fix return type
dougalm Sep 16, 2024
62b3ac1
Use tangent types in appropriate places
dougalm Sep 16, 2024
be0cc23
fix tests (remove replace/recast float0)
dougalm Sep 16, 2024
00abb18
missed one
dougalm Sep 16, 2024
c03df59
fix another dtype bug in tests
dougalm Sep 16, 2024
99760a0
Maybe we can put this check back (but really I just want to trigger c…
dougalm Sep 17, 2024
a5c4459
revert
dougalm Sep 17, 2024
7d5c660
Merge branch 'tangent-avals' into stackless
dougalm Sep 17, 2024
b8c17ca
docstring tweak for test
dougalm Sep 17, 2024
c80594f
tweak xla_metadata_test to avoid using impl path directly
dougalm Sep 17, 2024
a4117e2
Add back NamedAxisEffect
dougalm Sep 18, 2024
b17e5af
fix
dougalm Sep 18, 2024
1b8d743
named axis reversion fixes
dougalm Sep 18, 2024
13139fe
more fixes
dougalm Sep 18, 2024
8b72dae
Merge branch 'main' into stackless
dougalm Sep 18, 2024
f9dd1e6
missed a filter_named_axis_effects
dougalm Sep 19, 2024
a131df7
remove dead code
dougalm Sep 19, 2024
cc5c9d6
Merge branch 'main' into stackless
dougalm Sep 19, 2024
ca6e762
Merge branch 'main' into stackless
dougalm Sep 19, 2024
f9b3c49
fix bad merge
dougalm Sep 20, 2024
13eb6e2
rearrange to shrink diff a bit
dougalm Sep 20, 2024
d5f6166
more diff tweaks
dougalm Sep 20, 2024
efd6c21
more minor stuff
dougalm Sep 20, 2024
3d3dd93
fix
dougalm Sep 20, 2024
e867e66
Update batching axis size as you go under a shard map
dougalm Sep 21, 2024
a32ab6a
Merge branch 'main' into stackless
dougalm Sep 21, 2024
3dffee1
more batching cleanup
dougalm Sep 21, 2024
2df3c20
more batching
dougalm Sep 21, 2024
d8303a4
more batching
dougalm Sep 21, 2024
abf4d98
fix
dougalm Sep 21, 2024
213733d
jax2tf fix
dougalm Sep 21, 2024
44709e0
Merge branch 'main' into stackless
dougalm Sep 23, 2024
74d21ec
small backwards compat shims
dougalm Sep 23, 2024
b0533c3
fix merge
dougalm Sep 23, 2024
6d068f2
Merge branch 'main' into stackless
dougalm Sep 24, 2024
63523c7
fix bad merge
dougalm Sep 25, 2024
5df8920
more fix
dougalm Sep 25, 2024
0b24498
implement unsafe trace-querying APIs
dougalm Sep 25, 2024
23d3e01
lint
dougalm Sep 25, 2024
6b90f6d
Merge branch 'main' into stackless
dougalm Sep 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,3 @@ jobs:
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py

14 changes: 6 additions & 8 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,8 @@ def print_saved_residuals(f, *args, **kwargs):
@remat_p.def_impl
def remat_impl(*args, jaxpr, prevent_cse, differentiated, policy):
del prevent_cse, differentiated, policy # Unused.
return core.eval_jaxpr(jaxpr, (), *args)
with core.concrete_eval():
return core.eval_jaxpr(jaxpr, (), *args)

@remat_p.def_effectful_abstract_eval
def remat_abstract_eval(*args, jaxpr, prevent_cse, differentiated, policy):
Expand Down Expand Up @@ -701,20 +702,17 @@ def transposed(*args_flat):
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error

def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
jaxpr, **params):
def remat_vmap(axis_data, args, dims, *, jaxpr, **params):
assert not jaxpr.constvars
jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
pe.close_jaxpr(jaxpr), axis_size, dims,
[batching.zero_if_mapped] * len(jaxpr.outvars),
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
pe.close_jaxpr(jaxpr), axis_data, dims,
[batching.zero_if_mapped] * len(jaxpr.outvars))
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
if consts:
jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched)
out_dims = [0 if b else None for b in out_batched]
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None)
batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap
batching.fancy_primitive_batchers[remat_p] = remat_vmap

# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array:

@add_jaxvals_p.def_impl
def add_impl(x, y):
return raw_jaxval_adders[type(x)](x, y)
with core.set_current_trace(core.EvalTrace()):
return raw_jaxval_adders[type(x)](x, y )
raw_jaxval_adders = {} # type: ignore

@add_jaxvals_p.def_abstract_eval
Expand Down
32 changes: 13 additions & 19 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import weakref

import numpy as np
from contextlib import contextmanager, ExitStack
from contextlib import contextmanager

from jax._src import linear_util as lu
from jax._src import stages
Expand Down Expand Up @@ -983,10 +983,10 @@ def vmap_f(*args, **kwargs):
axis_size_ = (axis_size if axis_size is not None else
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
try:
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name)
out_flat = batching.batch(
flat_fun, axis_name, axis_size_, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
spmd_axis_name=spmd_axis_name
flat_fun, axis_data, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
).call_wrapped(*args_flat)
except batching.SpecMatchError as e:
out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)
Expand Down Expand Up @@ -1532,16 +1532,13 @@ def cache_miss(*args, **kwargs):
is_explicit_global_axis_size=p.is_explicit_global_axis_size,
)

map_bind_continuation, top_trace, fun_, tracers, params = (
core.map_bind_with_continuation(pxla.xla_pmap_p, p.flat_fun,
*p.flat_args, **params))
execute: Callable | None = None
if isinstance(top_trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
out = map_bind_continuation(execute(*tracers))
else:
out = map_bind_continuation(
pxla.xla_pmap_p.process(top_trace, fun_, tracers, params))
with core.take_current_trace() as trace:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)

out_tree, out_flat = p.out_tree, out
out_pytree_def = out_tree()
Expand Down Expand Up @@ -1788,7 +1785,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(Array(3.26819, dtype=float32, weak_type=True), Array(-5.00753, dtype=float32, weak_type=True))
(Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
Expand Down Expand Up @@ -2146,11 +2143,8 @@ def make_jaxpr(
@wraps(fun)
@api_boundary
def make_jaxpr_f(*args, **kwargs):
with ExitStack() as stack:
for axis_name, size in axis_env or []:
stack.enter_context(core.extend_axis_env(axis_name, size, None))
traced = jit(fun, static_argnums=static_argnums,
abstracted_axes=abstracted_axes).trace(*args, **kwargs)
traced = jit(fun, static_argnums=static_argnums,
abstracted_axes=abstracted_axes).trace(*args, **kwargs)
# `jit` converts tracers in consts to args but that breaks the semantics of
# `make_jaxpr`. Hence convert the tracers in args back to consts in jaxpr.
if traced._num_consts:
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,8 @@ def shard_device_array(x, devices, indices, sharding):
if sharding.is_fully_replicated:
shards = [x] * len(devices)
else:
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
with core.set_current_trace(core.EvalTrace()):
shards = x._multi_slice(start_indices, limit_indices, removed_dims)
aval = api_util.shaped_abstractify(x)
return pxla.batched_device_put(aval, sharding, shards, devices)

Expand Down
7 changes: 4 additions & 3 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def pure_callback_lowering(
ctx, *args, callback: _FlatCallback, sharding: SingleDeviceSharding | None, **params
):
def _callback(*flat_args):
return tuple(
with core.concrete_eval():
return tuple(
pure_callback_impl(
*flat_args,
callback=callback,
Expand Down Expand Up @@ -429,7 +430,8 @@ def _batch_fun(batched_args):

def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
def _callback(*flat_args):
return tuple(
with core.concrete_eval():
return tuple(
io_callback_impl(
*flat_args,
callback=callback,
Expand Down Expand Up @@ -511,7 +513,6 @@ def io_callback(
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
flat_result_avals = map(lambda x: core.ShapedArray(x.shape, x.dtype),
flat_shape_dtypes)
flat_args = map(core.raise_as_much_as_possible, flat_args)
out_flat = io_callback_p.bind(
*flat_args,
callback=_FlatCallback(callback, in_tree),
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def shard_map_error_check(

if not isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = core.ClosedJaxpr(jaxpr, ())
with core.extend_axis_env_nd(mesh.shape.items()):
with core.extend_axis_env(mesh.shape.items()):
# jaxpr to checked_jaxpr
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
jaxpr, enabled_errors, err_tree, *in_avals
Expand All @@ -966,7 +966,7 @@ def expand_errors_leading_dim(*xs):
errs = [lax.expand_dims(e, [0]) for e in errs]
return *errs, *outs

with core.extend_axis_env_nd(mesh.shape.items()):
with core.extend_axis_env(mesh.shape.items()):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
expand_errors_leading_dim, checked_jaxpr.in_avals
)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
The initialization, which uses both config.py and core.py is done using
`_update_thread_local_jit_state` in core.py to prevent circular imports.
"""
dynamic_trace_state: Any | None = None
trace_state: Any | None = None
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
compute_on_context_manager: Hashable = ()
Expand Down
Loading
Loading