Skip to content

Commit

Permalink
EnsembleFunction - initial (incomplete) impl
Browse files Browse the repository at this point in the history
  • Loading branch information
JHopeCollins committed Nov 19, 2024
1 parent 6406402 commit 806a605
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 85 deletions.
1 change: 1 addition & 0 deletions firedrake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
from firedrake.vector import *
from firedrake.version import __version__ as ver, __version_info__, check # noqa: F401
from firedrake.ensemble import *
from firedrake.ensemblefunction import *
from firedrake.randomfunctiongen import *
from firedrake.external_operators import *
from firedrake.progress_bar import ProgressBar # noqa: F401
Expand Down
163 changes: 78 additions & 85 deletions firedrake/adjoint/all_at_once_reduced_functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \
stop_annotating, no_annotations, get_working_tape, set_working_tape
from pyadjoint.enlisting import Enlist
from firedrake import Ensemble
from functools import wraps, cached_property
from typing import Callable, Optional
from contextlib import contextmanager
Expand Down Expand Up @@ -100,11 +99,11 @@ class AllAtOnceReducedFunctional(ReducedFunctional):
----------
control
The initial condition :math:`x_{0}`. Starting value is used as the
background (prior) data :math:`x_{b}`.
The :class:`EnsembleFunction` for the control x_{i} at the initial
condition and at the end of each observation stage.
nlocal_stages
The number of observation stages on the local ensemble member.
control
The background (prior) data for the initial condition :math:`x_{b}`.
background_iprod
The inner product to calculate the background error functional
Expand All @@ -126,56 +125,56 @@ class AllAtOnceReducedFunctional(ReducedFunctional):
weak_constraint
Whether to use the weak or strong constraint 4DVar formulation.
ensemble
The ensemble communicator to parallelise over. None for no time parallelism.
If `ensemble` is provided, then `background_iprod`, `observation_err` and
`observation_iprod` must only be provided on ensemble rank 0.
See Also
--------
:class:`pyadjoint.ReducedFunctional`.
"""

def __init__(self, control: Control,
nlocal_stages: int,
background: OverloadedType,
background_iprod: Optional[Callable[[OverloadedType], AdjFloat]],
observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None,
observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None,
weak_constraint: bool = True,
tape: Optional[Tape] = None,
_annotate_accumulation: bool = False,
ensemble: Optional[Ensemble] = None):
_annotate_accumulation: bool = False):

self.tape = get_working_tape() if tape is None else tape

self.weak_constraint = weak_constraint
self.initial_observations = observation_err is not None

# We need a copy for the prior, but this shouldn't be part of the tape
with stop_annotating():
self.background = control.copy_data()
self.background = background._ad_copy()
_rename(self.background, "Background")

if self.weak_constraint:
self._annotate_accumulation = _annotate_accumulation
self._accumulation_started = False

self.nlocal_stages = nlocal_stages

ensemble = control.ensemble
self.ensemble = ensemble
self.trank = ensemble.ensemble_comm.rank if ensemble else 0
self.nchunks = ensemble.ensemble_comm.size if ensemble else 1

x = control.control.subfunctions
self.x = x

self.control = control
self._controls = tuple(Control(xi) for xi in x)

# first control on rank 0 is initial conditions, not end of observation stage
self.nlocal_stages = len(x) - (1 if self.trank == 0 else 0)

self.stages = [] # The record of each observation stage
self.controls = [] # The solution at the beginning of each time-chunk

# first rank sets up functionals for background initial observations
if self.trank == 0:
self.controls.append(control)

# RF to recalculate error vector (x_0 - x_b)
self.background_error = isolated_rf(
operation=lambda x0: _ad_sub(x0, self.background),
control=control,
control=x[0],
functional_name="bkg_err_vec",
control_name="Control_0_bkg_copy")

Expand All @@ -190,7 +189,7 @@ def __init__(self, control: Control,
# RF to recalculate error vector (H(x_0) - y_0)
self.initial_observation_error = isolated_rf(
operation=observation_err,
control=control,
control=x[0],
functional_name="obs_err_vec_0",
control_name="Control_0_obs_copy")

Expand All @@ -201,10 +200,6 @@ def __init__(self, control: Control,
functional_name="obs_err_vec_0_copy")

else:
# create halo for previous state
with stop_annotating():
self.xprev = control.copy_data()
self.control_prev = Control(self.xprev)

if background_iprod is not None:
raise ValueError("Only the first ensemble rank needs `background_iprod`")
Expand All @@ -213,21 +208,16 @@ def __init__(self, control: Control,
if observation_err is not None:
raise ValueError("Only the first ensemble rank needs `observation_err`")

# create all controls on local ensemble member
with stop_annotating():
for _ in range(nlocal_stages):
self.controls.append(Control(control.copy_data()))
# create halo for previous state
if self.ensemble and self.trank != 0:
with stop_annotating():
self.xprev = x[0]._ad_copy()
self._control_prev = Control(self.xprev)

# halo for the derivative from the next chunk
if self.ensemble and self.trank != self.nchunks - 1:
self.xnext = control.copy_data()

# new tape for the initial stage
if self.trank == 0:
self.stages.append(
WeakObservationStage(self.controls[0], index=0))
else:
self._stage_tape = None
with stop_annotating():
self.xnext = x[0]._ad_copy()

else:
self._annotate_accumulation = True
Expand Down Expand Up @@ -292,53 +282,54 @@ def __call__(self, values: OverloadedType):
The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`.
"""
for c, v in zip(self.controls, values):
c.control.assign(v)
self.control.assign(values)
trank = self.trank

