Skip to content

Commit

Permalink
missed one
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Sep 16, 2024
1 parent be0cc23 commit 00abb18
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,14 +913,10 @@ def _custom_vjp_call_jaxpr_jvp(
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out]
args_dot = map(ad.instantiate_zeros, args_dot)
# Cast float0 to zeros with the primal dtype because custom vjp rules don't
# currently handle float0s
args_dot = map(ad.replace_float0s, args, args_dot)
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

Expand Down

0 comments on commit 00abb18

Please sign in to comment.