Skip to content

Commit

Permalink
Get value_shape from FunctionSpace (#3862)
Browse files Browse the repository at this point in the history
* Get value_shape from FunctionSpace

* Define FunctionSpace.block_size as the number of dofs per node

* sort_domains

---------

Co-authored-by: ksagiyam <k.sagiyama@imperial.ac.uk>
  • Loading branch information
2 people authored and connorjward committed Nov 25, 2024
1 parent dff63f8 commit cf46e77
Show file tree
Hide file tree
Showing 51 changed files with 235 additions and 235 deletions.
2 changes: 1 addition & 1 deletion firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,7 @@ def _as_global_kernel_arg_coefficient(_, self):

ufl_element = V.ufl_element()
if ufl_element.family() == "Real":
return op2.GlobalKernelArg((ufl_element.value_size,))
return op2.GlobalKernelArg((V.value_size,))
else:
return self._make_dat_global_kernel_arg(V, index=index)

Expand Down
11 changes: 6 additions & 5 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,28 +342,29 @@ def function_arg(self, g):
del self._function_arg_update
except AttributeError:
pass
V = self.function_space()
if isinstance(g, firedrake.Function) and g.ufl_element().family() != "Real":
if g.function_space() != self.function_space():
if g.function_space() != V:
raise RuntimeError("%r is defined on incompatible FunctionSpace!" % g)
self._function_arg = g
elif isinstance(g, ufl.classes.Zero):
if g.ufl_shape and g.ufl_shape != self.function_space().ufl_element().value_shape:
if g.ufl_shape and g.ufl_shape != V.value_shape:
raise ValueError(f"Provided boundary value {g} does not match shape of space")
# Special case. Scalar zero for direct Function.assign.
self._function_arg = g
elif isinstance(g, ufl.classes.Expr):
if g.ufl_shape != self.function_space().ufl_element().value_shape:
if g.ufl_shape != V.value_shape:
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
try:
self._function_arg = firedrake.Function(self.function_space())
self._function_arg = firedrake.Function(V)
# Use `Interpolator` instead of assembling an `Interpolate` form
# as the expression compilation needs to happen at this stage to
# determine if we should use interpolation or projection
# -> e.g. interpolation may not be supported for the element.
self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate
except (NotImplementedError, AttributeError):
# Element doesn't implement interpolation
self._function_arg = firedrake.Function(self.function_space()).project(g)
self._function_arg = firedrake.Function(V).project(g)
self._function_arg_update = firedrake.Projector(g, self._function_arg).project
else:
try:
Expand Down
7 changes: 3 additions & 4 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
self._update_function_name_function_space_name_map(tmesh.name, mesh.name, {f.name(): V_name})
# Embed if necessary
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element)
_element = get_embedding_element_for_checkpointing(element, V.value_shape)
if _element != element:
path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name())
self.require_group(path)
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def load_function(self, mesh, name, idx=None):
_name = self.get_attr(path, PREFIX_EMBEDDED + "_function")
_f = self.load_function(mesh, _name, idx=idx)
element = V.ufl_element()
_element = get_embedding_element_for_checkpointing(element)
_element = get_embedding_element_for_checkpointing(element, V.value_shape)
method = get_embedding_method_for_checkpointing(element)
assert _element == _f.function_space().ufl_element()
f = Function(V, name=name)
Expand Down Expand Up @@ -1436,8 +1436,7 @@ def _get_shared_data_key_for_checkpointing(self, mesh, ufl_element):
shape = ufl_element.reference_value_shape
block_size = np.prod(shape)
elif isinstance(ufl_element, finat.ufl.VectorElement):
shape = ufl_element.value_shape[:1]
block_size = np.prod(shape)
block_size = ufl_element.reference_value_shape[0]
else:
block_size = 1
return (nodes_per_entity, real_tensorproduct, block_size)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __new__(cls, value, domain=None, name=None, count=None):