# first "control" for later ranks is the halo
if self.ensemble and trank != 0:
x = [self.xprev, *self.x]
else:
x = [*self.x]

# post messages for control of forward model propogation on next chunk
trank = self.trank
if self.ensemble:
src = trank - 1
dst = trank + 1

if trank != self.nchunks - 1:
self.ensemble.isend(
self.controls[-1].control, dest=dst, tag=dst)
x[-1], dest=dst, tag=dst)

if trank != 0:
recv_reqs = self.ensemble.irecv(
self.xprev, source=src, tag=trank)

# first "control" is the halo
if self.ensemble and trank != 0:
values = [self.xprev, *values]

# Initial condition functionals
if trank == 0:
Jlocal = (
self.background_norm(
self.background_error(values[0]))
self.background_error(x[0]))
)

# observations at time 0
if self.initial_observations:
Jlocal += (
self.initial_observation_norm(
self.initial_observation_error(values[0]))
self.initial_observation_error(x[0]))
)
else:
Jlocal = 0.

# evaluate all stages on chunk except first
for i in range(1, len(self.stages)):
Jlocal += self.stages[i](values[i:i+2])
Jlocal += self.stages[i](x[i:i+2])

# wait for halo swap to finish
if trank != 0:
MPI.Request.Waitall(recv_reqs)

# evaluate first stage model on chunk now we have data
Jlocal += self.stages[0](values[0:2])
Jlocal += self.stages[0](x[0:2])

