Skip to content

Commit

Permalink
aaorf - delegate converting derivative from intermediate type to pyad…
Browse files Browse the repository at this point in the history
…joint
  • Loading branch information
JHopeCollins committed Dec 17, 2024
1 parent 74a0974 commit 6c4c5b7
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions firedrake/adjoint/all_at_once_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,16 +926,11 @@ def derivative(self, adj_input: float = 1.0, options: dict = {},
dm_forward = self.forward_model.derivative(adj_input=dm_errors[0],
options=options)

sentinel = -12345
riesz_map = options.get('riesz_representation', sentinel)
derivatives.append(dm_forward)
if riesz_map != sentinel:
if riesz_map is None:
derivatives.append(dm_errors[1])
else:
derivatives.append(dm_errors[1].riesz_representation(riesz_map))
else:
derivatives.append(dm_errors[1].riesz_representation())

# dm_errors is still in the dual space, so we need to convert it to the
# type that the user has requested - this will be the type of dm_forward.
derivatives.append(dm_forward._ad_convert_type(dm_errors[1], options))

if (rftype is None) or (rftype == 'obs'):
# derivative of reduction
Expand Down

0 comments on commit 6c4c5b7

Please sign in to comment.