diff --git a/firedrake/adjoint/function.py b/firedrake/adjoint/function.py index 131f859960..8f1d2d1477 100644 --- a/firedrake/adjoint/function.py +++ b/firedrake/adjoint/function.py @@ -149,10 +149,11 @@ def wrapper(self, other, *args, **kwargs): def _ad_annotate_iadd(__iadd__): @wraps(__iadd__) def wrapper(self, other, **kwargs): + with stop_annotating(): + func = __iadd__(self, other, **kwargs) + ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) - func = __iadd__(self, other, **kwargs) - if annotate: block = FunctionAssignBlock(func, self + other, ad_block_tag=ad_block_tag) tape = get_working_tape() @@ -167,10 +168,11 @@ def wrapper(self, other, **kwargs): def _ad_annotate_isub(__isub__): @wraps(__isub__) def wrapper(self, other, **kwargs): + with stop_annotating(): + func = __isub__(self, other, **kwargs) + ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) - func = __isub__(self, other, **kwargs) - if annotate: block = FunctionAssignBlock(func, self - other, ad_block_tag=ad_block_tag) tape = get_working_tape() @@ -185,10 +187,11 @@ def wrapper(self, other, **kwargs): def _ad_annotate_imul(__imul__): @wraps(__imul__) def wrapper(self, other, **kwargs): + with stop_annotating(): + func = __imul__(self, other, **kwargs) + ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) - func = __imul__(self, other, **kwargs) - if annotate: block = FunctionAssignBlock(func, self*other, ad_block_tag=ad_block_tag) tape = get_working_tape() @@ -203,10 +206,11 @@ def wrapper(self, other, **kwargs): def _ad_annotate_idiv(__idiv__): @wraps(__idiv__) def wrapper(self, other, **kwargs): + with stop_annotating(): + func = __idiv__(self, other, **kwargs) + ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) - func = __idiv__(self, other, **kwargs) - if annotate: block = FunctionAssignBlock(func, self/other, ad_block_tag=ad_block_tag) tape = get_working_tape() diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 7ddff571ee..024f7cb03c 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -9,10 +9,11 @@ import finat import firedrake import numpy +from pyadjoint.tape import annotate_tape from tsfc import kernel_args from tsfc.finatinterface import create_element import ufl -from firedrake import (assemble_expressions, extrusion_utils as eutils, matrix, parameters, solving, +from firedrake import (extrusion_utils as eutils, matrix, parameters, solving, tsfc_interface, utils) from firedrake.adjoint import annotate_assemble from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit @@ -100,7 +101,7 @@ def assemble(expr, *args, **kwargs): if isinstance(expr, (ufl.form.Form, slate.TensorBase)): return _assemble_form(expr, *args, **kwargs) elif isinstance(expr, ufl.core.expr.Expr): - return assemble_expressions.assemble_expression(expr) + return _assemble_expr(expr) else: raise TypeError(f"Unable to assemble: {expr}") @@ -290,6 +291,20 @@ def _assemble_form(form, tensor=None, bcs=None, *, return assembler.assemble() +def _assemble_expr(expr): + """Assemble a pointwise expression. + + :arg expr: The :class:`ufl.core.expr.Expr` to be evaluated. + :returns: A :class:`firedrake.Function` containing the result of this evaluation. + """ + try: + coefficients = ufl.algorithms.extract_coefficients(expr) + V, = set(c.function_space() for c in coefficients) - {None} + except ValueError: + raise ValueError("Cannot deduce correct target space from pointwise expression") + return firedrake.Function(V).assign(expr) + + def _check_inputs(form, tensor, bcs, diagonal): # Ensure mesh is 'initialised' as we could have got here without building a # function space (e.g. if integrating a constant). @@ -392,6 +407,12 @@ def assemble(self): :returns: The assembled object. """ + if annotate_tape(): + raise NotImplementedError( + "Taping with explicit FormAssembler objects is not supported yet. " + "Use assemble instead." + ) + if self._needs_zeroing: self._as_pyop2_type(self._tensor).zero() diff --git a/firedrake/assemble_expressions.py b/firedrake/assemble_expressions.py deleted file mode 100644 index 52b13a3ae7..0000000000 --- a/firedrake/assemble_expressions.py +++ /dev/null @@ -1,544 +0,0 @@ -import itertools -import weakref -from collections import OrderedDict, defaultdict -from functools import singledispatch - -import gem -import loopy -import numpy -import ufl -from gem.impero_utils import compile_gem, preprocess_gem -from gem.node import MemoizerArg -from gem.node import traversal as gem_traversal -from pyop2 import op2 -from pyop2.caching import cached -from pyop2.parloop import GlobalLegacyArg, DatLegacyArg -from tsfc import ufl2gem -from tsfc.loopy import generate -from tsfc.ufl_utils import ufl_reuse_if_untouched -from ufl.algorithms.apply_algebra_lowering import LowerCompoundAlgebra -from ufl.classes import (Coefficient, ComponentTensor, Expr, - Index, Indexed, MultiIndex, Terminal) -from ufl.corealg.map_dag import map_expr_dags -from ufl.corealg.multifunction import MultiFunction -from ufl.corealg.traversal import unique_pre_traversal as ufl_traversal - -import firedrake -from firedrake.petsc import PETSc -from firedrake.utils import ScalarType, cached_property, known_pyop2_safe - - -def extract_coefficients(expr): - return tuple(e for e in ufl_traversal(expr) if isinstance(e, ufl.Coefficient)) - - -class Translator(MultiFunction, ufl2gem.Mixin): - def __init__(self): - self.varmapping = OrderedDict() - MultiFunction.__init__(self) - ufl2gem.Mixin.__init__(self) - - # Override shape-based things - # Need to inspect GEM shape not UFL shape, due to Coefficients changing shape. - def sum(self, o, *ops): - shape, = set(o.shape for o in ops) - indices = gem.indices(len(shape)) - return gem.ComponentTensor(gem.Sum(*[gem.Indexed(op, indices) for op in ops]), - indices) - - def real(self, o, expr): - indices = gem.indices(len(expr.shape)) - return gem.ComponentTensor(gem.MathFunction('real', gem.Indexed(expr, indices)), - indices) - - def imag(self, o, expr): - indices = gem.indices(len(expr.shape)) - return gem.ComponentTensor(gem.MathFunction('imag', gem.Indexed(expr, indices)), - indices) - - def conj(self, o, expr): - indices = gem.indices(len(expr.shape)) - return gem.ComponentTensor(gem.MathFunction('conj', gem.Indexed(expr, indices)), - indices) - - def abs(self, o, expr): - indices = gem.indices(len(expr.shape)) - return gem.ComponentTensor(gem.MathFunction('abs', gem.Indexed(expr, indices)), - indices) - - def conditional(self, o, condition, then, else_): - assert condition.shape == () - shape, = set([then.shape, else_.shape]) - indices = gem.indices(len(shape)) - return gem.ComponentTensor(gem.Conditional(condition, gem.Indexed(then, indices), - gem.Indexed(else_, indices)), - indices) - - def indexed(self, o, aggregate, index): - return gem.Indexed(aggregate, index[:len(aggregate.shape)]) - - def index_sum(self, o, summand, indices): - index, = indices - indices = gem.indices(len(summand.shape)) - return gem.ComponentTensor(gem.IndexSum(gem.Indexed(summand, indices), (index,)), - indices) - - def component_tensor(self, o, expression, index): - index = tuple(i for i in index if i in expression.free_indices) - return gem.ComponentTensor(expression, index) - - def expr(self, o): - raise ValueError(f"Expression of type {type(o)} unsupported in pointwise expressions") - - def coefficient(self, o): - # Because we act on dofs, the ufl_shape is not the right thing to check - shape = o.dat.dim - try: - var = self.varmapping[o] - except KeyError: - name = f"C{len(self.varmapping)}" - var = gem.Variable(name, shape) - self.varmapping[o] = var - if o.ufl_shape == (): - assert shape == (1, ) - return gem.Indexed(var, (0, )) - else: - return var - - -class IndexRelabeller(MultiFunction): - def __init__(self): - super().__init__() - self._reset() - - def _reset(self): - count = itertools.count() - self.index_cache = defaultdict(lambda: Index(next(count))) - - expr = MultiFunction.reuse_if_untouched - - def multi_index(self, o): - return type(o)(tuple(self.index_cache[i] if isinstance(i, Index) else i - for i in o.indices())) - - -def flatten(shape): - if shape == (): - return shape - else: - return (numpy.prod(shape, dtype=int), ) - - -def reshape(expr, shape): - if numpy.prod(expr.ufl_shape, dtype=int) != numpy.prod(shape, dtype=int): - raise ValueError(f"Can't reshape from {expr.ufl_shape} to {shape}") - if shape == expr.ufl_shape: - return expr - if shape == (): - return expr - else: - expr = numpy.asarray([expr[i] for i in numpy.ndindex(expr.ufl_shape)]) - return ufl.as_tensor(expr.reshape(shape)) - - -@singledispatch -def _split(o, self, inct): - raise AssertionError(f"Unhandled expression type {type(o)} in splitting") - - -@_split.register(Expr) -def _split_expr(o, self, inct): - return tuple(ufl_reuse_if_untouched(o, *ops) - for ops in zip(*(self(op, inct) for op in o.ufl_operands))) - - -@_split.register(Coefficient) -def _split_coefficient(o, self, inct): - if isinstance(o, firedrake.Constant): - return tuple(o for _ in range(self.n)) - else: - split = o.split() - assert len(split) == self.n - # Reshaping to handle tensor/vector confusion. - return tuple(reshape(s, flatten(s.ufl_shape)) for s in split) - - -@_split.register(Terminal) -def _split_terminal(o, self, inct): - return tuple(o for _ in range(self.n)) - - -@_split.register(ComponentTensor) -def _split_component_tensor(o, self, inct): - expressions, multiindices = (self(op, True) for op in o.ufl_operands) - result = [] - shape_indices = set(i.count() for i in multiindices[0].indices()) - for expression, multiindex in zip(expressions, multiindices): - if shape_indices <= set(expression.ufl_free_indices): - result.append(ufl_reuse_if_untouched(o, expression, multiindex)) - else: - result.append(expression) - return tuple(result) - - -@_split.register(Indexed) -def _split_indexed(o, self, inct): - aggregate, multiindex = o.ufl_operands - indices = multiindex.indices() - result = [] - for agg in self(aggregate, False): - ncmp = len(agg.ufl_shape) - if ncmp == 0: - result.append(agg) - elif not inct: - idx = indices[:ncmp] - indices = indices[ncmp:] - mi = multiindex if multiindex.indices() == idx else MultiIndex(idx) - result.append(ufl_reuse_if_untouched(o, agg, mi)) - else: - # shape and inct - aggshape = (flatten(agg.ufl_shape) - + tuple(itertools.repeat(1, len(aggregate.ufl_shape) - 1))) - agg = reshape(agg, aggshape) - result.append(ufl_reuse_if_untouched(o, agg, multiindex)) - return tuple(result) - - -class Assign(object): - """Representation of a pointwise assignment expression.""" - relabeller = IndexRelabeller() - symbol = "=" - - __slots__ = ("lvalue", "rvalue", "__dict__", "__weakref__") - - def __init__(self, lvalue, rvalue): - """ - :arg lvalue: The coefficient to assign into. - :arg rvalue: The pointwise expression. - """ - if not isinstance(lvalue, ufl.Coefficient): - raise ValueError("lvalue for pointwise assignment must be a coefficient") - self.lvalue = lvalue - self.rvalue = ufl.as_ufl(rvalue) - n = len(self.lvalue.function_space()) - if n > 1: - self.splitter = MemoizerArg(_split) - self.splitter.n = n - - def __str__(self): - return f"{self.lvalue} {self.symbol} {self.rvalue}" - - def __repr__(self): - return f"{self.__class__.__name__}({self.lvalue!r}, {self.rvalue!r})" - - @cached_property - def coefficients(self): - """Tuple of coefficients involved in the assignment.""" - return (self.lvalue, ) + tuple(c for c in self.rcoefficients if c.dat != self.lvalue.dat) - - @cached_property - def rcoefficients(self): - """Coefficients appearing in the rvalue.""" - return extract_coefficients(self.rvalue) - - @cached_property - def split(self): - """A tuple of assignment expressions, separated by subspace for mixed spaces.""" - V = self.lvalue.function_space() - if len(V) > 1: - # rvalue cases we handle for mixed: - # 1. rvalue is a scalar constant (broadcast to all subspaces) - # 2. rvalue is a function in the same mixed space (actually - # handled by copy special-case in function.assign) - # 3. rvalue is has indexed subspaces and all indices are - # the same (assign to that subspace of the output mixed - # space) - # 4. rvalue is an expression only over mixed spaces and - # the spaces match (split and evaluate subspace-wise). - spaces = tuple(c.function_space() for c in self.rcoefficients) - indices = set(s.index for s in spaces if s is not None) - if len(indices) == 0: - # rvalue is some combination of constants - if self.rvalue.ufl_shape != (): - raise ValueError("Can only broadcast scalar constants to " - "mixed spaces in pointwise assignment") - return tuple(type(self)(s, self.rvalue) for s in self.lvalue.split()) - else: - if indices == set([None]): - if len((set(spaces) | {V}) - {None}) != 1: - # Check that there were no unindexed coefficients - raise ValueError("Saw indexed coefficients in rvalue, " - "perhaps you meant to index the lvalue with .sub(...)") - rvalues = self.splitter(self.rvalue, False) - return tuple(type(self)(lvalue, rvalue) - for lvalue, rvalue in zip(self.lvalue.split(), rvalues)) - elif indices & set([None]): - raise ValueError("Either all or non of the rvalue coefficients must have " - "a .sub(...) index") - try: - index, = indices - except ValueError: - raise ValueError("All rvalue coefficients must have the same .sub(...) index") - return (type(self)(self.lvalue.sub(index), self.rvalue), ) - else: - return (weakref.proxy(self), ) - - @property - @known_pyop2_safe - def args(self): - """Tuple of par_loop arguments for the expression.""" - args = [] - if isinstance(self, AugmentedAssign) or self.lvalue in self.rcoefficients: - args.append(self._as_weakreffed_arg(self.lvalue.dat, op2.RW)) - else: - args.append(self._as_weakreffed_arg(self.lvalue.dat, op2.WRITE)) - for c in self.rcoefficients: - if c.dat == self.lvalue.dat: - continue - args.append(self._as_weakreffed_arg(c.dat, op2.READ)) - return tuple(args) - - @cached_property - def iterset(self): - return weakref.proxy(self.lvalue.node_set) - - @cached_property - def fast_key(self): - """A fast lookup key for this expression.""" - return (type(self), hash(self.lvalue), hash(self.rvalue)) - - @cached_property - def slow_key(self): - """A slow lookup key for this expression (relabelling UFL indices).""" - self.relabeller._reset() - rvalue, = map_expr_dags(self.relabeller, [self.rvalue]) - return (type(self), hash(self.lvalue), hash(rvalue)) - - @cached_property - def par_loop_args(self): - """Arguments for a parallel loop to evaluate this expression. - - If the expression is over a mixed space, this merges kernels - for subspaces with the same node_set (resulting in fewer - par_loop calls). - """ - result = [] - grouping = OrderedDict() - for e in self.split: - grouping.setdefault(e.lvalue.node_set, []).append(e) - for iterset, exprs in grouping.items(): - k, arg_numbers = pointwise_expression_kernel(exprs, ScalarType, PETSc.Log.isActive()) - args = tuple(expr.args[i] - for expr, numbers in zip(exprs, arg_numbers) - for i in numbers) - result.append((k, iterset, args)) - return tuple(result) - - @staticmethod - def _as_weakreffed_arg(dat, access): - if isinstance(dat, op2.Global): - return GlobalLegacyArg(weakref.ref(dat), access) - elif isinstance(dat, (op2.Dat, op2.DatView)): - return DatLegacyArg(weakref.ref(dat), None, access) - else: - raise AssertionError - - -class AugmentedAssign(Assign): - """Base class for augmented pointwise assignment.""" - - -class IAdd(AugmentedAssign): - symbol = "+=" - - -class ISub(AugmentedAssign): - symbol = "-=" - - -class IMul(AugmentedAssign): - symbol = "*=" - - -class IDiv(AugmentedAssign): - symbol = "/=" - - -@PETSc.Log.EventDecorator() -def compile_to_gem(expr, translator): - """Compile a single pointwise expression to GEM. - - :arg expr: The expression to compile. - :arg translator: a :class:`Translator` instance. - :returns: A (lvalue, rvalue) pair of preprocessed GEM.""" - if not isinstance(expr, Assign): - raise ValueError(f"Don't know how to assign expression of type {type(expr)}") - spaces = tuple(c.function_space() for c in expr.coefficients) - if any(type(s.ufl_element()) is ufl.MixedElement for s in spaces if s is not None): - raise ValueError("Not expecting a mixed space at this point, " - "did you forget to index a function with .sub(...)?") - if len(set(s.ufl_element() for s in spaces if s is not None)) != 1: - raise ValueError("All coefficients must be defined on the same space") - lvalue = expr.lvalue - rvalue = expr.rvalue - broadcast = all(isinstance(c, firedrake.Constant) for c in expr.rcoefficients) and rvalue.ufl_shape == () - if not broadcast and lvalue.ufl_shape != rvalue.ufl_shape: - try: - rvalue = reshape(rvalue, lvalue.ufl_shape) - except ValueError: - raise ValueError("Mismatching shapes between lvalue and rvalue in pointwise assignment") - rvalue, = map_expr_dags(LowerCompoundAlgebra(), [rvalue]) - try: - lvalue, rvalue = map_expr_dags(translator, [lvalue, rvalue]) - except (AssertionError, ValueError): - raise ValueError("Mismatching shapes in pointwise assignment. " - "For intrinsically vector-/tensor-valued spaces make " - "sure you're not using shaped Constants or literals.") - - indices = gem.indices(len(lvalue.shape)) - if not broadcast: - if rvalue.shape != lvalue.shape: - raise ValueError("Mismatching shapes in pointwise assignment. " - "For intrinsically vector-/tensor-valued spaces make " - "sure you're not using shaped Constants or literals.") - rvalue = gem.Indexed(rvalue, indices) - lvalue = gem.Indexed(lvalue, indices) - if isinstance(expr, IAdd): - rvalue = gem.Sum(lvalue, rvalue) - elif isinstance(expr, ISub): - rvalue = gem.Sum(lvalue, gem.Product(gem.Literal(-1), rvalue)) - elif isinstance(expr, IMul): - rvalue = gem.Product(lvalue, rvalue) - elif isinstance(expr, IDiv): - rvalue = gem.Division(lvalue, rvalue) - return preprocess_gem([lvalue, rvalue]) - - -_pointwise_expression_cache = {} -"""In-memory cache for pointwise expression kernels.""" - - -def _pointwise_expression_key(exprs, scalar_type, is_logging): - """Return a cache key for use with :func:`pointwise_expression_kernel`.""" - from firedrake.interpolation import hash_expr - return (tuple((e.__class__, hash(e.lvalue), hash_expr(e.rvalue)) for e in exprs) - + (scalar_type, is_logging)) - - -@PETSc.Log.EventDecorator() -@cached(_pointwise_expression_cache, key=_pointwise_expression_key) -def pointwise_expression_kernel(exprs, scalar_type, is_logging): - """Compile a kernel for pointwise expressions. - - :arg exprs: List of expressions, all on the same iteration set. - :arg scalar_type: Default scalar type (numpy.dtype). - :arg is_logging: ``True`` if the kernel is to be annotated with PETSc events. - :returns: A 2-tuple where the first entry is a PyOP2 kernel for - evaluating the expressions and the second is a list of lists containing - the indices of the parloop arguments that are needed from ``exprs``. - """ - if len(set(e.lvalue.node_set for e in exprs)) > 1: - raise ValueError("All expressions must have same node layout.") - translator = Translator() - assignments = tuple(compile_to_gem(expr, translator) for expr in exprs) - prefix_ordering = tuple(OrderedDict.fromkeys(itertools.chain.from_iterable( - node.index_ordering() - for node in gem_traversal([v for v, _ in assignments]) - if isinstance(node, gem.Indexed)))) - impero_c = compile_gem(assignments, prefix_ordering=prefix_ordering, - remove_zeros=False, emit_return_accumulate=False) - coefficients = translator.varmapping - loopy_args = [] - parloop_arg_numbers = [] - for expr in exprs: - parloop_arg_numbers.append([]) - for i, (c, arg) in enumerate(zip(expr.coefficients, expr.args)): - try: - var = coefficients.pop(c) - except KeyError: - continue - is_input = arg.access in [op2.INC, op2.MAX, op2.MIN, op2.READ, op2.RW] - is_output = arg.access in [op2.INC, op2.MAX, op2.MIN, op2.RW, op2.WRITE] - loopy_args.append(loopy.GlobalArg(var.name, shape=var.shape, dtype=c.dat.dtype, - is_input=is_input, is_output=is_output)) - parloop_arg_numbers[-1].append(i) - assert len(coefficients) == 0 - name = "expression_kernel" - knl, event = generate(impero_c, loopy_args, scalar_type, kernel_name=name, - return_increments=False, log=is_logging) - return firedrake.op2.Kernel(knl, name, events=(event,)), parloop_arg_numbers - - -class dereffed: - def __init__(self, args): - self.args = args - - def __enter__(self): - for a in self.args: - data = a.data() - if data is None: - raise ReferenceError - a.data = a.data() - return self.args - - def __exit__(self, *args, **kwargs): - for a in self.args: - a.data = weakref.ref(a.data) - - -@PETSc.Log.EventDecorator() -@known_pyop2_safe -def evaluate_expression(expr, subset=None): - """Evaluate a pointwise expression. - - :arg expr: The expression to evaluate. - :arg subset: An optional subset to apply the expression on. - :returns: The lvalue in the provided expression.""" - lvalue = expr.lvalue - cache = lvalue._expression_cache - if cache is not None: - fast_key = expr.fast_key - try: - arguments = cache[fast_key] - except KeyError: - slow_key = expr.slow_key - try: - arguments = cache[slow_key] - cache[fast_key] = arguments - except KeyError: - arguments = None - if arguments is not None: - try: - for kernel, iterset, args in arguments: - with dereffed(args) as args: - firedrake.op2.par_loop(kernel, subset or iterset, *args) - return lvalue - except ReferenceError: - # TODO: Is there a situation where some of the kernels - # succeed and others don't? - pass - arguments = expr.par_loop_args - if cache is not None: - cache[slow_key] = arguments - cache[fast_key] = arguments - for kernel, iterset, args in arguments: - with dereffed(args) as args: - firedrake.op2.par_loop(kernel, subset or iterset, *args) - return lvalue - - -@PETSc.Log.EventDecorator() -def assemble_expression(expr, subset=None): - """Evaluate a UFL expression pointwise and assign it to a new - :class:`~.Function`. - - :arg expr: The UFL expression. - :arg subset: Optional subset to apply the expression on. - :returns: A new function.""" - try: - coefficients = extract_coefficients(expr) - V, = set(c.function_space() for c in coefficients) - {None} - except ValueError: - raise ValueError("Cannot deduce correct target space from pointwise expression") - result = firedrake.Function(V) - return evaluate_expression(Assign(result, expr), subset) diff --git a/firedrake/assign.py b/firedrake/assign.py new file mode 100644 index 0000000000..4642de8f3a --- /dev/null +++ b/firedrake/assign.py @@ -0,0 +1,257 @@ +import functools +import operator + +import numpy as np +from pyadjoint.tape import annotate_tape +from pyop2.utils import cached_property +import pytools +from ufl.algorithms import extract_coefficients +from ufl.constantvalue import as_ufl +from ufl.corealg.map_dag import map_expr_dag +from ufl.corealg.multifunction import MultiFunction + +from firedrake.constant import Constant +from firedrake.function import Function +from firedrake.petsc import PETSc +from firedrake.utils import ScalarType, split_by +from firedrake.vector import Vector + + +class CoefficientCollector(MultiFunction): + """Multifunction used for converting an expression into a weighted sum of coefficients. + + Calling ``map_expr_dag(CoefficientCollector(), expr)`` will return a tuple whose entries + are of the form ``(coefficient, weight)``. Expressions that cannot be expressed as a + weighted sum will raise an exception. + + Note: As well as being simple weighted sums (e.g. ``u.assign(2*v1 + 3*v2)``), one can + also assign constant expressions of the appropriate shape (e.g. ``u.assign(1.0)`` or + ``u.assign(2*v + 3)``). Therefore the returned tuple must be split since ``coefficient`` + may be either a :class:`firedrake.constant.Constant` or :class:`firedrake.function.Function`. + """ + + def product(self, o, a, b): + scalars, vectors = split_by(self._is_scalar_equiv, [a, b]) + # Case 1: scalar * scalar + if len(scalars) == 2: + # Compress the first argument (arbitrary) + scalar, vector = scalars + # Case 2: scalar * vector + elif len(scalars) == 1: + scalar, = scalars + vector, = vectors + # Case 3: vector * vector (invalid) + else: + raise ValueError("Expressions containing the product of two vector-valued " + "subexpressions cannot be used for assignment. Consider using " + "interpolate instead.") + scaling = self._as_scalar(scalar) + return tuple((coeff, weight*scaling) for coeff, weight in vector) + + def division(self, o, a, b): + # Division is only valid if b (the divisor) is a scalar + if self._is_scalar_equiv(b): + divisor = self._as_scalar(b) + return tuple((coeff, weight/divisor) for coeff, weight in a) + else: + raise ValueError("Expressions involving division by a vector-valued subexpression " + "cannot be used for assignment. Consider using interpolate instead.") + + def sum(self, o, a, b): + # Note: a and b are tuples of (coefficient, weight) so addition is concatenation + return a + b + + def power(self, o, a, b): + # Only valid if a and b are scalars + return ((Constant(self._as_scalar(a) ** self._as_scalar(b)), 1),) + + def abs(self, o, a): + # Only valid if a is a scalar + return ((Constant(abs(self._as_scalar(a))), 1),) + + def _scalar(self, o): + return ((Constant(o), 1),) + + int_value = _scalar + float_value = _scalar + zero = _scalar + + def multi_index(self, o): + pass + + def indexed(self, o, a, _): + return a + + def component_tensor(self, o, a, _): + return a + + def coefficient(self, o): + return ((o, 1),) + + def expr(self, o, *operands): + raise NotImplementedError(f"Handler not defined for {type(o)}") + + def _is_scalar_equiv(self, weighted_coefficients): + """Return ``True`` if the sequence of ``(coefficient, weight)`` can be compressed to + a single scalar value. + + This is only true when all coefficients are :class:`firedrake.Constant` and have + shape ``(1,)``. + """ + return all(isinstance(c, Constant) and c.dat.dim == (1,) + for (c, _) in weighted_coefficients) + + def _as_scalar(self, weighted_coefficients): + """Compress a sequence of ``(coefficient, weight)`` tuples to a single scalar value. + + This is necessary because we do not know a priori whether a :class:`firedrake.Constant` + is going to be used as a scale factor (e.g. ``u.assign(Constant(2)*v)``), or as a + constant to be added (e.g. ``u.assign(2*v + Constant(3))``). Therefore we only + compress to a scalar when we know it is required (e.g. inside a product with a + :class:`firedrake.Function`). + """ + return pytools.one( + functools.reduce(operator.add, (c.dat.data_ro*w for c, w in weighted_coefficients)) + ) + + +class Assigner: + """Class performing pointwise assignment of an expression to a :class:`firedrake.Function`. + + :param assignee: The :class:`firedrake.Function` being assigned to. + :param expression: The :class:`ufl.Expr` to evaluate. + :param subset: Optional subset (:class:`op2.Subset`) to apply the assignment over. + """ + symbol = "=" + + _coefficient_collector = CoefficientCollector() + + def __init__(self, assignee, expression, subset=None): + if isinstance(expression, Vector): + expression = expression.function + expression = as_ufl(expression) + + if not all(c.function_space() == assignee.function_space() + for c in extract_coefficients(expression) + if isinstance(c, Function)): + raise ValueError("All functions in the expression must be in the same " + "function space as the assignee") + + self._assignee = assignee + self._expression = expression + self._subset = subset + + def __str__(self): + return f"{self._assignee} {self.symbol} {self._expression}" + + def __repr__(self): + return f"{self.__class__.__name__}({self._assignee!r}, {self._expression!r})" + + @PETSc.Log.EventDecorator() + def assign(self): + """Perform the assignment.""" + if annotate_tape(): + raise NotImplementedError( + "Taping with explicit Assigner objects is not supported yet. " + "Use Function.assign instead." + ) + + if self._is_real_space: + self._assign_global(self._assignee.dat, [f.dat for f in self._functions]) + else: + # If mixed, loop over individual components + for assignee_dat, *func_dats in zip(self._assignee.dat.split, + *(f.dat.split for f in self._functions)): + self._assign_single_dat(assignee_dat, func_dats) + # Halo values are also updated + assignee_dat.halo_valid = True + + @cached_property + def _constants(self): + return tuple(c for (c, _) in self._weighted_coefficients + if isinstance(c, Constant)) + + @cached_property + def _constant_weights(self): + return tuple(w for (c, w) in self._weighted_coefficients + if isinstance(c, Constant)) + + @cached_property + def _functions(self): + return tuple(c for (c, _) in self._weighted_coefficients + if isinstance(c, Function)) + + @cached_property + def _function_weights(self): + return tuple(w for (c, w) in self._weighted_coefficients + if isinstance(c, Function)) + + @property + def _indices(self): + return self._subset.indices if self._subset else ... + + @property + def _is_real_space(self): + return self._assignee.function_space().ufl_element().family() == "Real" + + def _assign_global(self, assignee_global, function_globals): + assignee_global.data[self._indices] = self._compute_rvalue(function_globals) + + # TODO: It would be more efficient in permissible cases to use VecMAXPY instead of numpy operations. + def _assign_single_dat(self, assignee_dat, function_dats): + assignee_dat.data_with_halos[self._indices] = self._compute_rvalue(function_dats) + + def _compute_rvalue(self, function_dats=()): + # There are two components to the rvalue: weighted functions (in the same function space), + # and constants (e.g. u.assign(2*v + 3)). + if self._is_real_space: + func_data = np.array([f.data_ro[self._indices] for f in function_dats]) + else: + func_data = np.array([f.data_ro_with_halos[self._indices] for f in function_dats]) + func_rvalue = (func_data.T @ self._function_weights).T + const_data = np.array([c.dat.data_ro for c in self._constants], dtype=ScalarType) + const_rvalue = const_data.T @ self._constant_weights + return func_rvalue + const_rvalue + + @cached_property + def _weighted_coefficients(self): + # TODO: It would be nice to stash this on the expression so we can avoid extra + # traversals for non-persistent Assigner objects, but expressions do not currently + # have caches attached to them. + return map_expr_dag(self._coefficient_collector, self._expression) + + +class IAddAssigner(Assigner): + """Assigner class for :func:`firedrake.Function.__iadd__`.""" + symbol = "+=" + + def _assign_single_dat(self, assignee_dat, function_dats): + assignee_dat.data_with_halos[self._indices] += self._compute_rvalue(function_dats) + + +class ISubAssigner(Assigner): + """Assigner class for :func:`firedrake.Function.__isub__`.""" + symbol = "-=" + + def _assign_single_dat(self, assignee_dat, function_dats): + assignee_dat.data_with_halos[self._indices] -= self._compute_rvalue(function_dats) + + +class IMulAssigner(Assigner): + """Assigner class for :func:`firedrake.Function.__imul__`.""" + symbol = "*=" + + def _assign_single_dat(self, assignee_dat, function_dats): + if function_dats: + raise ValueError("Only multiplication by scalars is supported") + assignee_dat.data_with_halos[self._indices] *= self._compute_rvalue() + + +class IDivAssigner(Assigner): + """Assigner class for :func:`firedrake.Function.__itruediv__`.""" + symbol = "/=" + + def _assign_single_dat(self, assignee_dat, function_dats): + if function_dats: + raise ValueError("Only division by scalars is supported") + assignee_dat.data_with_halos[self._indices] /= self._compute_rvalue() diff --git a/firedrake/constant.py b/firedrake/constant.py index 9737e3d34b..9f767721d9 100644 --- a/firedrake/constant.py +++ b/firedrake/constant.py @@ -53,7 +53,7 @@ class Constant(ufl.Coefficient, ConstantMixin): def __new__(cls, *args, **kwargs): # Hack to avoid hitting `ufl.Coefficient.__new__` which may perform operations # meant for coefficients and not constants (e.g. check if the function space is dual or not) - # This is a consequence of firedrake.Constant inheriting from ufl.Constant instead of ufl.Coefficient. + # This is a consequence of firedrake.Constant inheriting from ufl.Coefficient instead of ufl.Constant. return object.__new__(cls) @ConstantMixin._ad_annotate_init diff --git a/firedrake/function.py b/firedrake/function.py index 46963b72b2..4e5f950fda 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -2,26 +2,20 @@ import sys import ufl from ufl.formatting.ufl2unicode import ufl2unicode +import cachetools import ctypes from collections import OrderedDict from ctypes import POINTER, c_int, c_double, c_void_p -import numbers from pyop2 import op2 from firedrake.utils import ScalarType, IntType, as_ctypes from firedrake import functionspaceimpl -from firedrake.logging import warning from firedrake import utils from firedrake import vector from firedrake.adjoint import FunctionMixin from firedrake.petsc import PETSc -try: - import cachetools -except ImportError: - warning("cachetools not available, expression assembly will be slowed down") - cachetools = None __all__ = ['Function', 'PointNotInDomainError'] @@ -268,11 +262,8 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType, self, self.function_space().ufl_function_space(), count=count ) - if cachetools: - # LRU cache for expressions assembled onto this function - self._expression_cache = cachetools.LRUCache(maxsize=50) - else: - self._expression_cache = None + # LRU cache for expressions assembled onto this function + self._expression_cache = cachetools.LRUCache(maxsize=50) if isinstance(function_space, Function): self.assign(function_space) @@ -375,7 +366,6 @@ def interpolate(self, expression, subset=None, ad_block_tag=None): @PETSc.Log.EventDecorator() @FunctionMixin._ad_annotate_assign - @utils.known_pyop2_safe def assign(self, expr, subset=None): r"""Set the :class:`Function` value to the pointwise value of expr. expr may only contain :class:`Function`\s on the same @@ -392,109 +382,42 @@ def assign(self, expr, subset=None): If present, subset must be an :class:`pyop2.Subset` of this :class:`Function`'s ``node_set``. The expression will then only be assigned to the nodes on that subset. + + .. note:: + + Assignment can only be performed for simple weighted sum expressions and constant + values. Things like ``u.assign(2*v + Constant(3.0))``. For more complicated + expressions (e.g. involving the product of functions) :meth:`.Function.interpolate` + should be used. """ - # Avoid generating code when assigning scalar values to the Real space - if (isinstance(expr, numbers.Number) - and self.function_space().ufl_element().family() == "Real"): - self.dat.data[...] = expr - return self - - expr = ufl.as_ufl(expr) - if isinstance(expr, ufl.classes.Zero): - self.dat.zero(subset=subset) - return self - elif (isinstance(expr, Function) - and expr.function_space() == self.function_space()): - expr.dat.copy(self.dat, subset=subset) - return self - - from firedrake import assemble_expressions - assemble_expressions.evaluate_expression( - assemble_expressions.Assign(self, expr), subset) + from firedrake.assign import Assigner + Assigner(self, expr, subset).assign() return self @FunctionMixin._ad_annotate_iadd - @utils.known_pyop2_safe def __iadd__(self, expr): - - if np.isscalar(expr): - self.dat += expr - return self - if isinstance(expr, vector.Vector): - expr = expr.function - if isinstance(expr, Function) and \ - expr.function_space() == self.function_space(): - self.dat += expr.dat - return self - - from firedrake import assemble_expressions - assemble_expressions.evaluate_expression( - assemble_expressions.IAdd(self, expr)) - + from firedrake.assign import IAddAssigner + IAddAssigner(self, expr).assign() return self @FunctionMixin._ad_annotate_isub - @utils.known_pyop2_safe def __isub__(self, expr): - - if np.isscalar(expr): - self.dat -= expr - return self - if isinstance(expr, vector.Vector): - expr = expr.function - if isinstance(expr, Function) and \ - expr.function_space() == self.function_space(): - self.dat -= expr.dat - return self - - from firedrake import assemble_expressions - assemble_expressions.evaluate_expression( - assemble_expressions.ISub(self, expr)) - + from firedrake.assign import ISubAssigner + ISubAssigner(self, expr).assign() return self @FunctionMixin._ad_annotate_imul - @utils.known_pyop2_safe def __imul__(self, expr): - - if np.isscalar(expr): - self.dat *= expr - return self - if isinstance(expr, vector.Vector): - expr = expr.function - if isinstance(expr, Function) and \ - expr.function_space() == self.function_space(): - self.dat *= expr.dat - return self - - from firedrake import assemble_expressions - assemble_expressions.evaluate_expression( - assemble_expressions.IMul(self, expr)) - + from firedrake.assign import IMulAssigner + IMulAssigner(self, expr).assign() return self @FunctionMixin._ad_annotate_idiv - @utils.known_pyop2_safe - def __idiv__(self, expr): - - if np.isscalar(expr): - self.dat /= expr - return self - if isinstance(expr, vector.Vector): - expr = expr.function - if isinstance(expr, Function) and \ - expr.function_space() == self.function_space(): - self.dat /= expr.dat - return self - - from firedrake import assemble_expressions - assemble_expressions.evaluate_expression( - assemble_expressions.IDiv(self, expr)) - + def __itruediv__(self, expr): + from firedrake.assign import IDivAssigner + IDivAssigner(self, expr).assign() return self - __itruediv__ = __idiv__ - def __float__(self): if ( diff --git a/firedrake/utils.py b/firedrake/utils.py index 0bc0e522e5..b7bee935e6 100644 --- a/firedrake/utils.py +++ b/firedrake/utils.py @@ -95,3 +95,22 @@ def tuplify(item): if not isinstance(item, dict): raise ValueError(f"tuplify does not know how to handle objects of type {type(item)}") return tuple((k, tuplify(item[k])) for k in sorted(item)) + + +def split_by(condition, items): + """Split an iterable in two according to some condition. + + :arg condition: Callable applied to each item in ``items``, returning ``True`` + or ``False``. + :arg items: Iterable to split apart. + :returns: A 2-tuple of the form ``(yess, nos)``, where ``yess`` is a tuple containing + the entries of ``items`` where ``condition`` is ``True`` and ``nos`` is a tuple + of those where ``condition`` is ``False``. + """ + result = [], [] + for item in items: + if condition(item): + result[0].append(item) + else: + result[1].append(item) + return tuple(result[0]), tuple(result[1]) diff --git a/tests/multigrid/test_p_multigrid.py b/tests/multigrid/test_p_multigrid.py index b799de52da..17e4e33296 100644 --- a/tests/multigrid/test_p_multigrid.py +++ b/tests/multigrid/test_p_multigrid.py @@ -50,7 +50,7 @@ def test_prolongation_matrix_matfree(): for u in us: for v in us: if u != v: - v.assign(zero(v.ufl_shape)) + v.assign(0) P = prolongation_matrix_matfree(v, u).getPythonContext() P._prolong() assert norm(v-expr, "L2") < tol diff --git a/tests/regression/test_adjoint_operators.py b/tests/regression/test_adjoint_operators.py index 20ebd50908..c63e19d4e8 100644 --- a/tests/regression/test_adjoint_operators.py +++ b/tests/regression/test_adjoint_operators.py @@ -25,7 +25,7 @@ def handle_annotation(): pause_annotation() -@pytest.fixture(params=['iadd', 'isub', 'imul', 'idiv']) +@pytest.fixture(params=['iadd', 'isub']) def op(request): return request.param @@ -469,10 +469,6 @@ def test_ioperator_replay(op, order, power): t += s elif op == 'isub': t -= s - elif op == 'imul': - t *= s - elif op == 'idiv': - t /= s else: raise ValueError("Operator '{:s}' not recognised".format(op)) @@ -497,12 +493,6 @@ def test_ioperator_replay(op, order, power): elif op == 'isub': ss -= ss tt -= tt - elif op == 'imul': - ss *= ss - tt *= tt - elif op == 'idiv': - ss /= ss - tt /= tt assert np.isclose(rf_s(t_orig), assemble(f(tt)*dx)) assert np.isclose(rf_t(s_orig), assemble(f(ss)*dx)) diff --git a/tests/regression/test_expressions.py b/tests/regression/test_expressions.py index d4e20bf9aa..f6a7900ed2 100644 --- a/tests/regression/test_expressions.py +++ b/tests/regression/test_expressions.py @@ -1,7 +1,6 @@ from operator import iadd, isub, imul, itruediv from functools import partial from itertools import permutations -from firedrake.assemble_expressions import Assign, evaluate_expression import pytest @@ -60,8 +59,7 @@ def func_factory(fs): f = Function(fs, name="f") one = Function(fs, name="one").assign(1) two = Function(fs, name="two").assign(2) - minusthree = Function(fs, name="minusthree").assign(-3) - return f, one, two, minusthree + return f, one, two @pytest.fixture() @@ -147,67 +145,33 @@ def ioptest(f, expr, x, op): 'iaddtest(f, 2, 2)', 'isubtest(two, 1, 1)', 'imultest(one, 2, 2)', - 'imultest(one, two, 2)', 'itruedivtest(two, 2, 1)', - 'itruedivtest(one, two, 0.5)', 'isubtest(one, one, 0)', 'assigntest(f, 2 * one, 2)', 'assigntest(f, one - one, 0)'] -scalar_tests = common_tests + [ - 'assigntest(f, sqrt(one), 1)', - 'exprtest(ufl.ln(one), 0)', - 'exprtest(two ** minusthree, 0.125)', - 'exprtest(ufl.sign(real(minusthree)), -1)', - 'exprtest(one + two / two ** minusthree, 17)'] - -mixed_tests = common_tests + [ - 'assigntest(f, one.sub(0), (1, 0))', - 'assigntest(f, one.sub(1), (0, 1))', - 'assigntest(two, one.sub(0), (1, 2))', - 'assigntest(two, one.sub(1), (2, 1))', - 'assigntest(two, one.sub(0) + two.sub(0), (3, 2))', - 'assigntest(two, two.sub(1) - one.sub(1), (2, 1))', - 'iaddtest(one, one.sub(0), (2, 1))', - 'iaddtest(one, one.sub(1), (1, 2))', -] - -indexed_fs_tests = [ - 'assigntest(f, one, (1, 0))', - 'assigntest(f, two, (0, 2))', - 'iaddtest(f, one, (1, 0))', - 'iaddtest(f, two, (0, 2))', - 'isubtest(f, one, (-1, 0))', - 'isubtest(f, two, (0, -2))'] - - -@pytest.mark.parametrize('expr', scalar_tests) + +@pytest.mark.parametrize('expr', common_tests) def test_scalar_expressions(expr, functions): - f, one, two, minusthree = functions + f, one, two = functions assert eval(expr) @pytest.mark.parametrize('expr', common_tests) def test_vector_expressions(expr, vfunctions): - f, one, two, minusthree = vfunctions + f, one, two = vfunctions assert eval(expr) @pytest.mark.parametrize('expr', common_tests) def test_tensor_expressions(expr, tfunctions): - f, one, two, minusthree = tfunctions + f, one, two = tfunctions assert eval(expr) -@pytest.mark.parametrize('expr', mixed_tests) +@pytest.mark.parametrize('expr', common_tests) def test_mixed_expressions(expr, mfunctions): - f, one, two, minusthree = mfunctions - assert eval(expr) - - -@pytest.mark.parametrize('expr', indexed_fs_tests) -def test_mixed_expressions_indexed_fs(expr, msfunctions): - f, one, two = msfunctions + f, one, two = mfunctions assert eval(expr) @@ -261,26 +225,6 @@ def test_asign_to_nonindexed_subspace_fails(mfs): Function(mfs).assign(Function(f)) -def test_assign_mixed_no_nan(mfs): - w = Function(mfs) - vs = w.split() - vs[0].assign(2) - w /= vs[0] - assert np.allclose(vs[0].dat.data_ro, 1.0) - for v in vs[1:]: - assert not np.isnan(v.dat.data_ro).any() - - -def test_assign_mixed_no_zero(mfs): - w = Function(mfs) - vs = w.split() - w.assign(2) - w *= vs[0] - assert np.allclose(vs[0].dat.data_ro, 4.0) - for v in vs[1:]: - assert np.allclose(v.dat.data_ro, 2.0) - - def test_assign_vector_const_to_vfs(vcg1): f = Function(vcg1) @@ -337,9 +281,6 @@ def test_assign_to_mfs_sub(cg1, vcg1): with pytest.raises(ValueError): w.sub(0).assign(v) - w.sub(0).assign(ufl.ln(q.sub(1))) - assert np.allclose(w.sub(0).dat.data_ro, ufl.ln(11)) - with pytest.raises(ValueError): w.assign(q.sub(1)) @@ -392,7 +333,7 @@ def test_assign_from_mfs_sub(cg1, vcg1): @pytest.mark.parametrize('value', [10, -10], ids=lambda v: "(f = %d)" % v) -@pytest.mark.parametrize('expr', ['f', '2*f', 'tanh(f)']) +@pytest.mark.parametrize('expr', ['f', '2*f']) def test_math_functions(expr, value): mesh = UnitSquareMesh(2, 2) V = FunctionSpace(mesh, 'CG', 1) @@ -407,26 +348,6 @@ def test_math_functions(expr, value): assert np.allclose(actual.dat.data_ro, expect) -@pytest.mark.parametrize('fn', [min_value, max_value]) -def test_minmax(fn): - mesh = UnitTriangleMesh() - V = FunctionSpace(mesh, "DG", 0) - - f = interpolate(as_ufl(1), V) - g = interpolate(as_ufl(2), V) - - h = Function(V) - - h.assign(fn(real(f), real(g))) - - if fn == min_value: - expect = 1 - else: - expect = 2 - - assert np.allclose(h.dat.data_ro, expect) - - def test_assign_mixed_multiple_shaped(): mesh = UnitTriangleMesh() V = VectorFunctionSpace(mesh, "DG", 0) @@ -454,71 +375,6 @@ def test_assign_mixed_multiple_shaped(): assert np.allclose(q.dat.data_ro, p1.dat.data_ro - p2.dat.data_ro) -def test_expression_cache(): - mesh = UnitSquareMesh(1, 1) - V = VectorFunctionSpace(mesh, "CG", 1) - W = TensorFunctionSpace(mesh, "CG", 1) - u = Function(V) - v = Function(V) - w = Function(W) - - i, j = indices(2) - exprA = Assign(u, as_vector(2*u[i], i)) - exprB = Assign(u, as_vector(2*u[j], j)) - - assert len(u._expression_cache) == 0 - - evaluate_expression(exprA) - - assert exprA.fast_key in u._expression_cache - assert exprA.slow_key in u._expression_cache - assert exprB.fast_key not in u._expression_cache - assert exprB.slow_key in u._expression_cache - - evaluate_expression(exprB) - assert exprB.fast_key in u._expression_cache - assert exprA.fast_key in u._expression_cache - - assert exprB.slow_key == exprA.slow_key - - assert len(u._expression_cache) == 3 - - u.assign(as_vector([1, 2])) - u.assign(as_vector(2*u[i], i)) - v.assign(as_vector(2*u[j], j)) - w.assign(as_tensor([[1, 2], [0, 3]])) - w.assign(as_tensor(w[i, j]+w[j, i], (i, j))) - - u -= as_vector([2, 4]) - assert u.dat.norm < 1e-15 - v -= as_vector([4, 8]) - assert v.dat.norm < 1e-15 - w -= as_tensor([[2, 2], [2, 6]]) - assert w.dat.norm < 1e-15 - - assert len(u._expression_cache) == 5 - - -def test_global_expression_cache(): - from firedrake.assemble_expressions import _pointwise_expression_cache - - mesh = UnitSquareMesh(1, 1) - V = VectorFunctionSpace(mesh, "CG", 1) - u = Function(V) - - _pointwise_expression_cache.clear() - assert len(_pointwise_expression_cache) == 0 - - u.assign(Constant(1)) - assert len(_pointwise_expression_cache) == 1 - - u.assign(Constant(2)) - assert len(_pointwise_expression_cache) == 1 - - u.assign(1) - assert len(_pointwise_expression_cache) == 2 - - def test_augmented_assignment_broadcast(): mesh = UnitSquareMesh(1, 1) V = FunctionSpace(mesh, "BDM", 1) diff --git a/tests/regression/test_real_space.py b/tests/regression/test_real_space.py index 3f8f87f700..8440211c77 100644 --- a/tests/regression/test_real_space.py +++ b/tests/regression/test_real_space.py @@ -258,11 +258,11 @@ def test_real_space_assign(): f = Function(V) f.assign(2) g = Function(V) - g.assign(2*f + f**3) + g.assign(2*f) h = Function(V) h.assign(0.0) assert np.allclose(float(f), 2.0) - assert np.allclose(float(g), 12.0) + assert np.allclose(float(g), 4.0) assert np.allclose(float(h), 0.0) diff --git a/tests/vertexonly/test_poisson_inverse_conductivity.py b/tests/vertexonly/test_poisson_inverse_conductivity.py index edee02744e..9073e7355e 100644 --- a/tests/vertexonly/test_poisson_inverse_conductivity.py +++ b/tests/vertexonly/test_poisson_inverse_conductivity.py @@ -1,7 +1,7 @@ import pytest import numpy as np from firedrake import * -from pyadjoint.tape import get_working_tape, pause_annotation, continue_annotation, set_working_tape +from pyadjoint.tape import get_working_tape, pause_annotation @pytest.fixture(autouse=True) @@ -29,12 +29,6 @@ def test_poisson_inverse_conductivity(): # Have to import inside test to make sure cleanup fixtures work as intended from firedrake_adjoint import Control, ReducedFunctional, minimize - # Manually set up annotation since test suite may have stopped it - tape = get_working_tape() - tape.clear_tape() - set_working_tape(tape) - continue_annotation() - # Use pyadjoint to estimate an unknown conductivity in a # poisson-like forward model from point measurements m = UnitSquareMesh(2, 2) @@ -98,7 +92,3 @@ def test_poisson_inverse_conductivity(): # Estimate q using Newton-CG which evaluates the hessian action minimize(Ĵ, method='Newton-CG', options={'disp': True}) - - # Make sure annotation is stopped - tape.clear_tape() - pause_annotation()