Skip to content

Commit

Permalink
Merge pull request #250 from firedrakeproject/ReubenHill/finat-dual-eval
Browse files Browse the repository at this point in the history
FInAT Dual Evaluation
  • Loading branch information
wence- authored Aug 26, 2021
2 parents 9d3b173 + a32ddc0 commit 3fe9dcf
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 132 deletions.
6 changes: 3 additions & 3 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,9 +866,9 @@ def index_sum(expression, indices):


def partial_indexed(tensor, indices):
"""Generalised indexing into a tensor. The number of indices may
be less than or equal to the rank of the tensor, so the result may
have a non-empty shape.
"""Generalised indexing into a tensor by eating shape off the front.
The number of indices may be less than or equal to the rank of the tensor,
so the result may have a non-empty shape.
:arg tensor: tensor-valued GEM expression
:arg indices: indices, at most as many as the rank of the tensor
Expand Down
17 changes: 15 additions & 2 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,18 @@ def traverse_sum(expression, stop_at=None):
return result


def contraction(expression):
def contraction(expression, ignore=None):
"""Optimise the contractions of the tensor product at the root of
the expression, including:
- IndexSum-Delta cancellation
- Sum factorisation
:arg ignore: Optional set of indices to ignore when applying sum
factorisation (otherwise all summation indices will be
considered). Use this if your expression has many contraction
indices.
This routine was designed with finite element coefficient
evaluation in mind.
"""
Expand All @@ -498,7 +503,15 @@ def contraction(expression):
def rebuild(expression):
sum_indices, factors = delta_elimination(*traverse_product(expression))
factors = remove_componenttensors(factors)
return sum_factorise(sum_indices, factors)
if ignore is not None:
# TODO: This is a really blunt instrument and one might
# plausibly want the ignored indices to be contracted on
# the inside rather than the outside.
extra = tuple(i for i in sum_indices if i in ignore)
to_factor = tuple(i for i in sum_indices if i not in ignore)
return IndexSum(sum_factorise(to_factor, factors), extra)
else:
return sum_factorise(sum_indices, factors)

# Sometimes the value shape is composed as a ListTensor, which
# could get in the way of decomposing factors. In particular,
Expand Down
64 changes: 64 additions & 0 deletions tests/test_interpolation_factorisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from functools import partial
import numpy
import pytest

from ufl import (Mesh, FunctionSpace, FiniteElement, VectorElement,
TensorElement, Coefficient,
interval, quadrilateral, hexahedron)

from tsfc.driver import compile_expression_dual_evaluation
from tsfc.finatinterface import create_element


@pytest.fixture(params=[interval, quadrilateral, hexahedron],
ids=lambda x: x.cellname())
def mesh(request):
return Mesh(VectorElement("P", request.param, 1))


@pytest.fixture(params=[FiniteElement, VectorElement, TensorElement],
ids=lambda x: x.__name__)
def element(request, mesh):
if mesh.ufl_cell() == interval:
family = "DP"
else:
family = "DQ"
return partial(request.param, family, mesh.ufl_cell())


def flop_count(mesh, source, target):
Vtarget = FunctionSpace(mesh, target)
Vsource = FunctionSpace(mesh, source)
to_element = create_element(Vtarget.ufl_element())
expr = Coefficient(Vsource)
kernel = compile_expression_dual_evaluation(expr, to_element)
return kernel.flop_count


def test_sum_factorisation(mesh, element):
# Interpolation between sum factorisable elements should cost
# O(p^{d+1})
degrees = numpy.asarray([2**n - 1 for n in range(2, 9)])
flops = []
for lo, hi in zip(degrees - 1, degrees):
flops.append(flop_count(mesh, element(int(lo)), element(int(hi))))
flops = numpy.asarray(flops)
rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees))
assert (rates < (mesh.topological_dimension()+1)).all()


