diff --git a/jax/_src/core.py b/jax/_src/core.py index 74d03b8d9464..51933a9f8bbf 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -343,8 +343,7 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, ctx = ctx or JaxprEqnContext( compute_on.current_compute_type(), config.threefry_partitionable.value, - xla_metadata_lib.current_xla_metadata(), - ) + xla_metadata_lib.current_xla_metadata()) if config.enable_checks.value: assert all(isinstance(x, (Var, Literal)) for x in invars) assert all(isinstance(v, Var) for v in outvars) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 2d27bf064fce..374816e001ec 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2828,8 +2828,7 @@ def inline_jaxpr_into_trace( outvars = [Var('', v.aval) for v in eqn.outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) - trace.frame.add_eqn(core.new_jaxpr_eqn(invars, outvars, eqn.primitive, - eqn.params, eqn.effects, src_)) + trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) # type: ignore map(env.setdefault, eqn.outvars, outvars) tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 42a7c966b4d6..34bf257f639e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1800,13 +1800,11 @@ def pjit_staging_rule(trace, *args, **params): params['jaxpr'], params['out_shardings'], params['out_layouts']) params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) - if (params["inline"] and all(is_unspecified(i) for i in params["in_shardings"]) and all(is_unspecified(o) for o in params["out_shardings"]) and all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): - if config.dynamic_shapes.value: # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, diff --git a/tests/memories_test.py b/tests/memories_test.py index 68aecfdf669f..3e0f444a1e66 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -742,6 +742,29 @@ def h(x): self.assertArraysEqual(out2, inp * 6) self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + def test_compute_on_basic_inline(self): + @compute_on('device_host') + @jax.jit + def g(x): + return x * 2 + + @functools.partial(jax.jit, inline=True) + def h(x): + y = g(x) + return y * 3 + + @jax.jit + def f(x): + return h(x) + + inp = jnp.arange(8) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(jnp.arange(8)).as_text('hlo') + self.assertRegex(lowered_text, + 'to_apply=g.*frontend_attributes={_xla_compute_type="host"}') + def test_compute_on_reduction(self): out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host')