From 806a6054f3391158019c73d6a767ede68741f176 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 19 Nov 2024 15:59:07 +0000 Subject: [PATCH] EnsembleFunction - initial (incomplete) impl --- firedrake/__init__.py | 1 + .../adjoint/all_at_once_reduced_functional.py | 163 +++++----- firedrake/adjoint_utils/__init__.py | 1 + firedrake/adjoint_utils/ensemblefunction.py | 69 ++++ firedrake/ensemblefunction.py | 295 ++++++++++++++++++ 5 files changed, 444 insertions(+), 85 deletions(-) create mode 100644 firedrake/adjoint_utils/ensemblefunction.py create mode 100644 firedrake/ensemblefunction.py diff --git a/firedrake/__init__.py b/firedrake/__init__.py index 0fc9aeeed6..b34a4f3d31 100644 --- a/firedrake/__init__.py +++ b/firedrake/__init__.py @@ -116,6 +116,7 @@ from firedrake.vector import * from firedrake.version import __version__ as ver, __version_info__, check # noqa: F401 from firedrake.ensemble import * +from firedrake.ensemblefunction import * from firedrake.randomfunctiongen import * from firedrake.external_operators import * from firedrake.progress_bar import ProgressBar # noqa: F401 diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 6a4e84617c..2c5ad915ea 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -1,7 +1,6 @@ from pyadjoint import ReducedFunctional, OverloadedType, Control, Tape, AdjFloat, \ stop_annotating, no_annotations, get_working_tape, set_working_tape from pyadjoint.enlisting import Enlist -from firedrake import Ensemble from functools import wraps, cached_property from typing import Callable, Optional from contextlib import contextmanager @@ -100,11 +99,11 @@ class AllAtOnceReducedFunctional(ReducedFunctional): ---------- control - The initial condition :math:`x_{0}`. Starting value is used as the - background (prior) data :math:`x_{b}`. + The :class:`EnsembleFunction` for the control x_{i} at the initial + condition and at the end of each observation stage. - nlocal_stages - The number of observation stages on the local ensemble member. + control + The background (prior) data for the initial condition :math:`x_{b}`. background_iprod The inner product to calculate the background error functional @@ -126,56 +125,56 @@ class AllAtOnceReducedFunctional(ReducedFunctional): weak_constraint Whether to use the weak or strong constraint 4DVar formulation. - ensemble - The ensemble communicator to parallelise over. None for no time parallelism. - If `ensemble` is provided, then `background_iprod`, `observation_err` and - `observation_iprod` must only be provided on ensemble rank 0. - See Also -------- :class:`pyadjoint.ReducedFunctional`. """ def __init__(self, control: Control, - nlocal_stages: int, + background: OverloadedType, background_iprod: Optional[Callable[[OverloadedType], AdjFloat]], observation_err: Optional[Callable[[OverloadedType], OverloadedType]] = None, observation_iprod: Optional[Callable[[OverloadedType], AdjFloat]] = None, weak_constraint: bool = True, tape: Optional[Tape] = None, - _annotate_accumulation: bool = False, - ensemble: Optional[Ensemble] = None): + _annotate_accumulation: bool = False): self.tape = get_working_tape() if tape is None else tape self.weak_constraint = weak_constraint self.initial_observations = observation_err is not None - # We need a copy for the prior, but this shouldn't be part of the tape with stop_annotating(): - self.background = control.copy_data() + self.background = background._ad_copy() + _rename(self.background, "Background") if self.weak_constraint: self._annotate_accumulation = _annotate_accumulation self._accumulation_started = False - self.nlocal_stages = nlocal_stages - + ensemble = control.ensemble self.ensemble = ensemble self.trank = ensemble.ensemble_comm.rank if ensemble else 0 self.nchunks = ensemble.ensemble_comm.size if ensemble else 1 + x = control.control.subfunctions + self.x = x + + self.control = control + self._controls = tuple(Control(xi) for xi in x) + + # first control on rank 0 is initial conditions, not end of observation stage + self.nlocal_stages = len(x) - (1 if self.trank == 0 else 0) + self.stages = [] # The record of each observation stage - self.controls = [] # The solution at the beginning of each time-chunk # first rank sets up functionals for background initial observations if self.trank == 0: - self.controls.append(control) # RF to recalculate error vector (x_0 - x_b) self.background_error = isolated_rf( operation=lambda x0: _ad_sub(x0, self.background), - control=control, + control=x[0], functional_name="bkg_err_vec", control_name="Control_0_bkg_copy") @@ -190,7 +189,7 @@ def __init__(self, control: Control, # RF to recalculate error vector (H(x_0) - y_0) self.initial_observation_error = isolated_rf( operation=observation_err, - control=control, + control=x[0], functional_name="obs_err_vec_0", control_name="Control_0_obs_copy") @@ -201,10 +200,6 @@ def __init__(self, control: Control, functional_name="obs_err_vec_0_copy") else: - # create halo for previous state - with stop_annotating(): - self.xprev = control.copy_data() - self.control_prev = Control(self.xprev) if background_iprod is not None: raise ValueError("Only the first ensemble rank needs `background_iprod`") @@ -213,21 +208,16 @@ def __init__(self, control: Control, if observation_err is not None: raise ValueError("Only the first ensemble rank needs `observation_err`") - # create all controls on local ensemble member - with stop_annotating(): - for _ in range(nlocal_stages): - self.controls.append(Control(control.copy_data())) + # create halo for previous state + if self.ensemble and self.trank != 0: + with stop_annotating(): + self.xprev = x[0]._ad_copy() + self._control_prev = Control(self.xprev) # halo for the derivative from the next chunk if self.ensemble and self.trank != self.nchunks - 1: - self.xnext = control.copy_data() - - # new tape for the initial stage - if self.trank == 0: - self.stages.append( - WeakObservationStage(self.controls[0], index=0)) - else: - self._stage_tape = None + with stop_annotating(): + self.xnext = x[0]._ad_copy() else: self._annotate_accumulation = True @@ -292,53 +282,54 @@ def __call__(self, values: OverloadedType): The computed value. Typically of instance of :class:`pyadjoint.AdjFloat`. """ - for c, v in zip(self.controls, values): - c.control.assign(v) + self.control.assign(values) + trank = self.trank + + # first "control" for later ranks is the halo + if self.ensemble and trank != 0: + x = [self.xprev, *self.x] + else: + x = [*self.x] # post messages for control of forward model propogation on next chunk - trank = self.trank if self.ensemble: src = trank - 1 dst = trank + 1 if trank != self.nchunks - 1: self.ensemble.isend( - self.controls[-1].control, dest=dst, tag=dst) + x[-1], dest=dst, tag=dst) if trank != 0: recv_reqs = self.ensemble.irecv( self.xprev, source=src, tag=trank) - # first "control" is the halo - if self.ensemble and trank != 0: - values = [self.xprev, *values] - # Initial condition functionals if trank == 0: Jlocal = ( self.background_norm( - self.background_error(values[0])) + self.background_error(x[0])) ) # observations at time 0 if self.initial_observations: Jlocal += ( self.initial_observation_norm( - self.initial_observation_error(values[0])) + self.initial_observation_error(x[0])) ) else: Jlocal = 0. # evaluate all stages on chunk except first for i in range(1, len(self.stages)): - Jlocal += self.stages[i](values[i:i+2]) + Jlocal += self.stages[i](x[i:i+2]) # wait for halo swap to finish if trank != 0: MPI.Request.Waitall(recv_reqs) # evaluate first stage model on chunk now we have data - Jlocal += self.stages[0](values[0:2]) + Jlocal += self.stages[0](x[0:2]) # sum all stages if self.ensemble: @@ -370,8 +361,16 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): The derivative with respect to the control. Should be an instance of the same type as the control. """ + trank = self.trank # create a list of overloaded types to put derivative into - derivatives = [] + derivatives = self.control._ad_copy() + derivatives.zero() + + if self.ensemble and trank != 0: + self.xprev.zero() + derivs = [self.xprev, *derivatives.subfunctions] + else: + derivs = [*derivatives.subfunctions] # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} @@ -382,57 +381,50 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): } # initial condition derivatives - if self.trank == 0: + if trank == 0: bkg_deriv = self.background_norm.derivative(adj_input=adj_input, options=intermediate_options) - derivatives.append(self.background_error.derivative(adj_input=bkg_deriv, - options=options)) + derivs[0] += self.background_error.derivative(adj_input=bkg_deriv, + options=options) # observations at time 0 if self.initial_observations: obs_deriv = self.initial_observation_norm.derivative(adj_input=adj_input, options=intermediate_options) - derivatives[0] += self.initial_observation_error.derivative(adj_input=obs_deriv, - options=options) + derivs[0] += self.initial_observation_error.derivative(adj_input=obs_deriv, + options=options) # evaluate first forward model, which contributes to previous chunk - derivs = self.stages[0].derivative(adj_input=adj_input, options=options) + sderiv0 = self.stages[0].derivative(adj_input=adj_input, options=options) - if self.trank == 0: - derivatives[0] += derivs[0] - else: - derivatives.append(derivs[0]) - derivatives.append(derivs[1]) + derivs[0] += sderiv0[0] + derivs[1] += sderiv0[1] # post the derivative halo exchange + from firedrake import norm if self.ensemble: - src = self.trank + 1 - dst = self.trank - 1 + src = trank + 1 + dst = trank - 1 - if self.trank != 0: + if trank != 0: self.ensemble.isend( - derivatives[0], dest=dst, tag=dst) + derivs[0], dest=dst, tag=dst) - if self.trank != self.nchunks - 1: + if trank != self.nchunks - 1: recv_reqs = self.ensemble.irecv( - self.xnext, source=src, tag=self.trank) + self.xnext, source=src, tag=trank) # # evaluate all forward models on chunk except first while halo in flight for i in range(1, len(self.stages)): - derivs = self.stages[i].derivative(adj_input=adj_input, options=options) - derivatives[i] += derivs[0] - derivatives.append(derivs[1]) + sderiv = self.stages[i].derivative(adj_input=adj_input, options=options) + derivs[i] += sderiv[0] + derivs[i+1] += sderiv[1] # finish the derivative halo exchange if self.ensemble: - if self.trank != self.nchunks - 1: + if trank != self.nchunks - 1: MPI.Request.Waitall(recv_reqs) - derivatives[-1] += self.xnext - - # we don't own the control for the halo, so remove it from the - # list of local derivatives once the communication has finished - if self.trank != 0: - derivatives.pop(0) + derivs[-1] += self.xnext return derivatives @@ -514,10 +506,10 @@ def recording_stages(self, sequential=True, **kwargs): stage_kwargs['local_index'] = 0 # subsequent ranks start from halo - controls = self.controls if trank == 0 else [self.control_prev, *self.controls] + controls = self._controls if trank == 0 else [self._control_prev, *self._controls] stage_sequence = ObservationStageSequence( - controls, self, stage_kwargs, sequential, weak_constraint=True) + controls, self, stage_kwargs, sequential) yield stage_sequence @@ -541,23 +533,21 @@ def recording_stages(self, sequential=True, **kwargs): else: # strong constraint yield ObservationStageSequence( - self.controls, self, stage_kwargs, - sequential=True, weak_constraint=False) + self.controls, self, stage_kwargs, sequential=True) class ObservationStageSequence: def __init__(self, controls: Control, aaorf: AllAtOnceReducedFunctional, stage_kwargs: dict = None, - sequential: bool = True, - weak_constraint: bool = True): + sequential: bool = True): self.controls = controls self.nstages = len(controls) - 1 self.aaorf = aaorf self.ctx = StageContext(**(stage_kwargs or {})) self.index = 0 - self.weak_constraint = weak_constraint - if weak_constraint: + self.weak_constraint = aaorf.weak_constraint + if self.weak_constraint: self.stages = [] def __iter__(self): @@ -792,6 +782,9 @@ def set_observation(self, state: OverloadedType, # remove the stage initial condition "control" now we've finished recording delattr(self, "control") + # stop the stage tape recording anything else + set_working_tape() + @no_annotations def __call__(self, values: OverloadedType, rftype: Optional[str] = None): diff --git a/firedrake/adjoint_utils/__init__.py b/firedrake/adjoint_utils/__init__.py index 3b3426a850..c71da61ade 100644 --- a/firedrake/adjoint_utils/__init__.py +++ b/firedrake/adjoint_utils/__init__.py @@ -12,3 +12,4 @@ from firedrake.adjoint_utils.solving import * # noqa: F401 from firedrake.adjoint_utils.mesh import * # noqa: F401 from firedrake.adjoint_utils.checkpointing import * # noqa: F401 +from firedrake.adjoint_utils.ensemblefunction import * # noqa: F401 diff --git a/firedrake/adjoint_utils/ensemblefunction.py b/firedrake/adjoint_utils/ensemblefunction.py new file mode 100644 index 0000000000..ee1e4a4ce3 --- /dev/null +++ b/firedrake/adjoint_utils/ensemblefunction.py @@ -0,0 +1,69 @@ +from pyadjoint.overloaded_type import OverloadedType +from functools import wraps + + +class EnsembleFunctionMixin(OverloadedType): + + @staticmethod + def _ad_annotate_init(init): + @wraps(init) + def wrapper(self, *args, **kwargs): + OverloadedType.__init__(self) + init(self, *args, **kwargs) + return wrapper + + @staticmethod + def _ad_to_list(m): + raise ValueError("NotImplementedYet") + + @staticmethod + def _ad_assign_numpy(dst, src, offset): + raise ValueError("NotImplementedYet") + + def _ad_dot(self, other, options=None): + # local dot product + ldot = sum( + uself._ad_dot(uother, options=options) + for uself, uother in zip(self.subfunctions, + other.subfunctions)) + # global dot product + gdot = self.ensemble.ensemble_comm.allreduce(ldot) + return gdot + + def _ad_add(self, other): + new = self.copy() + new += other + return new + + def _ad_mul(self, other): + new = self.copy() + # `self` can be a Cofunction in which case only left multiplication with a scalar is allowed. + other = other._fbuf if type(other) is type(self) else other + new._fbuf.assign(other*new._fbuf) + return new + + def _ad_iadd(self, other): + self += other + return self + + def _ad_imul(self, other): + self *= other + return self + + def _ad_copy(self): + return self.copy() + + def _ad_convert_riesz(self, value, options=None): + raise ValueError("NotImplementedYet") + + def _ad_from_petsc(self, vec): + with self.vec_wo as self_v: + vec.copy(result=self_v) + + def _ad_to_petsc(self, vec=None): + with self.vec_ro as self_v: + if vec: + self_v.copy(result=vec) + else: + vec = self_v.copy() + return vec diff --git a/firedrake/ensemblefunction.py b/firedrake/ensemblefunction.py new file mode 100644 index 0000000000..33ed7e7b2f --- /dev/null +++ b/firedrake/ensemblefunction.py @@ -0,0 +1,295 @@ +from firedrake.petsc import PETSc +from firedrake.adjoint_utils import EnsembleFunctionMixin +from firedrake.functionspace import MixedFunctionSpace +from firedrake.function import Function +from ufl.duals import is_primal, is_dual +from pyop2 import MixedDat + +from functools import cached_property +from contextlib import contextmanager + +__all__ = ("EnsembleFunction", "EnsembleCofunction") + + +class EnsembleFunctionBase(EnsembleFunctionMixin): + """ + A mixed finite element (co)function distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The sub(co)functions are distributed + over the different ensemble members. + + function_spaces + A list of function spaces for each (co)function on the + local ensemble member. + """ + + @PETSc.Log.EventDecorator() + @EnsembleFunctionMixin._ad_annotate_init + def __init__(self, ensemble, function_spaces): + self.ensemble = ensemble + self.local_function_spaces = function_spaces + self.local_size = len(function_spaces) + + # the local functions are stored as a big mixed space + self._function_space = MixedFunctionSpace(function_spaces) + self._fbuf = Function(self._function_space) + + # create a Vec containing the data for all functions on all + # ensemble members. Because we use the Vec of each local mixed + # function as the storage, if the data in the Function Vec + # is valid then the data in the EnsembleFunction Vec is valid. + + with self._fbuf.dat.vec as fvec: + local_size = self._function_space.node_set.size + sizes = (local_size, PETSc.DETERMINE) + self._vec = PETSc.Vec().createWithArray(fvec.array, + size=sizes, + comm=ensemble.global_comm) + self._vec.setFromOptions() + + @cached_property + def subfunctions(self): + """ + The (co)functions on the local ensemble member + """ + def local_function(i): + V = self.local_function_spaces[i] + usubs = self._subcomponents(i) + if len(usubs) == 1: + dat = usubs[0].dat + else: + dat = MixedDat((u.dat for u in usubs)) + return Function(V, val=dat) + + self._subfunctions = tuple(local_function(i) + for i in range(self.local_size)) + return self._subfunctions + + def _subcomponents(self, i): + """ + Return the subfunctions of the local mixed function storage + corresponding to the i-th local function. + """ + return tuple(self._fbuf.subfunctions[j] + for j in self._component_indices(i)) + + def _component_indices(self, i): + """ + Return the indices into the local mixed function storage + corresponding to the i-th local function. + """ + V = self.local_function_spaces[i] + offset = sum(len(V) for V in self.local_function_spaces[:i]) + return tuple(offset + i for i in range(len(V))) + + @PETSc.Log.EventDecorator() + def riesz_representation(self, riesz_map="L2", **kwargs): + """ + Return the Riesz representation of this :class:`EnsembleFunction` + with respect to the given Riesz map. + + Parameters + ---------- + + riesz_map + The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. + + kwargs + other arguments to be passed to the firedrake.riesz_map. + """ + DualType = { + EnsembleFunction: EnsembleCofunction, + EnsembleCofunction: EnsembleFunction, + }[type(self)] + Vdual = [V.dual() for V in self.local_function_spaces] + riesz = DualType(self.ensemble, Vdual) + for u in riesz.subfunctions: + u.assign(u.riesz_representation(riesz_map=riesz_map, **kwargs)) + return riesz + + @PETSc.Log.EventDecorator() + def assign(self, other, subsets=None): + r"""Set the :class:`EnsembleFunction` to the value of another + :class:`EnsembleFunction` other. + + Parameters + ---------- + + other + The :class:`EnsembleFunction` to assign from. + + subsets + An iterable of :class:`pyop2.types.set.Subset`, one for each local :class:`Function`. + The values of each local function will then only + be assigned on the nodes on the corresponding subset. + """ + if type(other) is not type(self): + raise ValueError( + f"Cannot assign {type(self)} from {type(other)}") + if subsets: + for i in range(self.local_size): + self.subfunctions[i].assign( + other.subfunctions[i], subset=subsets[i]) + else: + for i in range(self.local_size): + self.subfunctions[i].assign(other.subfunctions[i]) + return self + + @PETSc.Log.EventDecorator() + def copy(self): + """ + Return a deep copy of the :class:`EnsembleFunction`. + """ + new = type(self)(self.ensemble, self.local_function_spaces) + new.assign(self) + return new + + @PETSc.Log.EventDecorator() + def zero(self, subsets=None): + """ + Set values to zero. + + Parameters + ---------- + + subsets + An iterable of :class:`pyop2.types.set.Subset`, one for each local :class:`Function`. + The values of each local function will then only + be assigned on the nodes on the corresponding subset. + """ + if subsets: + for i in range(self.local_size): + self.subfunctions[i].zero(subsets[i]) + else: + for u in self.subfunctions: + u.zero() + return self + + @PETSc.Log.EventDecorator() + def __iadd__(self, other): + for us, uo in zip(self.subfunctions, other.subfunctions): + us.assign(us + uo) + return self + + @PETSc.Log.EventDecorator() + def __imul__(self, other): + if type(other) is type(self): + for us, uo in zip(self.subfunctions, other.subfunctions): + us.assign(us*uo) + else: + for us in self.subfunctions: + us *= other + return self + + @PETSc.Log.EventDecorator() + def __add__(self, other): + new = self.copy() + for i in range(self.local_size): + new.subfunctions[i] += other.subfunctions[i] + return new + + @PETSc.Log.EventDecorator() + def __mul__(self, other): + new = self.copy() + if type(other) is type(self): + for i in range(self.local_size): + self.subfunctions[i].assign(other.subfunctions[i]*self.subfunctions[i]) + else: + for i in range(self.local_size): + self.subfunctions[i].assign(other*self.subfunctions[i]) + return new + + @PETSc.Log.EventDecorator() + def __rmul__(self, other): + return self.__mul__(other) + + @contextmanager + def vec(self): + """ + Context manager for the global PETSc Vec with read/write access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager so that the data gets copied to/from + # the Function.dat storage and _vec. + # However, this copy is done without _vec knowing, so we have + # to manually increment the state. + with self._fbuf.dat.vec: + self._vec.stateIncrease() + yield self._vec + + @contextmanager + def vec_ro(self): + """ + Context manager for the global PETSc Vec with read only access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager to make sure that the data gets copied + # to the Function.dat storage and _vec. + with self._fbuf.dat.vec_ro: + self._vec.stateIncrease() + yield self._vec + + @contextmanager + def vec_wo(self): + """ + Context manager for the global PETSc Vec with write only access. + + It is invalid to access the Vec outside of a context manager. + """ + # _fbuf.vec shares the same storage as _vec, so we need this + # nested context manager to make sure that the data gets copied + # from the Function.dat storage and _vec. + with self._fbuf.dat.vec_wo: + yield self._vec + + +class EnsembleFunction(EnsembleFunctionBase): + """ + A mixed finite element Function distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The subfunctions are distributed + over the different ensemble members. + + function_spaces + A list of function spaces for each function on the + local ensemble member. + """ + def __init__(self, ensemble, function_spaces): + if not all(is_primal(V) for V in function_spaces): + raise TypeError( + "EnsembleFunction must be created using primal FunctionSpaces") + super().__init__(ensemble, function_spaces) + + +class EnsembleCofunction(EnsembleFunctionBase): + """ + A mixed finite element Cofunction distributed over an ensemble. + + Parameters + ---------- + + ensemble + The ensemble communicator. The subcofunctions are distributed + over the different ensemble members. + + function_spaces + A list of dual function spaces for each cofunction on the + local ensemble member. + """ + def __init__(self, ensemble, function_spaces): + if not all(is_dual(V) for V in function_spaces): + raise TypeError( + "EnsembleCofunction must be created using dual FunctionSpaces") + super().__init__(ensemble, function_spaces)