def test_sum_factorisation_scalar_tensor(mesh, element):
# Interpolation into tensor elements should cost value_shape
# more than the equivalent scalar element.
degree = 2**7 - 1
source = element(degree - 1)
target = element(degree)
tensor_flops = flop_count(mesh, source, target)
expect = numpy.prod(target.value_shape())
if isinstance(target, FiniteElement):
scalar_flops = tensor_flops
else:
target = target.sub_elements()[0]
source = source.sub_elements()[0]
scalar_flops = flop_count(mesh, source, target)
assert numpy.allclose(tensor_flops / scalar_flops, expect, rtol=1e-2)
208 changes: 81 additions & 127 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import reduce
from itertools import chain

from numpy import asarray, allclose, isnan
from numpy import asarray

import ufl
from ufl.algorithms import extract_arguments, extract_coefficients
Expand All @@ -18,13 +18,10 @@
import gem
import gem.impero_utils as impero_utils

import FIAT
from FIAT.reference_element import TensorProductCell
from FIAT.functional import PointEvaluation

from finat.point_set import PointSet, UnknownPointSet
from finat.quadrature import AbstractQuadratureRule, make_quadrature, QuadratureRule
from finat.quadrature_element import QuadratureElement
import finat
from finat.quadrature import AbstractQuadratureRule, make_quadrature