# sum all stages
if self.ensemble:
Expand Down Expand Up @@ -370,8 +361,16 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}):
The derivative with respect to the control.
Should be an instance of the same type as the control.
"""
trank = self.trank
# create a list of overloaded types to put derivative into
derivatives = []
derivatives = self.control._ad_copy()
derivatives.zero()

if self.ensemble and trank != 0:
self.xprev.zero()
derivs = [self.xprev, *derivatives.subfunctions]
else:
derivs = [*derivatives.subfunctions]

# chaining ReducedFunctionals means we need to pass Cofunctions not Functions
options = options or {}
Expand All @@ -382,57 +381,50 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}):
}

# initial condition derivatives
if self.trank == 0:
if trank == 0:
bkg_deriv = self.background_norm.derivative(adj_input=adj_input,
options=intermediate_options)
derivatives.append(self.background_error.derivative(adj_input=bkg_deriv,
options=options))
derivs[0] += self.background_error.derivative(adj_input=bkg_deriv,
options=options)

# observations at time 0
if self.initial_observations:
obs_deriv = self.initial_observation_norm.derivative(adj_input=adj_input,
options=intermediate_options)
derivatives[0] += self.initial_observation_error.derivative(adj_input=obs_deriv,
options=options)
derivs[0] += self.initial_observation_error.derivative(adj_input=obs_deriv,
options=options)

# evaluate first forward model, which contributes to previous chunk
derivs = self.stages[0].derivative(adj_input=adj_input, options=options)
sderiv0 = self.stages[0].derivative(adj_input=adj_input, options=options)

if self.trank == 0:
derivatives[0] += derivs[0]
else:
derivatives.append(derivs[0])
derivatives.append(derivs[1])
derivs[0] += sderiv0[0]
derivs[1] += sderiv0[1]

# post the derivative halo exchange
from firedrake import norm

Check failure on line 404 in firedrake/adjoint/all_at_once_reduced_functional.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/adjoint/all_at_once_reduced_functional.py:404:9: F401 'firedrake.norm' imported but unused
if self.ensemble:
src = self.trank + 1
dst = self.trank - 1
src = trank + 1
dst = trank - 1

if self.trank != 0:
if trank != 0:
self.ensemble.isend(
derivatives[0], dest=dst, tag=dst)
derivs[0], dest=dst, tag=dst)

if self.trank != self.nchunks - 1:
if trank != self.nchunks - 1:
recv_reqs = self.ensemble.irecv(
self.xnext, source=src, tag=self.trank)
self.xnext, source=src, tag=trank)

# # evaluate all forward models on chunk except first while halo in flight
for i in range(1, len(self.stages)):
derivs = self.stages[i].derivative(adj_input=adj_input, options=options)
derivatives[i] += derivs[0]
derivatives.append(derivs[1])
sderiv = self.stages[i].derivative(adj_input=adj_input, options=options)
derivs[i] += sderiv[0]
derivs[i+1] += sderiv[1]

# finish the derivative halo exchange
if self.ensemble:
if self.trank != self.nchunks - 1:
if trank != self.nchunks - 1:
MPI.Request.Waitall(recv_reqs)
derivatives[-1] += self.xnext

# we don't own the control for the halo, so remove it from the
# list of local derivatives once the communication has finished
if self.trank != 0:
derivatives.pop(0)
derivs[-1] += self.xnext

return derivatives

Expand Down Expand Up @@ -514,10 +506,10 @@ def recording_stages(self, sequential=True, **kwargs):
stage_kwargs['local_index'] = 0

# subsequent ranks start from halo
controls = self.controls if trank == 0 else [self.control_prev, *self.controls]
controls = self._controls if trank == 0 else [self._control_prev, *self._controls]

stage_sequence = ObservationStageSequence(
controls, self, stage_kwargs, sequential, weak_constraint=True)
controls, self, stage_kwargs, sequential)

yield stage_sequence

Expand All @@ -541,23 +533,21 @@ def recording_stages(self, sequential=True, **kwargs):
else: # strong constraint

yield ObservationStageSequence(
self.controls, self, stage_kwargs,
sequential=True, weak_constraint=False)
self.controls, self, stage_kwargs, sequential=True)


class ObservationStageSequence:
def __init__(self, controls: Control,
aaorf: AllAtOnceReducedFunctional,
stage_kwargs: dict = None,
sequential: bool = True,
weak_constraint: bool = True):
sequential: bool = True):
self.controls = controls
self.nstages = len(controls) - 1
self.aaorf = aaorf
self.ctx = StageContext(**(stage_kwargs or {}))
self.index = 0
self.weak_constraint = weak_constraint
if weak_constraint:
self.weak_constraint = aaorf.weak_constraint
if self.weak_constraint:
self.stages = []

def __iter__(self):
Expand Down Expand Up @@ -792,6 +782,9 @@ def set_observation(self, state: OverloadedType,
# remove the stage initial condition "control" now we've finished recording
delattr(self, "control")

# stop the stage tape recording anything else
set_working_tape()

@no_annotations
def __call__(self, values: OverloadedType,
rftype: Optional[str] = None):
Expand Down
1 change: 1 addition & 0 deletions firedrake/adjoint_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from firedrake.adjoint_utils.solving import * # noqa: F401
from firedrake.adjoint_utils.mesh import * # noqa: F401
from firedrake.adjoint_utils.checkpointing import * # noqa: F401
from firedrake.adjoint_utils.ensemblefunction import * # noqa: F401
Loading

0 comments on commit 806a605

Please sign in to comment.