diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index d3065d0f96d7..e058ced2cf58 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -33,7 +33,7 @@ from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import util -from jax._src.state.discharge import register_discharge_rule, discharge_state +from jax._src.state.discharge import register_partial_discharge_rule, discharge_state from jax._src.state.types import AbstractRef, RefEffect from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects from jax._src.interpreters import ad @@ -854,19 +854,21 @@ def _cond_lowering(ctx, index, *args, branches): mlir.register_lowering(cond_p, _cond_lowering) -@register_discharge_rule(cond_p) -def _cond_state_discharge_rule(in_avals, out_avals, *args, branches): +@register_partial_discharge_rule(cond_p) +def _cond_state_discharge_rule(in_avals, out_avals, index, *args, branches, should_discharge): discharged_branches = tuple( - core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ()) + core.ClosedJaxpr( + discharge_state(branch.jaxpr, (), + should_discharge=should_discharge[1:])[0], ()) for branch in branches) - out_vals = cond_p.bind(*args, branches=discharged_branches) + out_vals = cond_p.bind(index, *args, branches=discharged_branches) out_vals, out_ref_vals = util.split_list( out_vals, [len(out_avals)]) ref_val_iter = iter(out_ref_vals) new_invals = [] - for aval in in_avals: - new_invals.append( - next(ref_val_iter) if isinstance(aval, AbstractRef) else None) + for should, aval in zip(should_discharge, in_avals): + discharged_inval = isinstance(aval, AbstractRef) and should + new_invals.append(next(ref_val_iter) if discharged_inval else None) return new_invals, out_vals diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 21b522b3d8bb..1f844ad8380d 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -228,18 +228,28 @@ def _for_abstract_eval(*avals, jaxpr, **__): nonlocal_state_effects = core.join_effects(*aval_effects) return list(avals), nonlocal_state_effects -@state_discharge.register_discharge_rule(for_p) +@state_discharge.register_partial_discharge_rule(for_p) def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr, reverse: bool, which_linear: Sequence[bool], - nsteps: int, unroll: int - ) -> tuple[Sequence[Any | None], Sequence[Any]]: - out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse, - which_linear=which_linear, nsteps=nsteps, - unroll=unroll) - new_invals = [] - for aval, out_val in zip(in_avals, out_vals): - new_invals.append(out_val if isinstance(aval, AbstractRef) else None) - return new_invals, out_vals + nsteps: int, unroll: int, + should_discharge:list[bool]) -> tuple[Sequence[Any | None], Sequence[Any]]: + if len(jaxpr.constvars) > 0: + raise NotImplementedError("Constants not supported for discharge.") + + # Jaxpr doesn't return anything in for loops so the return is a tuple of the discharged values. + discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, (), should_discharge=[False, *should_discharge]) + def body(carry, _): + i, *args = carry + outputs = iter(core.eval_jaxpr(discharged_jaxpr, (), i, *args)) + carry_outputs = tuple(next(outputs) if should else arg for arg, should in zip(args, should_discharge)) + carry = (i + 1 if reverse else i - 1, *carry_outputs) + return carry, None + + # New args are also the result type in for_loop + beg = nsteps - 1 if reverse else 0 + (_, *new_args), _ = loops.scan(body, (beg, *args), None, length=nsteps, reverse=reverse, unroll=unroll) + discharged_args = tuple(arg if should else None for arg, should in zip(new_args, should_discharge)) + return discharged_args, new_args def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll): del which_linear diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index 7970440d29a6..76f278a63383 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -97,11 +97,36 @@ def __call__(self, in_avals: Sequence[core.AbstractValue], _discharge_rules: dict[core.Primitive, DischargeRule] = {} +class PartialDischargeRule(Protocol): + """A partial discharge rule. + + Exactly like a discharge rule only it accepts a `should_discharge` + argument that indicates which inputs should be discharged and the + return value returns a tuple of which the first element is the new + inputs or none but only the ones that correspond to `True` entries + in `should_charge`. + """ + + + def __call__(self, in_avals: Sequence[core.AbstractValue], + out_avals: Sequence[core.AbstractValue], *args: Any, + should_discharge=list[bool], + **params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]: + ... + +_partial_discharge_rules: dict[core.Primitive, DischargeRule] = {} + def register_discharge_rule(prim: core.Primitive): def register(f: DischargeRule): _discharge_rules[prim] = f return register +def register_partial_discharge_rule(prim: core.Primitive): + def register(f: DischargeRule): + _partial_discharge_rules[prim] = f + return register + + def _eval_jaxpr_discharge_state( jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any], *args: Any): @@ -116,21 +141,35 @@ def _eval_jaxpr_discharge_state( if d and isinstance(v.aval, AbstractRef)} for eqn in jaxpr.eqns: + should_discharge_arg = [id(v.aval) in refs_to_discharge for v in eqn.invars] if eqn.primitive is core.mutable_array_p: [invar], [outvar] = eqn.invars, eqn.outvars ans = env.read(invar) refs_to_discharge.add(id(outvar.aval)) - elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars) - or core.internal_mutable_array_effect in eqn.effects ): - if eqn.primitive not in _discharge_rules: + elif (any(should_discharge_arg) + or core.internal_mutable_array_effect in eqn.effects + ): + if eqn.primitive in _partial_discharge_rules: + rule = partial(_partial_discharge_rules[eqn.primitive], should_discharge=should_discharge_arg) + elif eqn.primitive in _discharge_rules: + should_discharge_arg = [True] * len(eqn.invars) + rule = _discharge_rules[eqn.primitive] + else: raise NotImplementedError("No state discharge rule implemented for " f"primitive: {eqn.primitive}") invals = map(env.read, eqn.invars) in_avals = [v.aval for v in eqn.invars] out_avals = [v.aval for v in eqn.outvars] - new_invals, ans = _discharge_rules[eqn.primitive]( + new_invals, ans = rule( in_avals, out_avals, *invals, **eqn.params) - for new_inval, invar in zip(new_invals, eqn.invars): + for invar, should, new_inval in zip(eqn.invars, should_discharge_arg, new_invals): + if not should: + if new_inval is not None: + raise ValueError( + f"Did not ask for inval to be discharged but it was. ({invar=}," + f" {new_inval=})" + ) + continue if new_inval is not None: env.write(invar, new_inval) # type: ignore[arg-type] else: diff --git a/tests/state_test.py b/tests/state_test.py index d04a674ab8c0..be288711b50b 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -739,6 +739,41 @@ def f(ref): in_avals = [shaped_array_ref((), jnp.float32)] pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals) + def test_partial_discharge(self): + def f(a_ref, b_ref): + a_ref[...] = jnp.array(0., dtype=jnp.float32) + b_ref[...] = jnp.array(1., dtype=jnp.float32) + return a_ref[...], b_ref[...] + + scalar_ref = shaped_array_ref((), jnp.float32) + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(f), [scalar_ref, scalar_ref]) + + discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True]) + self.assertEqual(discharged_jaxpr.eqns[0].primitive, swap_p) + self.assertEqual(discharged_jaxpr.eqns[1].primitive, get_p) + + def test_partial_for_discharge(self): + def f(a_ref, b_ref): + @partial(for_loop.for_loop, 5, init_state=()) + def _(i, st): + a_ref[...] = 0. + b_ref[...] = 1. + return a_ref[...], b_ref[...] + + ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.)) + jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True]) + # Effects on y_ref were discharged away but not the effects on x_ref + self.assertEqual(f_jaxpr.effects, {ReadEffect(0), WriteEffect(0), ReadEffect(1), WriteEffect(1)}) + self.assertEqual(jaxpr.effects, {ReadEffect(0), WriteEffect(0)}) + # x_ref arg is still a reference but y_ref is discharged + self.assertNotIsInstance(jaxpr.invars[1].aval, AbstractRef) + self.assertIsInstance(jaxpr.invars[0].aval, AbstractRef) + # x_ref value is returned as part of the discharged refs set. + self.assertLen(f_jaxpr.out_avals, 2) + self.assertLen(jaxpr.outvars, 3) + if CAN_USE_HYPOTHESIS: @@ -1061,6 +1096,27 @@ def false_fun(): out = jax.jit(f)(False) self.assertTupleEqual(out, (0., 5.)) + def test_cond_discharge(self): + def f0(pred, x_ref, y_ref): + def true_fun(): + x_ref[...] = 1. + def false_fun(): + y_ref[...] = 2. + lax.cond(pred, true_fun, false_fun) + return x_ref[...], y_ref[...] + ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x))) + f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.)) + jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True]) + # Effects on y_ref were discharged away but not the effects on x_ref + self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)}) + self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)}) + # x_ref arg is still a reference but y_ref is discharged + self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef) + self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef) + # x_ref value is returned as part of the discharged refs set. + self.assertLen(f_jaxpr.out_avals, 2) + self.assertLen(jaxpr.outvars, 3) + def test_cond_with_ref_reuse(self): def f(pred): def body(x_ref):