from tsfc import fem, ufl_utils
from tsfc.finatinterface import as_fiat_cell
Expand Down Expand Up @@ -282,14 +279,6 @@ def compile_expression_dual_evaluation(expression, to_element, *,
:arg parameters: parameters object
:returns: Loopy-based ExpressionKernel object.
"""
# Just convert FInAT element to FIAT for now.
# Dual evaluation in FInAT will bring a thorough revision.
finat_to_element = to_element
to_element = finat_to_element.fiat_equivalent

if any(len(dual.deriv_dict) != 0 for dual in to_element.dual_basis()):
raise NotImplementedError("Can only interpolate onto dual basis functionals without derivative evaluation, sorry!")

if parameters is None:
parameters = default_parameters()
else:
Expand All @@ -302,7 +291,7 @@ def compile_expression_dual_evaluation(expression, to_element, *,

# Find out which mapping to apply
try:
mapping, = set(to_element.mapping())
mapping, = set((to_element.mapping,))
except ValueError:
raise NotImplementedError("Don't know how to interpolate onto zany spaces, sorry")
expression = apply_mapping(expression, mapping, domain)
Expand Down Expand Up @@ -341,7 +330,7 @@ def compile_expression_dual_evaluation(expression, to_element, *,
# Split mixed coefficients
expression = ufl_utils.split_coefficients(expression, builder.coefficient_split)

# Translate to GEM
# Set up kernel config for translation of UFL expression to gem
kernel_cfg = dict(interface=builder,
ufl_cell=domain.ufl_cell(),
# FIXME: change if we ever implement
Expand All @@ -351,132 +340,97 @@ def compile_expression_dual_evaluation(expression, to_element, *,
index_cache={},
scalar_type=parameters["scalar_type"])

# A FInAT QuadratureElement with a runtime tabulated UnknownPointSet
# point set is the target element on the reference cell for dual evaluation
# where the points are specified at runtime. This special casing will not
# be necessary when FInAT dual evaluation is done - the dual evaluation
# method of every FInAT element will create the necessary gem code.
from finat.tensorfiniteelement import TensorFiniteElement
runtime_quadrature_rule = (
isinstance(finat_to_element, QuadratureElement) or
(
isinstance(finat_to_element, TensorFiniteElement) and
isinstance(finat_to_element.base_element, QuadratureElement)
) and
isinstance(finat_to_element._rule.point_set, UnknownPointSet)
)

if all(isinstance(dual, PointEvaluation) for dual in to_element.dual_basis()):
# This is an optimisation for point-evaluation nodes which
# should go away once FInAT offers the interface properly
config = kernel_cfg.copy()
if runtime_quadrature_rule:
# Until FInAT dual evaluation is done, FIAT
# QuadratureElements with UnknownPointSet point sets
# advertise NaNs as their points for each node in the dual
# basis. This has to be manually replaced with the real
# UnknownPointSet point set used to create the
# QuadratureElement rule.
point_set = finat_to_element._rule.point_set
config.update(point_indices=point_set.indices, point_expr=point_set.expression)
context = fem.GemPointContext(**config)
else:
qpoints = []
# Everything is just a point evaluation.
for dual in to_element.dual_basis():
ptdict = dual.get_point_dict()
qpoint, = ptdict.keys()
(qweight, component), = ptdict[qpoint]
assert allclose(qweight, 1.0)
assert component == ()
qpoints.append(qpoint)
point_set = PointSet(qpoints)
config.update(point_set=point_set)

# Allow interpolation onto QuadratureElements to refer to the quadrature
# rule they represent
if isinstance(to_element, FIAT.QuadratureElement):
assert allclose(asarray(qpoints), asarray(to_element._points))
quad_rule = QuadratureRule(point_set, to_element._weights)
config["quadrature_rule"] = quad_rule

context = fem.PointSetContext(**config)

expr, = fem.compile_ufl(expression, context, point_sum=False)
# In some cases point_set.indices may be dropped from expr, but nothing
# new should now appear
assert set(expr.free_indices) <= set(chain(point_set.indices, *argument_multiindices))
shape_indices = tuple(gem.Index() for _ in expr.shape)
basis_indices = point_set.indices
ir = gem.Indexed(expr, shape_indices)
else:
# This is general code but is more unrolled than necssary.
dual_expressions = [] # one for each functional
broadcast_shape = len(expression.ufl_shape) - len(to_element.value_shape())
shape_indices = tuple(gem.Index() for _ in expression.ufl_shape[:broadcast_shape])
expr_cache = {} # Sharing of evaluation of the expression at points
for dual in to_element.dual_basis():
pts = tuple(sorted(dual.get_point_dict().keys()))
try:
expr, point_set = expr_cache[pts]
except KeyError:
config = kernel_cfg.copy()
if runtime_quadrature_rule:
# Until FInAT dual evaluation is done, FIAT
# QuadratureElements with UnknownPointSet point sets
# advertise NaNs as their points for each node in the dual
# basis. This has to be manually replaced with the real
# UnknownPointSet point set used to create the
# QuadratureElement rule.
assert isnan(pts).all()
point_set = finat_to_element._rule.point_set
config.update(point_indices=point_set.indices, point_expr=point_set.expression)
context = fem.GemPointContext(**config)
else:
point_set = PointSet(pts)
config.update(point_set=point_set)
context = fem.PointSetContext(**config)
expr, = fem.compile_ufl(expression, context, point_sum=False)
# In some cases point_set.indices may be dropped from expr, but
# nothing new should now appear
assert set(expr.free_indices) <= set(chain(point_set.indices, *argument_multiindices))
expr = gem.partial_indexed(expr, shape_indices)
expr_cache[pts] = expr, point_set
weights = collections.defaultdict(list)
for p in pts:
for (w, cmp) in dual.get_point_dict()[p]:
weights[cmp].append(w)
qexprs = gem.Zero()
for cmp in sorted(weights):
qweights = gem.Literal(weights[cmp])
qexpr = gem.Indexed(expr, cmp)
qexpr = gem.index_sum(gem.Indexed(qweights, point_set.indices)*qexpr,
point_set.indices)
qexprs = gem.Sum(qexprs, qexpr)
assert qexprs.shape == ()
assert set(qexprs.free_indices) == set(chain(shape_indices, *argument_multiindices))
dual_expressions.append(qexprs)
basis_indices = (gem.Index(), )
ir = gem.Indexed(gem.ListTensor(dual_expressions), basis_indices)
# Allow interpolation onto QuadratureElements to refer to the quadrature
# rule they represent
if isinstance(to_element, finat.QuadratureElement):
kernel_cfg["quadrature_rule"] = to_element._rule

# Create callable for translation of UFL expression to gem
fn = DualEvaluationCallable(expression, kernel_cfg)

# Get the gem expression for dual evaluation and corresponding basis
# indices needed for compilation of the expression
evaluation, basis_indices = to_element.dual_evaluation(fn)

# Build kernel body
return_indices = basis_indices + shape_indices + tuple(chain(*argument_multiindices))
return_indices = basis_indices + tuple(chain(*argument_multiindices))
return_shape = tuple(i.extent for i in return_indices)
return_var = gem.Variable('A', return_shape)
return_expr = gem.Indexed(return_var, return_indices)

# TODO: one should apply some GEM optimisations as in assembly,
# but we don't for now.
ir, = impero_utils.preprocess_gem([ir])
impero_c = impero_utils.compile_gem([(return_expr, ir)], return_indices)
evaluation, = impero_utils.preprocess_gem([evaluation])
impero_c = impero_utils.compile_gem([(return_expr, evaluation)], return_indices)
index_names = dict((idx, "p%d" % i) for (i, idx) in enumerate(basis_indices))
# Handle kernel interface requirements
builder.register_requirements([ir])
builder.register_requirements([evaluation])
builder.set_output(return_var)
# Build kernel tuple
return builder.construct_kernel(impero_c, index_names, first_coefficient_fake_coords)


class DualEvaluationCallable(object):
"""
Callable representing a function to dual evaluate.
When called, this takes in a
:class:`finat.point_set.AbstractPointSet` and returns a GEM
expression for evaluation of the function at those points.
:param expression: UFL expression for the function to dual evaluate.
:param kernel_cfg: A kernel configuration for creation of a
:class:`GemPointContext` or a :class:`PointSetContext`
Not intended for use outside of
:func:`compile_expression_dual_evaluation`.
"""
def __init__(self, expression, kernel_cfg):
self.expression = expression
self.kernel_cfg = kernel_cfg

def __call__(self, ps):
"""The function to dual evaluate.
:param ps: The :class:`finat.point_set.AbstractPointSet` for
evaluating at
:returns: a gem expression representing the evaluation of the
input UFL expression at the given point set ``ps``.
For point set points with some shape ``(*value_shape)``
(i.e. ``()`` for scalar points ``(x)`` for vector points
``(x, y)`` for tensor points etc) then the gem expression
has shape ``(*value_shape)`` and free indices corresponding
to the input :class:`finat.point_set.AbstractPointSet`'s
free indices alongside any input UFL expression free
indices.
"""

if not isinstance(ps, finat.point_set.AbstractPointSet):
raise ValueError("Callable argument not a point set!")

# Avoid modifying saved kernel config
kernel_cfg = self.kernel_cfg.copy()

if isinstance(ps, finat.point_set.UnknownPointSet):
# Run time known points
kernel_cfg.update(point_indices=ps.indices, point_expr=ps.expression)
# GemPointContext's aren't allowed to have quadrature rules
kernel_cfg.pop("quadrature_rule", None)
translation_context = fem.GemPointContext(**kernel_cfg)
else:
# Compile time known points
kernel_cfg.update(point_set=ps)
translation_context = fem.PointSetContext(**kernel_cfg)

gem_expr, = fem.compile_ufl(self.expression, translation_context, point_sum=False)
# In some cases ps.indices may be dropped from expr, but nothing
# new should now appear
argument_multiindices = kernel_cfg["argument_multiindices"]
assert set(gem_expr.free_indices) <= set(chain(ps.indices, *argument_multiindices))

return gem_expr


def lower_integral_type(fiat_cell, integral_type):
"""Lower integral type into the dimension of the integration
subentity and a list of entity numbers for that dimension.
Expand Down

0 comments on commit 3fe9dcf

Please sign in to comment.