Skip to content

Commit

Permalink
update remat opt
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Aug 28, 2024
1 parent c1f871a commit 4944e2f
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
f"functions with side effects, but {fwd_name} has the following "
f"effects: {fwd_jaxpr.effects}")

@pe._memoize
# @pe._memoize
def fun_jaxpr_thunk():
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
return jaxpr, consts
Expand Down Expand Up @@ -1455,7 +1455,7 @@ def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
return fwd_jaxpr.out_avals, fwd_jaxpr.effects

def _remat_opt_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims,
axis_data, main_type, args, in_dims,
*,
num_consts: int,
num_res: int,
Expand All @@ -1464,6 +1464,9 @@ def _remat_opt_vmap(
):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]
axis_name = axis_data.name
axis_size = axis_data.size
spmd_axis_name = axis_data.spmd_name

in_batched = [d is not not_mapped for d in in_dims]
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
Expand All @@ -1476,7 +1479,7 @@ def _remat_opt_vmap(

_, prim_batched = split_list(in_batched, [num_consts])

@pe._memoize
# @pe._memoize
def batched_fun_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
Expand Down Expand Up @@ -1590,9 +1593,8 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
mlir.register_lowering(remat_opt_p, mlir.lower_fun(
_remat_opt_impl, multiple_results=True))

# batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
# batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None)

batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
pe.dce_rules[remat_opt_p] = _remat_opt_dce

0 comments on commit 4944e2f

Please sign in to comment.