Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor the riesz map out into a separate object. #3662

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
19 changes: 15 additions & 4 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,15 @@ def _init_solver_parameters(self, args, kwargs):
self.assemble_kwargs = {}

def __str__(self):
return "solve({} = {})".format(ufl2unicode(self.lhs),
ufl2unicode(self.rhs))
try:
lhs_string = ufl2unicode(self.lhs)
except AttributeError:
lhs_string = str(self.lhs)
try:
rhs_string = ufl2unicode(self.rhs)
except AttributeError:
rhs_string = str(self.rhs)
return "solve({} = {})".format(lhs_string, rhs_string)

def _create_F_form(self):
# Process the equation forms, replacing values with checkpoints,
Expand Down Expand Up @@ -756,7 +763,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
c = block_variable.output
c_rep = block_variable.saved_output

if isinstance(c, firedrake.Function):
if isinstance(c, (firedrake.Function, firedrake.Cofunction)):
trial_function = firedrake.TrialFunction(c.function_space())
elif isinstance(c, firedrake.Constant):
mesh = F_form.ufl_domain()
Expand Down Expand Up @@ -793,7 +800,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
replace_map[self.func] = self.get_outputs()[0].saved_output
dFdm = replace(dFdm, replace_map)

dFdm = dFdm * adj_sol
if isinstance(dFdm, firedrake.Argument):
# Corner case. Should be fixed more permanently upstream in UFL.
dFdm = ufl.Action(dFdm, adj_sol)
else:
dFdm = dFdm * adj_sol
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)

return dFdm
Expand Down
63 changes: 17 additions & 46 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,55 +221,18 @@ def _ad_create_checkpoint(self):
return self.copy(deepcopy=True)

def _ad_convert_riesz(self, value, options=None):
from firedrake import Function, Cofunction
from firedrake import Function

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
solver_options = options.get("solver_options", {})
V = options.get("function_space", self.function_space())
if value == 0.:
# In adjoint-based differentiation, value == 0. arises only when
# the functional is independent on the control variable.
return Function(V)

if not isinstance(value, (Cofunction, Function)):
raise TypeError("Expected a Cofunction or a Function")

if riesz_representation == "l2":
return Function(V, val=value.dat)

elif riesz_representation in ("L2", "H1"):
if not isinstance(value, Cofunction):
raise TypeError("Expected a Cofunction")

ret = Function(V)
a = self._define_riesz_map_form(riesz_representation, V)
firedrake.solve(a == value, ret, **solver_options)
return ret

elif callable(riesz_representation):
return riesz_representation(value)

else:
raise ValueError(
"Unknown Riesz representation %s" % riesz_representation)

def _define_riesz_map_form(self, riesz_representation, V):
from firedrake import TrialFunction, TestFunction
return Function(self.function_space())

u = TrialFunction(V)
v = TestFunction(V)
if riesz_representation == "L2":
a = firedrake.inner(u, v)*firedrake.dx

elif riesz_representation == "H1":
a = firedrake.inner(u, v)*firedrake.dx \
+ firedrake.inner(firedrake.grad(u), firedrake.grad(v))*firedrake.dx

else:
raise NotImplementedError(
"Unknown Riesz representation %s" % riesz_representation)
return a
return value.riesz_representation(riesz_map=riesz_representation,
solver_options=solver_options)

@no_annotations
def _ad_convert_type(self, value, options=None):
Expand All @@ -294,17 +257,16 @@ def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint

def _ad_will_add_as_dependency(self):
"""Method called when the object is added as a Block dependency.

"""
"""Method called when the object is added as a Block dependency."""
with checkpoint_init_data():
super()._ad_will_add_as_dependency()

def _ad_mul(self, other):
from firedrake import Function

r = Function(self.function_space())
# `self` can be a Cofunction in which case only left multiplication with a scalar is allowed.
# `self` can be a Cofunction in which case only left multiplication
# with a scalar is allowed.
r.assign(other * self)
return r

Expand All @@ -316,7 +278,10 @@ def _ad_add(self, other):
return r

def _ad_dot(self, other, options=None):
from firedrake import assemble
from firedrake import assemble, action, Cofunction

if isinstance(other, Cofunction):
return assemble(action(other, self))

options = {} if options is None else options
riesz_representation = options.get("riesz_representation", "L2")
Expand Down Expand Up @@ -406,3 +371,9 @@ def _ad_to_petsc(self, vec=None):

def __deepcopy__(self, memodict={}):
return self.copy(deepcopy=True)


class CofunctionMixin(FunctionMixin):

def _ad_dot(self, other):
return firedrake.assemble(firedrake.action(self, other))
Loading
Loading