Skip to content

Commit

Permalink
Fix jaxpr equation context propagation in jaxpr equations when `inlin…
Browse files Browse the repository at this point in the history
…e=True`.

PiperOrigin-RevId: 675754808
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Sep 17, 2024
1 parent 86fe463 commit 8b5b717
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
3 changes: 1 addition & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,7 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata(),
)
xla_metadata_lib.current_xla_metadata())
if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2828,8 +2828,7 @@ def inline_jaxpr_into_trace(
outvars = [Var('', v.aval) for v in eqn.outvars]
src_ = (src if not eqn.source_info.name_stack else
src.replace(name_stack=src.name_stack + eqn.source_info.name_stack))
trace.frame.add_eqn(core.new_jaxpr_eqn(invars, outvars, eqn.primitive,
eqn.params, eqn.effects, src_))
trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) # type: ignore
map(env.setdefault, eqn.outvars, outvars)

tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars],
Expand Down
2 changes: 0 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,13 +1800,11 @@ def pjit_staging_rule(trace, *args, **params):
params['jaxpr'], params['out_shardings'], params['out_layouts'])
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
out_layouts=out_layouts)

if (params["inline"] and
all(is_unspecified(i) for i in params["in_shardings"]) and
all(is_unspecified(o) for o in params["out_shardings"]) and
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):

if config.dynamic_shapes.value:
# Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
# shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
Expand Down
23 changes: 23 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,29 @@ def h(x):
self.assertArraysEqual(out2, inp * 6)
self.assertEqual(out2.sharding.memory_kind, 'pinned_host')

def test_compute_on_basic_inline(self):
@compute_on('device_host')
@jax.jit
def g(x):
return x * 2

@functools.partial(jax.jit, inline=True)
def h(x):
y = g(x)
return y * 3

@jax.jit
def f(x):
return h(x)

inp = jnp.arange(8)
out = f(inp)
self.assertArraysEqual(out, inp * 6)

lowered_text = f.lower(jnp.arange(8)).as_text('hlo')
self.assertRegex(lowered_text,
'to_apply=g.*frontend_attributes={_xla_compute_type="host"}')

def test_compute_on_reduction(self):
out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')

Expand Down

0 comments on commit 8b5b717

Please sign in to comment.