Skip to content

Commit

Permalink
Introducing partial discharge rules and implementations for cond_p an…
Browse files Browse the repository at this point in the history
…d for_p

As things stand you can partially discharge a jaxpr with
`discharge_state(should_discharge=[...])` but each equation is discharges *all*
its arguments. This means that primitives like `scan_p` and `cond_p` discharge
all references they refer to (no pun intended) regardless of whether the user
asked for it. We provide a special discharge rule that is preferred to the
normal one when present that allows the op to discharge only some of the
references.

This feature is especially useful for pallas kernels because contrary to all
other contexts where jaxprs are expected to eventually be fully discharged,
pallas kernels lower references all the way to the runtime as pointers or
MLIR memrefs.

PiperOrigin-RevId: 678665716
  • Loading branch information
cperivol authored and Google-ML-Automation committed Sep 26, 2024
1 parent 5cef547 commit 961680c
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 23 deletions.
18 changes: 10 additions & 8 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import util
from jax._src.state.discharge import register_discharge_rule, discharge_state
from jax._src.state.discharge import register_partial_discharge_rule, discharge_state
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.interpreters import ad
Expand Down Expand Up @@ -854,19 +854,21 @@ def _cond_lowering(ctx, index, *args, branches):

mlir.register_lowering(cond_p, _cond_lowering)

@register_discharge_rule(cond_p)
def _cond_state_discharge_rule(in_avals, out_avals, *args, branches):
@register_partial_discharge_rule(cond_p)
def _cond_state_discharge_rule(in_avals, out_avals, index, *args, branches, should_discharge):
discharged_branches = tuple(
core.ClosedJaxpr(discharge_state(branch.jaxpr, ())[0], ())
core.ClosedJaxpr(
discharge_state(branch.jaxpr, (),
should_discharge=should_discharge[1:])[0], ())
for branch in branches)
out_vals = cond_p.bind(*args, branches=discharged_branches)
out_vals = cond_p.bind(index, *args, branches=discharged_branches)
out_vals, out_ref_vals = util.split_list(
out_vals, [len(out_avals)])
ref_val_iter = iter(out_ref_vals)
new_invals = []
for aval in in_avals:
new_invals.append(
next(ref_val_iter) if isinstance(aval, AbstractRef) else None)
for should, aval in zip(should_discharge, in_avals):
discharged_inval = isinstance(aval, AbstractRef) and should
new_invals.append(next(ref_val_iter) if discharged_inval else None)
return new_invals, out_vals


Expand Down
30 changes: 20 additions & 10 deletions jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,28 @@ def _for_abstract_eval(*avals, jaxpr, **__):
nonlocal_state_effects = core.join_effects(*aval_effects)
return list(avals), nonlocal_state_effects

@state_discharge.register_discharge_rule(for_p)
@state_discharge.register_partial_discharge_rule(for_p)
def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
reverse: bool, which_linear: Sequence[bool],
nsteps: int, unroll: int
) -> tuple[Sequence[Any | None], Sequence[Any]]:
out_vals = for_p.bind(*args, jaxpr=jaxpr, reverse=reverse,
which_linear=which_linear, nsteps=nsteps,
unroll=unroll)
new_invals = []
for aval, out_val in zip(in_avals, out_vals):
new_invals.append(out_val if isinstance(aval, AbstractRef) else None)
return new_invals, out_vals
nsteps: int, unroll: int,
should_discharge:list[bool]) -> tuple[Sequence[Any | None], Sequence[Any]]:
if len(jaxpr.constvars) > 0:
raise NotImplementedError("Constants not supported for discharge.")

# Jaxpr doesn't return anything in for loops so the return is a tuple of the discharged values.
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, (), should_discharge=[False, *should_discharge])
def body(carry, _):
i, *args = carry
outputs = iter(core.eval_jaxpr(discharged_jaxpr, (), i, *args))
carry_outputs = tuple(next(outputs) if should else arg for arg, should in zip(args, should_discharge))
carry = (i + 1 if reverse else i - 1, *carry_outputs)
return carry, None

# New args are also the result type in for_loop
beg = nsteps - 1 if reverse else 0
(_, *new_args), _ = loops.scan(body, (beg, *args), None, length=nsteps, reverse=reverse, unroll=unroll)
discharged_args = tuple(arg if should else None for arg, should in zip(new_args, should_discharge))
return discharged_args, new_args

def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll):
del which_linear
Expand Down
49 changes: 44 additions & 5 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,36 @@ def __call__(self, in_avals: Sequence[core.AbstractValue],

_discharge_rules: dict[core.Primitive, DischargeRule] = {}

class PartialDischargeRule(Protocol):
"""A partial discharge rule.
Exactly like a discharge rule only it accepts a `should_discharge`
argument that indicates which inputs should be discharged and the
return value returns a tuple of which the first element is the new
inputs or none but only the ones that correspond to `True` entries
in `should_charge`.
"""


def __call__(self, in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], *args: Any,
should_discharge=list[bool],
**params: Any) -> tuple[Sequence[Any | None], Sequence[Any]]:
...

_partial_discharge_rules: dict[core.Primitive, DischargeRule] = {}

def register_discharge_rule(prim: core.Primitive):
def register(f: DischargeRule):
_discharge_rules[prim] = f
return register

def register_partial_discharge_rule(prim: core.Primitive):
def register(f: DischargeRule):
_partial_discharge_rules[prim] = f
return register