if not isinstance(domain, ufl.AbstractDomain):
cell = ufl.as_cell(domain)
coordinate_element = finat.ufl.VectorElement("Lagrange", cell, 1, gdim=cell.geometric_dimension)
coordinate_element = finat.ufl.VectorElement("Lagrange", cell, 1, dim=cell.topological_dimension())
domain = ufl.Mesh(coordinate_element)

cell = domain.ufl_cell()
Expand Down
8 changes: 4 additions & 4 deletions firedrake/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ufl


def get_embedding_dg_element(element, broken_cg=False):
def get_embedding_dg_element(element, value_shape, broken_cg=False):
cell = element.cell
family = lambda c: "DG" if c.is_simplex() else "DQ"
if isinstance(cell, ufl.TensorProductCell):
Expand All @@ -19,7 +19,7 @@ def get_embedding_dg_element(element, broken_cg=False):
scalar_element = finat.ufl.FiniteElement(family(cell), cell=cell, degree=degree)
if broken_cg:
scalar_element = finat.ufl.BrokenElement(scalar_element.reconstruct(family="Lagrange"))
shape = element.value_shape
shape = value_shape
if len(shape) == 0:
DG = scalar_element
elif len(shape) == 1:
Expand All @@ -37,12 +37,12 @@ def get_embedding_dg_element(element, broken_cg=False):
native_elements_for_checkpointing = {"Lagrange", "Discontinuous Lagrange", "Q", "DQ", "Real"}


def get_embedding_element_for_checkpointing(element):
def get_embedding_element_for_checkpointing(element, value_shape):
"""Convert the given UFL element to an element that :class:`~.CheckpointFile` can handle."""
if element.family() in native_elements_for_checkpointing:
return element
else:
return get_embedding_dg_element(element)
return get_embedding_dg_element(element, value_shape)


def get_embedding_method_for_checkpointing(element):
Expand Down
2 changes: 1 addition & 1 deletion firedrake/external_operators/point_expr_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *operands, function_space, derivatives=None, argument_slots=(
if not isinstance(operator_data["func"], types.FunctionType):
raise TypeError("Expecting a FunctionType pointwise expression")
expr_shape = operator_data["func"](*operands).ufl_shape
if expr_shape != function_space.ufl_element().value_shape:
if expr_shape != function_space.value_shape:
raise ValueError("The dimension does not match with the dimension of the function space %s" % function_space)

@property
Expand Down
3 changes: 1 addition & 2 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def argument(self, o):
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
else:
args += [Zero()
for j in numpy.ndindex(
V_is[i].ufl_element().value_shape)]
for j in numpy.ndindex(V_is[i].value_shape)]
return self._arg_cache.setdefault(o, as_vector(args))


Expand Down
2 changes: 1 addition & 1 deletion firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def at(self, arg, *args, **kwargs):
raise NotImplementedError("Point evaluation not implemented for variable layers")

# Validate geometric dimension
gdim = mesh.ufl_cell().geometric_dimension()
gdim = mesh.geometric_dimension()
if arg.shape[-1] == gdim:
pass
elif len(arg.shape) == 1 and gdim == 1:
Expand Down
4 changes: 2 additions & 2 deletions firedrake/functionspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def VectorFunctionSpace(mesh, family, degree=None, dim=None,
"""
sub_element = make_scalar_element(mesh, family, degree, vfamily, vdegree, variant)
if dim is None:
dim = mesh.ufl_cell().geometric_dimension()
dim = mesh.geometric_dimension()
if not isinstance(dim, numbers.Integral) and dim > 0:
raise ValueError(f"Can't make VectorFunctionSpace with dim={dim}")
element = finat.ufl.VectorElement(sub_element, dim=dim)
Expand Down Expand Up @@ -237,7 +237,7 @@ def TensorFunctionSpace(mesh, family, degree=None, shape=None,
"""
sub_element = make_scalar_element(mesh, family, degree, vfamily, vdegree, variant)
shape = shape or (mesh.ufl_cell().geometric_dimension(),) * 2
shape = shape or (mesh.geometric_dimension(),) * 2
element = finat.ufl.TensorElement(sub_element, shape=shape, symmetry=symmetry)
return FunctionSpace(mesh, element, name=name)

Expand Down
32 changes: 16 additions & 16 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,13 @@ def split(self):
def _components(self):
if len(self) == 1:
return tuple(type(self).create(self.topological.sub(i), self.mesh())
for i in range(self.value_size))
for i in range(self.block_size))
else:
return self.subfunctions

