Skip to content

Commit

Permalink
WIP support Argument.part
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 8, 2025
1 parent bfa596d commit 33bd10d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
21 changes: 9 additions & 12 deletions ufl/algorithms/formtransformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ufl.argument import Argument
from ufl.coefficient import Coefficient
from ufl.constantvalue import Zero
from ufl.split_functions import split

# All classes:
from ufl.core.expr import ufl_err_str
Expand Down Expand Up @@ -397,18 +398,18 @@ def compute_form_action(form, coefficient):
# Extract all arguments
arguments = form.arguments()

parts = [arg.part() for arg in arguments]
if set(parts) - {None}:
raise ValueError("compute_form_action cannot handle parts.")

# Pick last argument (will be replaced)
u = arguments[-1]

fs = u.ufl_function_space()
if u.part() is not None:
fs = fs[u.part()]
if coefficient is None:
coefficient = Coefficient(fs)
elif coefficient.ufl_function_space() != fs:
debug("Computing action of form on a coefficient in a different function space.")
if u.part() is not None:
coefficient = split(coefficient)[u.part()]
return replace(form, {u: coefficient})


Expand Down Expand Up @@ -457,10 +458,6 @@ def compute_form_adjoint(form, reordered_arguments=None):
"""
arguments = form.arguments()

parts = [arg.part() for arg in arguments]
if set(parts) - {None}:
raise ValueError("compute_form_adjoint cannot handle parts.")

if len(arguments) != 2:
raise ValueError("Expecting bilinear form.")

Expand All @@ -469,17 +466,17 @@ def compute_form_adjoint(form, reordered_arguments=None):
raise ValueError("Mistaken assumption in code!")

if reordered_arguments is None:
reordered_u = Argument(u.ufl_function_space(), number=v.number(), part=v.part())
reordered_v = Argument(v.ufl_function_space(), number=u.number(), part=u.part())
reordered_u = u.reconstruct(number=v.number())
reordered_v = v.reconstruct(number=u.number())
else:
reordered_u, reordered_v = reordered_arguments

if reordered_u.number() >= reordered_v.number():
raise ValueError("Ordering of new arguments is the same as the old arguments!")

if reordered_u.part() != v.part():
if reordered_v.part() != v.part():
raise ValueError("Ordering of new arguments is the same as the old arguments!")
if reordered_v.part() != u.part():
if reordered_u.part() != u.part():
raise ValueError("Ordering of new arguments is the same as the old arguments!")

if reordered_u.ufl_function_space() != u.ufl_function_space():
Expand Down
6 changes: 5 additions & 1 deletion ufl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def __init__(self, function_space, number, part=None):
raise ValueError("Expecting a FunctionSpace.")

self._ufl_function_space = function_space
self._ufl_shape = function_space.value_shape
if part is None:
shape = function_space.value_shape
else:
shape = function_space[part].value_shape
self._ufl_shape = shape

if not isinstance(number, numbers.Integral):
raise ValueError(f"Expecting an int for number, not {number}")
Expand Down
10 changes: 3 additions & 7 deletions ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ def __new__(cls, expression, multiindex):
fi, fid = (), ()
return Zero(shape=(), free_indices=fi, index_dimensions=fid)

try:
# Simplify indexed ListTensor
c = expression[multiindex]
return Indexed(*c.ufl_operands) if isinstance(c, Indexed) else c
except ValueError:
return Operator.__new__(cls)

return Operator.__new__(cls)

def __init__(self, expression, multiindex):
"""Initialise."""
Expand Down Expand Up @@ -121,7 +117,7 @@ def __getitem__(self, key):
f"but object is already indexed: {ufl_err_str(self)}"
)

def _ufl_expr_reconstruct_(self, expression, multiindex):
def __ufl_expr_reconstruct_(self, expression, multiindex):
"""Reconstruct."""
try:
# Simplify indexed ListTensor
Expand Down
7 changes: 7 additions & 0 deletions ufl/split_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def split(v):
If v is a Coefficient or Argument in a mixed space, returns a tuple
with the function components corresponding to the subelements.
"""
from ufl.argument import BaseArgument
if isinstance(v, BaseArgument) and v.part() is None:
element = v.ufl_element()
r = tuple(v.reconstruct(part=part) for part in range(element.num_sub_elements))
for a in r:
print(a.ufl_shape)
return r
domain = extract_unique_domain(v)

# Default range is all of v
Expand Down

0 comments on commit 33bd10d

Please sign in to comment.