Skip to content

Commit

Permalink
Introducing partial discharge rules and implementations for cond_p
Browse files Browse the repository at this point in the history
As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.

This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.

Here we implement the partial discharge rule for `cond_p` and will implement it
for others in due course.

PiperOrigin-RevId: 678665716
  • Loading branch information
cperivol authored and Google-ML-Automation committed Sep 30, 2024
1 parent b3fca90 commit f8c8630
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 13 deletions.
18 changes: 10 additions & 8 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import util
from jax._src.state.discharge import register_discharge_rule, discharge_state
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.interpreters import ad
Expand Down Expand Up @@ -854,19 +854,21 @@ def _cond_lowering(ctx, index, *args, branches):

mlir.register_lowering(cond_p, _cond_lowering)

@register_discharge_rule(cond_p)
def _cond_state_discharge_rule(in_avals, out_avals, *args, branches):
@register_partial_discharge_rule(cond_p)
def _cond_state_discharge_rule(in_avals, out_avals, index, *args, branches, should_discharge):
discharged_branches = tuple(
core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ())
core.ClosedJaxpr(
discharge_state(branch.jaxpr, (),
should_discharge=should_discharge[1:])[0], ())
for branch in branches)
out_vals = cond_p.bind(*args, branches=discharged_branches)
out_vals = cond_p.bind(index, *args, branches=discharged_branches)
out_vals, out_ref_vals = util.split_list(
out_vals, [len(out_avals)])
ref_val_iter = iter(out_ref_vals)
new_invals = []
for aval in in_avals:
new_invals.append(
next(ref_val_iter) if isinstance(aval, AbstractRef) else None)
for should, aval in zip(should_discharge, in_avals):
discharged_inval = isinstance(aval, AbstractRef) and should
new_invals.append(next(ref_val_iter) if discharged_inval else None)
return new_invals, out_vals


Expand Down
49 changes: 44 additions & 5 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,36 @@ def __call__(self, in_avals: Sequence[core.AbstractValue],

_discharge_rules: dict[core.Primitive, DischargeRule] = {}

class PartialDischargeRule(Protocol):
"""A partial discharge rule.
Exactly like a discharge rule only it accepts a `should_discharge`
argument that indicates which inputs should be discharged and the
return value returns a tuple of which the first element is the new
inputs or none but only the ones that correspond to `True` entries
in `should_charge`.
"""


def __call__(self, in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], *args: Any,
should_discharge=list[bool],
**params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]:
...

_partial_discharge_rules: dict[core.Primitive, PartialDischargeRule] = {}

def register_discharge_rule(prim: core.Primitive):
def register(f: DischargeRule):
_discharge_rules[prim] = f
return register

def register_partial_discharge_rule(prim: core.Primitive):
def register(f: DischargeRule):
_partial_discharge_rules[prim] = f
return register


def _eval_jaxpr_discharge_state(
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
*args: Any):
Expand All @@ -116,21 +141,35 @@ def _eval_jaxpr_discharge_state(
if d and isinstance(v.aval, AbstractRef)}

for eqn in jaxpr.eqns:
should_discharge_arg = [id(v.aval) in refs_to_discharge for v in eqn.invars]
if eqn.primitive is core.mutable_array_p:
[invar], [outvar] = eqn.invars, eqn.outvars
ans = env.read(invar)
refs_to_discharge.add(id(outvar.aval))
elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars)
or core.internal_mutable_array_effect in eqn.effects ):
if eqn.primitive not in _discharge_rules:
elif (any(should_discharge_arg)
or core.internal_mutable_array_effect in eqn.effects
):
if eqn.primitive in _partial_discharge_rules:
rule = partial(_partial_discharge_rules[eqn.primitive], should_discharge=should_discharge_arg) # type: ignore
elif eqn.primitive in _discharge_rules:
should_discharge_arg = [True] * len(eqn.invars)
rule = _discharge_rules[eqn.primitive]
else:
raise NotImplementedError("No state discharge rule implemented for "
f"primitive: {eqn.primitive}")
invals = map(env.read, eqn.invars)
in_avals = [v.aval for v in eqn.invars]
out_avals = [v.aval for v in eqn.outvars]
new_invals, ans = _discharge_rules[eqn.primitive](
new_invals, ans = rule(
in_avals, out_avals, *invals, **eqn.params)
for new_inval, invar in zip(new_invals, eqn.invars):
for invar, should, new_inval in zip(eqn.invars, should_discharge_arg, new_invals):
if not should:
if new_inval is not None:
raise ValueError(
f"Did not ask for inval to be discharged but it was. ({invar=},"
f" {new_inval=})"
)
continue
if new_inval is not None:
env.write(invar, new_inval) # type: ignore[arg-type]
else:
Expand Down
35 changes: 35 additions & 0 deletions tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,20 @@ def f(ref):
in_avals = [shaped_array_ref((), jnp.float32)]
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals)

def test_partial_discharge(self):
def f(a_ref, b_ref):
a_ref[...] = jnp.array(0., dtype=jnp.float32)
b_ref[...] = jnp.array(1., dtype=jnp.float32)
return a_ref[...], b_ref[...]

scalar_ref = shaped_array_ref((), jnp.float32)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [scalar_ref, scalar_ref])

discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
self.assertEqual(discharged_jaxpr.eqns[0].primitive, swap_p)
self.assertEqual(discharged_jaxpr.eqns[1].primitive, get_p)


if CAN_USE_HYPOTHESIS:

Expand Down Expand Up @@ -1061,6 +1075,27 @@ def false_fun():
out = jax.jit(f)(False)
self.assertTupleEqual(out, (0., 5.))

def test_cond_discharge(self):
def f0(pred, x_ref, y_ref):
def true_fun():
x_ref[...] = 1.
def false_fun():
y_ref[...] = 2.
lax.cond(pred, true_fun, false_fun)
return x_ref[...], y_ref[...]
ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x)))
f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.))
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True])
# Effects on y_ref were discharged away but not the effects on x_ref
self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)})
self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)})
# x_ref arg is still a reference but y_ref is discharged
self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef)
self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef)
# x_ref value is returned as part of the discharged refs set.
self.assertLen(f_jaxpr.out_avals, 2)
self.assertLen(jaxpr.outvars, 3)

def test_cond_with_ref_reuse(self):
def f(pred):
def body(x_ref):
Expand Down

0 comments on commit f8c8630

Please sign in to comment.