diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index d3065d0f96d7..e058ced2cf58 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -33,7 +33,7 @@ from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import util -from jax._src.state.discharge import register_discharge_rule, discharge_state +from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects from jax._src.interpreters import ad @@ -854,19 +854,21 @@ def _cond_lowering(ctx, index, *args, branches): mlir.register_lowering(cond_p, _cond_lowering) -@register_discharge_rule(cond_p) -def _cond_state_discharge_rule(in_avals, out_avals, *args, branches): +@register_partial_discharge_rule(cond_p) +def _cond_state_discharge_rule(in_avals, out_avals, index, *args, branches, should_discharge): discharged_branches = tuple( - core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ()) + core.ClosedJaxpr( + discharge_state(branch.jaxpr, (), + should_discharge=should_discharge[1:])[0], ()) for branch in branches) - out_vals = cond_p.bind(*args, branches=discharged_branches) + out_vals = cond_p.bind(index, *args, branches=discharged_branches) out_vals, out_ref_vals = util.split_list( out_vals, [len(out_avals)]) ref_val_iter = iter(out_ref_vals) new_invals = [] - for aval in in_avals: - new_invals.append( - next(ref_val_iter) if isinstance(aval, AbstractRef) else None) + for should, aval in zip(should_discharge, in_avals): + discharged_inval = isinstance(aval, AbstractRef) and should + new_invals.append(next(ref_val_iter) if discharged_inval else None) return new_invals, out_vals diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 7970440d29a6..bc2ba997c107 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -97,11 +97,36 @@ def __call__(self, in_avals: Sequence[core.AbstractValue], _discharge_rules: dict[core.Primitive, DischargeRule] = {} +class PartialDischargeRule(Protocol): + """A partial discharge rule. + + Exactly like a discharge rule only it accepts a `should_discharge` + argument that indicates which inputs should be discharged and the + return value returns a tuple of which the first element is the new + inputs or none but only the ones that correspond to `True` entries + in `should_charge`. + """ + + + def __call__(self, in_avals: Sequence[core.AbstractValue], + out_avals: Sequence[core.AbstractValue], *args: Any, + should_discharge=list[bool], + **params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]: + ... + +_partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {} + def register_discharge_rule(prim: core.Primitive): def register(f: DischargeRule): _discharge_rules[prim] = f return register +def register_partial_discharge_rule(prim: core.Primitive): + def register(f: DischargeRule): + _partial_discharge_rules[prim] = f + return register + + def _eval_jaxpr_discharge_state( jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any], *args: Any): @@ -116,21 +141,35 @@ def _eval_jaxpr_discharge_state( if d and isinstance(v.aval, AbstractRef)} for eqn in jaxpr.eqns: + should_discharge_arg = [id(v.aval) in refs_to_discharge for v in eqn.invars] if eqn.primitive is core.mutable_array_p: [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) refs_to_discharge.add(id(outvar.aval)) - elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars) - or core.internal_mutable_array_effect in eqn.effects ): - if eqn.primitive not in _discharge_rules: + elif (any(should_discharge_arg) + or core.internal_mutable_array_effect in eqn.effects + ): + if eqn.primitive in _partial_discharge_rules: + rule = partial(_partial_discharge_rules[eqn.primitive], should_discharge=should_discharge_arg) # type: ignore + elif eqn.primitive in _discharge_rules: + should_discharge_arg = [True] * len(eqn.invars) + rule = _discharge_rules[eqn.primitive] + else: raise NotImplementedError("No state discharge rule implemented for " f"primitive: {eqn.primitive}") invals = map(env.read, eqn.invars) in_avals = [v.aval for v in eqn.invars] out_avals = [v.aval for v in eqn.outvars] - new_invals, ans = _discharge_rules[eqn.primitive]( + new_invals, ans = rule( in_avals, out_avals, *invals, **eqn.params) - for new_inval, invar in zip(new_invals, eqn.invars): + for invar, should, new_inval in zip(eqn.invars, should_discharge_arg, new_invals): + if not should: + if new_inval is not None: + raise ValueError( + f"Did not ask for inval to be discharged but it was. ({invar=}," + f" {new_inval=})" + ) + continue if new_inval is not None: env.write(invar, new_inval) # type: ignore[arg-type] else: diff --git a/tests/state_test.py b/tests/state_test.py index d04a674ab8c0..63449985e3bf 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -739,6 +739,20 @@ def f(ref): in_avals = [shaped_array_ref((), jnp.float32)] pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) + def test_partial_discharge(self): + def f(a_ref, b_ref): + a_ref[...] = jnp.array(0., dtype=jnp.float32) + b_ref[...] = jnp.array(1., dtype=jnp.float32) + return a_ref[...], b_ref[...] + + scalar_ref = shaped_array_ref((), jnp.float32) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(f), [scalar_ref, scalar_ref]) + + discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) + self.assertEqual(discharged_jaxpr.eqns[0].primitive, swap_p) + self.assertEqual(discharged_jaxpr.eqns[1].primitive, get_p) + if CAN_USE_HYPOTHESIS: @@ -1061,6 +1075,27 @@ def false_fun(): out = jax.jit(f)(False) self.assertTupleEqual(out, (0., 5.)) + def test_cond_discharge(self): + def f0(pred, x_ref, y_ref): + def true_fun(): + x_ref[...] = 1. + def false_fun(): + y_ref[...] = 2. + lax.cond(pred, true_fun, false_fun) + return x_ref[...], y_ref[...] + ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) + jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) + # Effects on y_ref were discharged away but not the effects on x_ref + self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)}) + self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)}) + # x_ref arg is still a reference but y_ref is discharged + self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef) + self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef) + # x_ref value is returned as part of the discharged refs set. + self.assertLen(f_jaxpr.out_avals, 2) + self.assertLen(jaxpr.outvars, 3) + def test_cond_with_ref_reuse(self): def f(pred): def body(x_ref):