Skip to content

Commit

Permalink
Merge pull request #282 from firedrakeproject/ksagiyam/fix-issue-274-…
Browse files Browse the repository at this point in the history
…retry

Ksagiyam/fix issue 274 retry
  • Loading branch information
ksagiyam authored Sep 8, 2022
2 parents bcea761 + 700feab commit 351994d
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 8 deletions.
8 changes: 3 additions & 5 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,7 @@ class Literal(Constant):

def __new__(cls, array):
array = asarray(array)
if (array == 0).all():
# All zeros, make symbolic zero
return Zero(array.shape)
else:
return super(Literal, cls).__new__(cls)
return super(Literal, cls).__new__(cls)

def __init__(self, array):
array = asarray(array)
Expand Down Expand Up @@ -548,6 +544,8 @@ def __new__(cls, aggregate, multiindex):
assert isinstance(index, IndexBase)
if isinstance(index, Index):
index.set_extent(extent)
elif isinstance(index, int) and not (0 <= index < extent):
raise IndexError("Invalid literal index")

# Empty multiindex
if not multiindex:
Expand Down
44 changes: 44 additions & 0 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,50 @@ def remove_componenttensors(expressions):
return [mapper(expression, ()) for expression in expressions]


@singledispatch
def _constant_fold_zero(node, self):
raise AssertionError("cannot handle type %s" % type(node))


_constant_fold_zero.register(Node)(reuse_if_untouched)


@_constant_fold_zero.register(Literal)
def _constant_fold_zero_literal(node, self):
if (node.array == 0).all():
# All zeros, make symbolic zero
return Zero(node.shape)
else:
return node


@_constant_fold_zero.register(ListTensor)
def _constant_fold_zero_listtensor(node, self):
new_children = list(map(self, node.children))
if all(isinstance(nc, Zero) for nc in new_children):
return Zero(node.shape)
elif all(nc == c for nc, c in zip(new_children, node.children)):
return node
else:
return node.reconstruct(*new_children)


def constant_fold_zero(exprs):
"""Produce symbolic zeros from Literals
:arg exprs: An iterable of gem expressions.
:returns: A list of gem expressions where any Literal containing
only zeros is replaced by symbolic Zero of the appropriate
shape.
We need a separate path for ListTensor so that its `reconstruct`
method will not be called when the new children are `Zero()`s;
otherwise Literal `0`s would be reintroduced.
"""
mapper = Memoizer(_constant_fold_zero)
return [mapper(e) for e in exprs]


def _select_expression(expressions, index):
"""Helper function to select an expression from a list of
expressions with an index. This function expect sanitised input,
Expand Down
40 changes: 40 additions & 0 deletions tests/test_tsfc_274.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import gem
import numpy
from finat.point_set import PointSet
from gem.interpreter import evaluate
from tsfc.finatinterface import create_element
from ufl import FiniteElement, RestrictedElement, quadrilateral


def test_issue_274():
# See https://github.com/firedrakeproject/tsfc/issues/274
ufl_element = RestrictedElement(
FiniteElement("Q", quadrilateral, 2), restriction_domain="facet"
)
ps = PointSet([[0.5]])
finat_element = create_element(ufl_element)
evaluations = []
for eid in range(4):
(val,) = finat_element.basis_evaluation(0, ps, (1, eid)).values()
evaluations.append(val)

i = gem.Index()
j = gem.Index()
(expr,) = evaluate(
[
gem.ComponentTensor(
gem.Indexed(gem.select_expression(evaluations, i), (j,)),
(*ps.indices, i, j),
)
]
)

(expect,) = evaluate(
[
gem.ComponentTensor(
gem.Indexed(gem.ListTensor(evaluations), (i, j)), (*ps.indices, i, j)
)
]
)

assert numpy.allclose(expr.arr, expect.arr)
4 changes: 2 additions & 2 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import gem
from gem.node import traversal
from gem.optimise import ffc_rounding
from gem.optimise import ffc_rounding, constant_fold_zero
from gem.unconcatenate import unconcatenate
from gem.utils import cached_property

Expand Down Expand Up @@ -718,4 +718,4 @@ def compile_ufl(expression, context, interior_facet=False, point_sum=False):
result = map_expr_dags(context.translator, expressions)
if point_sum:
result = [gem.index_sum(expr, context.point_indices) for expr in result]
return result
return constant_fold_zero(result)
3 changes: 2 additions & 1 deletion tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from gem.node import traversal
from gem.utils import cached_property
import gem.impero_utils as impero_utils
from gem.optimise import remove_componenttensors as prune
from gem.optimise import remove_componenttensors as prune, constant_fold_zero

from tsfc import fem, ufl_utils
from tsfc.kernel_interface import KernelInterface
Expand Down Expand Up @@ -213,6 +213,7 @@ def compile_gem(self, ctx):
else:
return_variables = []
expressions = []
expressions = constant_fold_zero(expressions)

# Need optimised roots
options = dict(reduce(operator.and_,
Expand Down

0 comments on commit 351994d

Please sign in to comment.