@PETSc.Log.EventDecorator()
def sub(self, i):
if len(self) == 1:
bound = self.value_size
else:
bound = len(self)
bound = len(self._components)
if i < 0 or i >= bound:
raise IndexError("Invalid component %d, not in [0, %d)" % (i, bound))
return self._components[i]
Expand Down Expand Up @@ -489,14 +486,17 @@ def __init__(self, mesh, element, name=None):
shape_element = element
if isinstance(element, finat.ufl.WithMapping):
shape_element = element.wrapee
sub = shape_element.sub_elements[0].value_shape
sub = shape_element.sub_elements[0].reference_value_shape
self.shape = rvs[:len(rvs) - len(sub)]
else:
self.shape = ()
self._label = ""
self._ufl_function_space = ufl.FunctionSpace(mesh.ufl_mesh(), element, label=self._label)
self._mesh = mesh

self.value_size = self._ufl_function_space.value_size
r"""The number of scalar components of this :class:`FunctionSpace`."""

self.rank = len(self.shape)
r"""The rank of this :class:`FunctionSpace`. Spaces where the
element is scalar-valued (or intrinsically vector-valued) have
Expand All @@ -505,7 +505,7 @@ def __init__(self, mesh, element, name=None):
the number of components of their
:attr:`finat.ufl.finiteelementbase.FiniteElementBase.value_shape`."""

self.value_size = int(numpy.prod(self.shape, dtype=int))
self.block_size = int(numpy.prod(self.shape, dtype=int))
r"""The total number of degrees of freedom at each function
space node."""
self.name = name
Expand Down Expand Up @@ -654,7 +654,7 @@ def __getitem__(self, i):

@utils.cached_property
def _components(self):
return tuple(ComponentFunctionSpace(self, i) for i in range(self.value_size))
return tuple(ComponentFunctionSpace(self, i) for i in range(self.block_size))

def sub(self, i):
r"""Return a view into the ith component."""
Expand Down Expand Up @@ -684,7 +684,7 @@ def node_count(self):
def dof_count(self):
r"""The number of degrees of freedom (includes halo dofs) of this
function space on this process. Cf. :attr:`FunctionSpace.node_count` ."""
return self.node_count*self.value_size
return self.node_count*self.block_size

def dim(self):
r"""The global number of degrees of freedom for this function space.
Expand Down Expand Up @@ -821,7 +821,7 @@ def local_to_global_map(self, bcs, lgmap=None):
else:
indices = lgmap.block_indices.copy()
bsize = lgmap.getBlockSize()
assert bsize == self.value_size
assert bsize == self.block_size
else:
# MatBlock case, LGMap is already unrolled.
indices = lgmap.block_indices.copy()
Expand All @@ -830,11 +830,11 @@ def local_to_global_map(self, bcs, lgmap=None):
nodes = []
for bc in bcs:
if bc.function_space().component is not None:
nodes.append(bc.nodes * self.value_size
nodes.append(bc.nodes * self.block_size
+ bc.function_space().component)
elif unblocked:
tmp = bc.nodes * self.value_size
for i in range(self.value_size):
tmp = bc.nodes * self.block_size
for i in range(self.block_size):
nodes.append(tmp + i)
else:
nodes.append(bc.nodes)
Expand Down Expand Up @@ -1300,9 +1300,9 @@ def ComponentFunctionSpace(parent, component):
"""
element = parent.ufl_element()
assert type(element) in frozenset([finat.ufl.VectorElement, finat.ufl.TensorElement])
if not (0 <= component < parent.value_size):
if not (0 <= component < parent.block_size):
raise IndexError("Invalid component %d. not in [0, %d)" %
(component, parent.value_size))
(component, parent.block_size))
new = ProxyFunctionSpace(parent.mesh(), element.sub_elements[0], name=parent.name)
new.identifier = "component"
new.component = component
Expand Down Expand Up @@ -1346,7 +1346,7 @@ def make_dof_dset(self):
def make_dat(self, val=None, valuetype=None, name=None):
r"""Return a newly allocated :class:`pyop2.types.glob.Global` representing the
data for a :class:`.Function` on this space."""
return op2.Global(self.value_size, val, valuetype, name, self._comm)
return op2.Global(self.block_size, val, valuetype, name, self._comm)

