Skip to content

Commit

Permalink
Merge pull request #2562 from firedrakeproject/connorjward/assign-wei…
Browse files Browse the repository at this point in the history
…ghted-sums-only

Only permit scalars and weighted sums for `assign`
  • Loading branch information
ReubenHill authored Nov 18, 2022
2 parents 146397a + ad6d7c8 commit 2ec31f9
Show file tree
Hide file tree
Showing 12 changed files with 347 additions and 831 deletions.
20 changes: 12 additions & 8 deletions firedrake/adjoint/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ def wrapper(self, other, *args, **kwargs):
def _ad_annotate_iadd(__iadd__):
@wraps(__iadd__)
def wrapper(self, other, **kwargs):
with stop_annotating():
func = __iadd__(self, other, **kwargs)

ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
func = __iadd__(self, other, **kwargs)

if annotate:
block = FunctionAssignBlock(func, self + other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
Expand All @@ -167,10 +168,11 @@ def wrapper(self, other, **kwargs):
def _ad_annotate_isub(__isub__):
@wraps(__isub__)
def wrapper(self, other, **kwargs):
with stop_annotating():
func = __isub__(self, other, **kwargs)

ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
func = __isub__(self, other, **kwargs)

if annotate:
block = FunctionAssignBlock(func, self - other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
Expand All @@ -185,10 +187,11 @@ def wrapper(self, other, **kwargs):
def _ad_annotate_imul(__imul__):
@wraps(__imul__)
def wrapper(self, other, **kwargs):
with stop_annotating():
func = __imul__(self, other, **kwargs)

ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
func = __imul__(self, other, **kwargs)

if annotate:
block = FunctionAssignBlock(func, self*other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
Expand All @@ -203,10 +206,11 @@ def wrapper(self, other, **kwargs):
def _ad_annotate_idiv(__idiv__):
@wraps(__idiv__)
def wrapper(self, other, **kwargs):
with stop_annotating():
func = __idiv__(self, other, **kwargs)

ad_block_tag = kwargs.pop("ad_block_tag", None)
annotate = annotate_tape(kwargs)
func = __idiv__(self, other, **kwargs)

if annotate:
block = FunctionAssignBlock(func, self/other, ad_block_tag=ad_block_tag)
tape = get_working_tape()
Expand Down
25 changes: 23 additions & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import finat
import firedrake
import numpy
from pyadjoint.tape import annotate_tape
from tsfc import kernel_args
from tsfc.finatinterface import create_element
import ufl
from firedrake import (assemble_expressions, extrusion_utils as eutils, matrix, parameters, solving,
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,
tsfc_interface, utils)
from firedrake.adjoint import annotate_assemble
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
Expand Down Expand Up @@ -100,7 +101,7 @@ def assemble(expr, *args, **kwargs):
if isinstance(expr, (ufl.form.Form, slate.TensorBase)):
return _assemble_form(expr, *args, **kwargs)
elif isinstance(expr, ufl.core.expr.Expr):
return assemble_expressions.assemble_expression(expr)
return _assemble_expr(expr)
else:
raise TypeError(f"Unable to assemble: {expr}")

Expand Down Expand Up @@ -290,6 +291,20 @@ def _assemble_form(form, tensor=None, bcs=None, *,
return assembler.assemble()


def _assemble_expr(expr):
"""Assemble a pointwise expression.
:arg expr: The :class:`ufl.core.expr.Expr` to be evaluated.
:returns: A :class:`firedrake.Function` containing the result of this evaluation.
"""
try:
coefficients = ufl.algorithms.extract_coefficients(expr)
V, = set(c.function_space() for c in coefficients) - {None}
except ValueError:
raise ValueError("Cannot deduce correct target space from pointwise expression")
return firedrake.Function(V).assign(expr)


def _check_inputs(form, tensor, bcs, diagonal):
# Ensure mesh is 'initialised' as we could have got here without building a
# function space (e.g. if integrating a constant).
Expand Down Expand Up @@ -392,6 +407,12 @@ def assemble(self):
:returns: The assembled object.
"""
if annotate_tape():
raise NotImplementedError(
"Taping with explicit FormAssembler objects is not supported yet. "
"Use assemble instead."
)

if self._needs_zeroing:
self._as_pyop2_type(self._tensor).zero()

Expand Down
Loading

0 comments on commit 2ec31f9

Please sign in to comment.