diff --git a/tsfc/driver.py b/tsfc/driver.py index d5d665cd..b0e46af6 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -6,7 +6,7 @@ from functools import reduce from itertools import chain -from numpy import asarray, allclose +from numpy import asarray, allclose, isnan import ufl from ufl.algorithms import extract_arguments, extract_coefficients @@ -22,8 +22,9 @@ from FIAT.reference_element import TensorProductCell from FIAT.functional import PointEvaluation -from finat.point_set import PointSet +from finat.point_set import PointSet, UnknownPointSet from finat.quadrature import AbstractQuadratureRule, make_quadrature, QuadratureRule +from finat.quadrature_element import QuadratureElement from tsfc import fem, ufl_utils from tsfc.finatinterface import as_fiat_cell @@ -203,8 +204,8 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co config = kernel_cfg.copy() config.update(quadrature_rule=quad_rule) expressions = fem.compile_ufl(integrand, - interior_facet=interior_facet, - **config) + fem.PointSetContext(**config), + interior_facet=interior_facet) reps = mode.Integrals(expressions, quadrature_multiindex, argument_multiindices, params) for var, rep in zip(return_variables, reps): @@ -286,7 +287,8 @@ def compile_expression_dual_evaluation(expression, to_element, *, # Just convert FInAT element to FIAT for now. # Dual evaluation in FInAT will bring a thorough revision. - to_element = to_element.fiat_equivalent + finat_to_element = to_element + to_element = finat_to_element.fiat_equivalent if any(len(dual.deriv_dict) != 0 for dual in to_element.dual_basis()): raise NotImplementedError("Can only interpolate onto dual basis functionals without derivative evaluation, sorry!") @@ -357,30 +359,58 @@ def compile_expression_dual_evaluation(expression, to_element, *, index_cache={}, scalar_type=parameters["scalar_type"]) + # A FInAT QuadratureElement with a runtime tabulated UnknownPointSet + # point set is the target element on the reference cell for dual evaluation + # where the points are specified at runtime. This special casing will not + # be necessary when FInAT dual evaluation is done - the dual evaluation + # method of every FInAT element will create the necessary gem code. + from finat.tensorfiniteelement import TensorFiniteElement + runtime_quadrature_rule = ( + isinstance(finat_to_element, QuadratureElement) or + ( + isinstance(finat_to_element, TensorFiniteElement) and + isinstance(finat_to_element.base_element, QuadratureElement) + ) and + isinstance(finat_to_element._rule.point_set, UnknownPointSet) + ) + if all(isinstance(dual, PointEvaluation) for dual in to_element.dual_basis()): # This is an optimisation for point-evaluation nodes which # should go away once FInAT offers the interface properly - qpoints = [] - # Everything is just a point evaluation. - for dual in to_element.dual_basis(): - ptdict = dual.get_point_dict() - qpoint, = ptdict.keys() - (qweight, component), = ptdict[qpoint] - assert allclose(qweight, 1.0) - assert component == () - qpoints.append(qpoint) - point_set = PointSet(qpoints) config = kernel_cfg.copy() - config.update(point_set=point_set) - - # Allow interpolation onto QuadratureElements to refer to the quadrature - # rule they represent - if isinstance(to_element, FIAT.QuadratureElement): - assert allclose(asarray(qpoints), asarray(to_element._points)) - quad_rule = QuadratureRule(point_set, to_element._weights) - config["quadrature_rule"] = quad_rule - - expr, = fem.compile_ufl(expression, **config, point_sum=False) + if runtime_quadrature_rule: + # Until FInAT dual evaluation is done, FIAT + # QuadratureElements with UnknownPointSet point sets + # advertise NaNs as their points for each node in the dual + # basis. This has to be manually replaced with the real + # UnknownPointSet point set used to create the + # QuadratureElement rule. + point_set = finat_to_element._rule.point_set + config.update(point_indices=point_set.indices, point_expr=point_set.expression) + context = fem.GemPointContext(**config) + else: + qpoints = [] + # Everything is just a point evaluation. + for dual in to_element.dual_basis(): + ptdict = dual.get_point_dict() + qpoint, = ptdict.keys() + (qweight, component), = ptdict[qpoint] + assert allclose(qweight, 1.0) + assert component == () + qpoints.append(qpoint) + point_set = PointSet(qpoints) + config.update(point_set=point_set) + + # Allow interpolation onto QuadratureElements to refer to the quadrature + # rule they represent + if isinstance(to_element, FIAT.QuadratureElement): + assert allclose(asarray(qpoints), asarray(to_element._points)) + quad_rule = QuadratureRule(point_set, to_element._weights) + config["quadrature_rule"] = quad_rule + + context = fem.PointSetContext(**config) + + expr, = fem.compile_ufl(expression, context, point_sum=False) # In some cases point_set.indices may be dropped from expr, but nothing # new should now appear assert set(expr.free_indices) <= set(chain(point_set.indices, *argument_multiindices)) @@ -398,10 +428,23 @@ def compile_expression_dual_evaluation(expression, to_element, *, try: expr, point_set = expr_cache[pts] except KeyError: - point_set = PointSet(pts) config = kernel_cfg.copy() - config.update(point_set=point_set) - expr, = fem.compile_ufl(expression, **config, point_sum=False) + if runtime_quadrature_rule: + # Until FInAT dual evaluation is done, FIAT + # QuadratureElements with UnknownPointSet point sets + # advertise NaNs as their points for each node in the dual + # basis. This has to be manually replaced with the real + # UnknownPointSet point set used to create the + # QuadratureElement rule. + assert isnan(pts).all() + point_set = finat_to_element._rule.point_set + config.update(point_indices=point_set.indices, point_expr=point_set.expression) + context = fem.GemPointContext(**config) + else: + point_set = PointSet(pts) + config.update(point_set=point_set) + context = fem.PointSetContext(**config) + expr, = fem.compile_ufl(expression, context, point_sum=False) # In some cases point_set.indices may be dropped from expr, but # nothing new should now appear assert set(expr.free_indices) <= set(chain(point_set.indices, *argument_multiindices)) diff --git a/tsfc/fem.py b/tsfc/fem.py index 7dd79f6d..2c5c865c 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -340,7 +340,7 @@ def cell_avg(self, o): for name in ["ufl_cell", "index_cache", "scalar_type"]} config.update(quadrature_degree=degree, interface=self.context, argument_multiindices=argument_multiindices) - expr, = compile_ufl(integrand, point_sum=True, **config) + expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) return expr def facet_avg(self, o): @@ -357,7 +357,7 @@ def facet_avg(self, o): "integral_type"]} config.update(quadrature_degree=degree, interface=self.context, argument_multiindices=argument_multiindices) - expr, = compile_ufl(integrand, point_sum=True, **config) + expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) return expr def modified_terminal(self, o): @@ -517,7 +517,7 @@ def translate_cellvolume(terminal, mt, ctx): config = {name: getattr(ctx, name) for name in ["ufl_cell", "index_cache", "scalar_type"]} config.update(interface=interface, quadrature_degree=degree) - expr, = compile_ufl(integrand, point_sum=True, **config) + expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) return expr @@ -531,7 +531,7 @@ def translate_facetarea(terminal, mt, ctx): for name in ["ufl_cell", "integration_dim", "scalar_type", "entity_ids", "index_cache"]} config.update(interface=ctx, quadrature_degree=degree) - expr, = compile_ufl(integrand, point_sum=True, **config) + expr, = compile_ufl(integrand, PointSetContext(**config), point_sum=True) return expr @@ -693,8 +693,17 @@ def take_singleton(xs): return result -def compile_ufl(expression, interior_facet=False, point_sum=False, **kwargs): - context = PointSetContext(**kwargs) +def compile_ufl(expression, context, interior_facet=False, point_sum=False): + """Translate a UFL expression to GEM. + + :arg expression: The UFL expression to compile. + :arg context: translation context - either a :class:`GemPointContext` + or :class:`PointSetContext` + :arg interior_facet: If ``true``, treat expression as an interior + facet integral (default ``False``) + :arg point_sum: If ``true``, return a `gem.IndexSum` of the final + gem expression along the ``context.point_indices`` (if present). + """ # Abs-simplification expression = simplify_abs(expression, context.complex_mode) diff --git a/tsfc/finatinterface.py b/tsfc/finatinterface.py index 0e0b9e5a..f66cb81b 100644 --- a/tsfc/finatinterface.py +++ b/tsfc/finatinterface.py @@ -115,7 +115,7 @@ def convert_finiteelement(element, **kwargs): if degree is None or scheme is None: raise ValueError("Quadrature scheme and degree must be specified!") - return finat.QuadratureElement(cell, degree, scheme), set() + return finat.make_quadrature_element(cell, degree, scheme), set() lmbda = supported_elements[element.family()] if lmbda is None: if element.cell().cellname() == "quadrilateral":