Skip to content

Commit

Permalink
Simplify Restricted(ConstantValue)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 1, 2025
1 parent aec1f3b commit 7100560
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 0 additions & 2 deletions ufl/corealg/multifunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
#
# Modified by Massimiliano Leoni, 2016

from functools import cache
from inspect import signature

from ufl.core.expr import Expr
from ufl.core.ufl_type import UFLType


@cache
def get_num_args(function):
"""Return the number of arguments accepted by *function*."""
sig = signature(function)
Expand Down
10 changes: 6 additions & 4 deletions ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def __new__(cls, expression, multiindex):
# cyclic import
from ufl.tensors import ListTensor

simpler = False
flattened = False
indices = multiindex.indices()

while (
len(indices) > 0
and isinstance(expression, ListTensor)
Expand All @@ -39,13 +40,13 @@ def __new__(cls, expression, multiindex):
# Simplify indexed ListTensor objects
expression = expression[indices[0]]
indices = indices[1:]
simpler = True
flattened = True

if isinstance(expression, Indexed):
# Simplify nested Indexed objects
indices = expression.ufl_operands[1].indices() + indices
expression = expression.ufl_operands[0]
simpler = True
flattened = True

if len(indices) == 0:
return expression
Expand All @@ -64,7 +65,8 @@ def __new__(cls, expression, multiindex):
else:
fi, fid = (), ()
return Zero(shape=(), free_indices=fi, index_dimensions=fid)
elif simpler:
elif flattened:
# Simplified Indexed expression
return Indexed(expression, MultiIndex(indices))
else:
return Operator.__new__(cls)
Expand Down
8 changes: 7 additions & 1 deletion ufl/restriction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.constantvalue import ConstantValue
from ufl.core.operator import Operator
from ufl.core.ufl_type import ufl_type
from ufl.precedence import parstr
Expand All @@ -24,7 +25,12 @@ class Restricted(Operator):

__slots__ = ()

# TODO: Add __new__ operator here, e.g. restricted(literal) == literal
def __new__(cls, expression):
"""Create a new Restricted."""
if isinstance(expression, ConstantValue):
return expression
else:
return Operator.__new__(cls)

def __init__(self, f):
"""Initialise."""
Expand Down

0 comments on commit 7100560

Please sign in to comment.