def _eval_jaxpr_discharge_state(
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
*args: Any):
Expand All @@ -116,21 +141,35 @@ def _eval_jaxpr_discharge_state(
if d and isinstance(v.aval, AbstractRef)}

for eqn in jaxpr.eqns:
should_discharge_arg = [id(v.aval) in refs_to_discharge for v in eqn.invars]
if eqn.primitive is core.mutable_array_p:
[invar], [outvar] = eqn.invars, eqn.outvars
ans = env.read(invar)
refs_to_discharge.add(id(outvar.aval))
elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars)
or core.internal_mutable_array_effect in eqn.effects ):
if eqn.primitive not in _discharge_rules:
elif (any(should_discharge_arg)
or core.internal_mutable_array_effect in eqn.effects
):
if eqn.primitive in _partial_discharge_rules:
rule = partial(_partial_discharge_rules[eqn.primitive], should_discharge=should_discharge_arg)
elif eqn.primitive in _discharge_rules:
should_discharge_arg = [True] * len(eqn.invars)
rule = _discharge_rules[eqn.primitive]
else:
raise NotImplementedError("No state discharge rule implemented for "
f"primitive: {eqn.primitive}")
invals = map(env.read, eqn.invars)
in_avals = [v.aval for v in eqn.invars]
out_avals = [v.aval for v in eqn.outvars]
new_invals, ans = _discharge_rules[eqn.primitive](
new_invals, ans = rule(
in_avals, out_avals, *invals, **eqn.params)
for new_inval, invar in zip(new_invals, eqn.invars):
for invar, should, new_inval in zip(eqn.invars, should_discharge_arg, new_invals):
if not should:
if new_inval is not None:
raise ValueError(
f"Did not ask for inval to be discharged but it was. ({invar=},"
f" {new_inval=})"
)
continue
if new_inval is not None:
env.write(invar, new_inval) # type: ignore[arg-type]
else:
Expand Down
56 changes: 56 additions & 0 deletions tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,41 @@ def f(ref):
in_avals = [shaped_array_ref((), jnp.float32)]
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals)

def test_partial_discharge(self):
def f(a_ref, b_ref):
a_ref[...] = jnp.array(0., dtype=jnp.float32)
b_ref[...] = jnp.array(1., dtype=jnp.float32)
return a_ref[...], b_ref[...]

scalar_ref = shaped_array_ref((), jnp.float32)
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [scalar_ref, scalar_ref])

discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
self.assertEqual(discharged_jaxpr.eqns[0].primitive, swap_p)
self.assertEqual(discharged_jaxpr.eqns[1].primitive, get_p)

def test_partial_for_discharge(self):
def f(a_ref, b_ref):
@partial(for_loop.for_loop, 5, init_state=())
def _(i, st):
a_ref[...] = 0.
b_ref[...] = 1.
return a_ref[...], b_ref[...]

ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x)))
f_jaxpr = jax.make_jaxpr(f)(ref(1.), ref(2.))
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, True])
# Effects on y_ref were discharged away but not the effects on x_ref
self.assertEqual(f_jaxpr.effects, {ReadEffect(0), WriteEffect(0), ReadEffect(1), WriteEffect(1)})
self.assertEqual(jaxpr.effects, {ReadEffect(0), WriteEffect(0)})
# x_ref arg is still a reference but y_ref is discharged
self.assertNotIsInstance(jaxpr.invars[1].aval, AbstractRef)
self.assertIsInstance(jaxpr.invars[0].aval, AbstractRef)
# x_ref value is returned as part of the discharged refs set.
self.assertLen(f_jaxpr.out_avals, 2)
self.assertLen(jaxpr.outvars, 3)


if CAN_USE_HYPOTHESIS:

Expand Down Expand Up @@ -1061,6 +1096,27 @@ def false_fun():
out = jax.jit(f)(False)
self.assertTupleEqual(out, (0., 5.))

def test_cond_discharge(self):
def f0(pred, x_ref, y_ref):
def true_fun():
x_ref[...] = 1.
def false_fun():
y_ref[...] = 2.
lax.cond(pred, true_fun, false_fun)
return x_ref[...], y_ref[...]
ref = lambda x: AbstractRef(core.raise_to_shaped(core.get_aval(x)))
f_jaxpr = jax.make_jaxpr(f0)(False, ref(3.), ref(4.))
jaxpr, _ = discharge_state(f_jaxpr.jaxpr, (), should_discharge=[False, False, True])
# Effects on y_ref were discharged away but not the effects on x_ref
self.assertEqual(f_jaxpr.effects, {ReadEffect(1), WriteEffect(1), ReadEffect(2), WriteEffect(2)})
self.assertEqual(jaxpr.effects, {ReadEffect(1), WriteEffect(1)})
# x_ref arg is still a reference but y_ref is discharged
self.assertNotIsInstance(jaxpr.invars[2].aval, AbstractRef)
self.assertIsInstance(jaxpr.invars[1].aval, AbstractRef)
# x_ref value is returned as part of the discharged refs set.
self.assertLen(f_jaxpr.out_avals, 2)
self.assertLen(jaxpr.outvars, 3)

def test_cond_with_ref_reuse(self):
def f(pred):
def body(x_ref):
Expand Down

0 comments on commit 961680c

Please sign in to comment.