From 4944e2fce3d74c1893400291ccf680c1622f3578 Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 28 Aug 2024 20:28:09 +0000 Subject: [PATCH] update remat opt --- jax/_src/custom_derivatives.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 2696675654ad..425e46499657 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1418,7 +1418,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: f"functions with side effects, but {fwd_name} has the following " f"effects: {fwd_jaxpr.effects}") - @pe._memoize + # @pe._memoize def fun_jaxpr_thunk(): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) return jaxpr, consts @@ -1455,7 +1455,7 @@ def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_): return fwd_jaxpr.out_avals, fwd_jaxpr.effects def _remat_opt_vmap( - spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, + axis_data, main_type, args, in_dims, *, num_consts: int, num_res: int, @@ -1464,6 +1464,9 @@ def _remat_opt_vmap( ): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] + axis_name = axis_data.name + axis_size = axis_data.size + spmd_axis_name = axis_data.spmd_name in_batched = [d is not not_mapped for d in in_dims] batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( @@ -1476,7 +1479,7 @@ def _remat_opt_vmap( _, prim_batched = split_list(in_batched, [num_consts]) - @pe._memoize + # @pe._memoize def batched_fun_jaxpr_thunk(): fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( @@ -1590,9 +1593,8 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): mlir.register_lowering(remat_opt_p, mlir.lower_fun( _remat_opt_impl, multiple_results=True)) -# batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap -# batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None) +batching.fancy_primitive_batchers[remat_opt_p] = _remat_opt_vmap ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose pe.dce_rules[remat_opt_p] = _remat_opt_dce