def entity_node_map(self, source_mesh, source_integral_type, source_subdomain_id, source_all_integer_subdomain_ids):
return None
Expand Down
20 changes: 9 additions & 11 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def __init__(
# VectorFunctionSpace equivalent is built from the scalar
# sub-element.
ufl_scalar_element = ufl_scalar_element.sub_elements[0]
if ufl_scalar_element.value_shape != ():
if ufl_scalar_element.reference_value_shape != ():
raise NotImplementedError(
"Can't yet cross-mesh interpolate onto function spaces made from VectorElements or TensorElements made from sub elements with value shape other than ()."
)
Expand Down Expand Up @@ -614,7 +614,7 @@ def __init__(
# I first point evaluate my expression at these locations, giving a
# P0DG function on the VOM. As described in the manual, this is an
# interpolation operation.
shape = V_dest.ufl_element().value_shape
shape = V_dest.ufl_function_space().value_shape
if len(shape) == 0:
fs_type = firedrake.FunctionSpace
elif len(shape) == 1:
Expand Down Expand Up @@ -988,18 +988,16 @@ def callable():
else:
# Make sure we have an expression of the right length i.e. a value for
# each component in the value shape of each function space
dims = [numpy.prod(fs.ufl_element().value_shape, dtype=int)
for fs in V]
loops = []
if numpy.prod(expr.ufl_shape, dtype=int) != sum(dims):
if numpy.prod(expr.ufl_shape, dtype=int) != V.value_size:
raise RuntimeError('Expression of length %d required, got length %d'
% (sum(dims), numpy.prod(expr.ufl_shape, dtype=int)))
% (V.value_size, numpy.prod(expr.ufl_shape, dtype=int)))
if len(V) > 1:
raise NotImplementedError(
"UFL expressions for mixed functions are not yet supported.")
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))
if bcs and len(arguments) == 0:
loops.extend([partial(bc.apply, f) for bc in bcs])
loops.extend(partial(bc.apply, f) for bc in bcs)

def callable(loops, f):
for l in loops:
Expand All @@ -1024,13 +1022,13 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None):
if access is op2.READ:
raise ValueError("Can't have READ access for output function")

if len(expr.ufl_shape) != len(V.ufl_element().value_shape):
if len(expr.ufl_shape) != len(V.value_shape):
raise RuntimeError('Rank mismatch: Expression rank %d, FunctionSpace rank %d'
% (len(expr.ufl_shape), len(V.ufl_element().value_shape)))
% (len(expr.ufl_shape), len(V.value_shape)))

if expr.ufl_shape != V.ufl_element().value_shape:
if expr.ufl_shape != V.value_shape:
raise RuntimeError('Shape mismatch: Expression shape %r, FunctionSpace shape %r'
% (expr.ufl_shape, V.ufl_element().value_shape))
% (expr.ufl_shape, V.value_shape))

# NOTE: The par_loop is always over the target mesh cells.
target_mesh = as_domain(V)
Expand Down
Loading

0 comments on commit cf46e77

Please sign in to comment.