Skip to content

Commit

Permalink
fix aval error
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Aug 28, 2024
1 parent f4200ba commit c1f871a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
primals_out = map(self.primal_part, primals_out)
res = map(self.primal_part, res)
avals_out = [Zero.from_primal_value(x) for x in primals_out]
avals_out = [core.primal_aval_to_tangent_aval(raise_to_shaped(core.get_aval(x)))
for x in primals_out]
# TODO(frostig,mattjj): avoid instantiating zeros when we don't have to!
with core.set_current_trace(self.parent_trace):
tangents_in = map(instantiate_zeros, tangents_in)
Expand Down

0 comments on commit c1f871a

Please sign in to comment.