Skip to content

Commit

Permalink
Merge pull request #247 from firedrakeproject/ReubenHill/0d-mesh-inte…
Browse files Browse the repository at this point in the history
…rpolation-hack

Runtime point set FInAT QuadratureElement interpolation
  • Loading branch information
dham authored Jun 1, 2021
2 parents 231e301 + 994da9c commit 8eccb7d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 35 deletions.
99 changes: 71 additions & 28 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import reduce
from itertools import chain

from numpy import asarray, allclose
from numpy import asarray, allclose, isnan

import ufl
from ufl.algorithms import extract_arguments, extract_coefficients
Expand All @@ -22,8 +22,9 @@
from FIAT.reference_element import TensorProductCell
from FIAT.functional import PointEvaluation

from finat.point_set import PointSet
from finat.point_set import PointSet, UnknownPointSet
from finat.quadrature import AbstractQuadratureRule, make_quadrature, QuadratureRule
from finat.quadrature_element import QuadratureElement

from tsfc import fem, ufl_utils
from tsfc.finatinterface import as_fiat_cell
Expand Down Expand Up @@ -203,8 +204,8 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co
config = kernel_cfg.copy()
config.update(quadrature_rule=quad_rule)
expressions = fem.compile_ufl(integrand,
interior_facet=interior_facet,
**config)
fem.PointSetContext(**config),
interior_facet=interior_facet)
reps = mode.Integrals(expressions, quadrature_multiindex,
argument_multiindices, params)
for var, rep in zip(return_variables, reps):
Expand Down Expand Up @@ -286,7 +287,8 @@ def compile_expression_dual_evaluation(expression, to_element, *,

# Just convert FInAT element to FIAT for now.
# Dual evaluation in FInAT will bring a thorough revision.
to_element = to_element.fiat_equivalent
finat_to_element = to_element
to_element = finat_to_element.fiat_equivalent

if any(len(dual.deriv_dict) != 0 for dual in to_element.dual_basis()):
raise NotImplementedError("Can only interpolate onto dual basis functionals without derivative evaluation, sorry!")
Expand Down Expand Up @@ -357,30 +359,58 @@ def compile_expression_dual_evaluation(expression, to_element, *,
index_cache={},
scalar_type=parameters["scalar_type"])

# A FInAT QuadratureElement with a runtime tabulated UnknownPointSet
# point set is the target element on the reference cell for dual evaluation
# where the points are specified at runtime. This special casing will not
# be necessary when FInAT dual evaluation is done - the dual evaluation
# method of every FInAT element will create the necessary gem code.
from finat.tensorfiniteelement import TensorFiniteElement
runtime_quadrature_rule = (
isinstance(finat_to_element, QuadratureElement) or
(
isinstance(finat_to_element, TensorFiniteElement) and
isinstance(finat_to_element.base_element, QuadratureElement)
) and
isinstance(finat_to_element._rule.point_set, UnknownPointSet)
)

if all(isinstance(dual, PointEvaluation) for dual in to_element.dual_basis()):
# This is an optimisation for point-evaluation nodes which
# should go away once FInAT offers the interface properly
qpoints = []
# Everything is just a point evaluation.
for dual in to_element.dual_basis():
ptdict = dual.get_point_dict()
qpoint, = ptdict.keys()
(qweight, component), = ptdict[qpoint]
assert allclose(qweight, 1.0)
assert component == ()
qpoints.append(qpoint)
point_set = PointSet(qpoints)
config = kernel_cfg.copy()
config.update(point_set=point_set)

# Allow interpolation onto QuadratureElements to refer to the quadrature
# rule they represent
if isinstance(to_element, FIAT.QuadratureElement):
assert allclose(asarray(qpoints), asarray(to_element._points))
quad_rule = QuadratureRule(point_set, to_element._weights)
config["quadrature_rule"] = quad_rule

expr, = fem.compile_ufl(expression, **config, point_sum=False)
if runtime_quadrature_rule:
# Until FInAT dual evaluation is done, FIAT
# QuadratureElements with UnknownPointSet point sets
# advertise NaNs as their points for each node in the dual
# basis. This has to be manually replaced with the real
# UnknownPointSet point set used to create the
# QuadratureElement rule.
point_set = finat_to_element._rule.point_set
config.update(point_indices=point_set.indices, point_expr=point_set.expression)
context = fem.GemPointContext(**config)
else:
qpoints = []
# Everything is just a point evaluation.
for dual in to_element.dual_basis():
ptdict = dual.get_point_dict()
qpoint, = ptdict.keys()
(qweight, component), = ptdict[qpoint]
assert allclose(qweight, 1.0)
assert component == ()
qpoints.append(qpoint)
point_set = PointSet(qpoints)
config.update(point_set=point_set)

# Allow interpolation onto QuadratureElements to refer to the quadrature
# rule they represent
if isinstance(to_element, FIAT.QuadratureElement):
assert allclose(asarray(qpoints), asarray(to_element._points))
quad_rule = QuadratureRule(point_set, to_element._weights)
config["quadrature_rule"] = quad_rule

context = fem.PointSetContext(**config)

expr, = fem.compile_ufl(expression, context, point_sum=False)
# In some cases point_set.indices may be dropped from expr, but nothing
# new should now appear
assert set(expr.free_indices) <= set(chain(point_set.indices, *argument_multiindices))
Expand All @@ -398,10 +428,23 @@ def compile_expression_dual_evaluation(expression, to_element, *,
try:
expr, point_set = expr_cache[pts]
except KeyError:
point_set = PointSet(pts)
config = kernel_cfg.copy()
config.update(point_set=point_set)
expr, = fem.compile_ufl(expression, **config, point_sum=False)
if runtime_quadrature_rule:
# Until FInAT dual evaluation is done, FIAT
# QuadratureElements with UnknownPointSet point sets
# advertise NaNs as their points for each node in the dual
# basis. This has to be manually replaced with the real
# UnknownPointSet point set used to create the
# QuadratureElement rule.
assert isnan(pts).all()
point_set = finat_to_element._rule.point_set
config.update(point_indices=point_set.indices, point_expr=point_set.expression)
context = fem.GemPointContext(**config)
else:
point_set = PointSet(pts)
config.update(point_set=point_set)
context = fem.PointSetContext(**config)
expr, = fem.compile_ufl(expression, context, point_sum=False)
# In some cases point_set.indices may be dropped from expr, but
# nothing new should now appear
assert set(expr.free_indices) <= set(chain(point_set.indices, *argument_multiindices))
Expand Down
21 changes: 15 additions & 6 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def cell_avg(self, o):
for name in ["ufl_cell", "index_cache", "scalar_type"]}
config.update(quadrature_degree=degree, interface=self.context,
argument_multiindices=argument_multiindices)
expr, = compile_ufl(integrand, point_sum=True, **config)
expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True)
return expr

