Skip to content

Commit

Permalink
aaorf - pass riesz_representation through chained rfs properly
Browse files Browse the repository at this point in the history
  • Loading branch information
JHopeCollins committed Dec 17, 2024
1 parent d15eba6 commit 74a0974
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions firedrake/adjoint/all_at_once_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Callable, Optional
from contextlib import contextmanager
from mpi4py import MPI
from firedrake.petsc import PETSc

__all__ = ['AllAtOnceReducedFunctional']

Expand Down Expand Up @@ -77,6 +76,20 @@ def _ad_sub(left, right):
return result


def _intermediate_options(final_options):
"""
Options set for the intermediate stages of a chain of ReducedFunctionals
Takes all elements of the final_options except riesz_representation,
which is set to prevent returning derivatives to the primal space.
"""
return {
'riesz_representation': None,
**{k: v for k, v in final_options.items()
if (k != 'riesz_representation')}
}


class AllAtOnceReducedFunctional(ReducedFunctional):
"""ReducedFunctional for 4DVar data assimilation.
Expand Down Expand Up @@ -359,11 +372,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}):

# chaining ReducedFunctionals means we need to pass Cofunctions not Functions
options = options or {}
intermediate_options = {
'riesz_representation': 'l2',
**{k: v for k, v in options.items()
if (k != 'riesz_representation')}
}
intermediate_options = _intermediate_options(options)

# evaluate first forward model, which contributes to previous chunk
sderiv0 = self.stages[0].derivative(
Expand Down Expand Up @@ -627,7 +636,6 @@ def __next__(self):
stage = StrongObservationStage(control, self.aaorf)
self._prev_stage = stage


return stage, self.ctx


Expand Down Expand Up @@ -903,11 +911,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {},

# chaining ReducedFunctionals means we need to pass Cofunctions not Functions
options = options or {}
intermediate_options = {
'riesz_representation': None,
**{k: v for k, v in options.items()
if (k != 'riesz_representation')}
}
intermediate_options = _intermediate_options(options)

if (rftype is None) or (rftype == 'model'):
# derivative of reduction
Expand All @@ -922,8 +926,16 @@ def derivative(self, adj_input: float = 1.0, options: dict = {},
dm_forward = self.forward_model.derivative(adj_input=dm_errors[0],
options=options)

sentinel = -12345
riesz_map = options.get('riesz_representation', sentinel)
derivatives.append(dm_forward)
derivatives.append(dm_errors[1].riesz_representation())
if riesz_map != sentinel:
if riesz_map is None:
derivatives.append(dm_errors[1])
else:
derivatives.append(dm_errors[1].riesz_representation(riesz_map))
else:
derivatives.append(dm_errors[1].riesz_representation())

if (rftype is None) or (rftype == 'obs'):
# derivative of reduction
Expand Down

0 comments on commit 74a0974

Please sign in to comment.