def facet_avg(self, o):
Expand All @@ -357,7 +357,7 @@ def facet_avg(self, o):
"integral_type"]}
config.update(quadrature_degree=degree, interface=self.context,
argument_multiindices=argument_multiindices)
expr, = compile_ufl(integrand, point_sum=True, **config)
expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True)
return expr

def modified_terminal(self, o):
Expand Down Expand Up @@ -517,7 +517,7 @@ def translate_cellvolume(terminal, mt, ctx):
config = {name: getattr(ctx, name)
for name in ["ufl_cell", "index_cache", "scalar_type"]}
config.update(interface=interface, quadrature_degree=degree)
expr, = compile_ufl(integrand, point_sum=True, **config)
expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True)
return expr


Expand All @@ -531,7 +531,7 @@ def translate_facetarea(terminal, mt, ctx):
for name in ["ufl_cell", "integration_dim", "scalar_type",
"entity_ids", "index_cache"]}
config.update(interface=ctx, quadrature_degree=degree)
expr, = compile_ufl(integrand, point_sum=True, **config)
expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True)
return expr


Expand Down Expand Up @@ -693,8 +693,17 @@ def take_singleton(xs):
return result


def compile_ufl(expression, interior_facet=False, point_sum=False, **kwargs):
context = PointSetContext(**kwargs)
def compile_ufl(expression, context, interior_facet=False, point_sum=False):
"""Translate a UFL expression to GEM.
:arg expression: The UFL expression to compile.
:arg context: translation context - either a :class:`GemPointContext`
or :class:`PointSetContext`
:arg interior_facet: If ``true``, treat expression as an interior
facet integral (default ``False``)
:arg point_sum: If ``true``, return a `gem.IndexSum` of the final
gem expression along the ``context.point_indices`` (if present).
"""

# Abs-simplification
expression = simplify_abs(expression, context.complex_mode)
Expand Down
2 changes: 1 addition & 1 deletion tsfc/finatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def convert_finiteelement(element, **kwargs):
if degree is None or scheme is None:
raise ValueError("Quadrature scheme and degree must be specified!")

return finat.QuadratureElement(cell, degree, scheme), set()
return finat.make_quadrature_element(cell, degree, scheme), set()
lmbda = supported_elements[element.family()]
if lmbda is None:
if element.cell().cellname() == "quadrilateral":
Expand Down

0 comments on commit 8eccb7d

Please sign in to comment.