Skip to content

Commit

Permalink
rebase fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Nov 13, 2024
1 parent d77495d commit b0c822e
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 10 deletions.
54 changes: 50 additions & 4 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,8 @@ def __init__(self, form, local_knl, subdomain_id, all_integer_subdomain_ids, dia
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)
self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo)
self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo)
self._active_exterior_facet_orientations = _FormHandler.iter_active_exterior_facet_orientations(form, local_knl.kinfo)
self._active_interior_facet_orientations = _FormHandler.iter_active_interior_facet_orientations(form, local_knl.kinfo)

self._map_arg_cache = {}
# Cache for holding :class:`op2.MapKernelArg` instances.
Expand All @@ -1620,6 +1622,8 @@ def build(self):
assert_empty(self._constants)
assert_empty(self._active_exterior_facets)
assert_empty(self._active_interior_facets)
assert_empty(self._active_exterior_facet_orientations)
assert_empty(self._active_interior_facet_orientations)

iteration_regions = {"exterior_facet_top": op2.ON_TOP,
"exterior_facet_bottom": op2.ON_BOTTOM,
Expand Down Expand Up @@ -1817,12 +1821,24 @@ def _as_global_kernel_arg_interior_facet(_, self):

@_as_global_kernel_arg.register(kernel_args.ExteriorFacetOrientationKernelArg)
def _as_global_kernel_arg_exterior_facet_orientation(_, self):
return op2.DatKernelArg((1,))
mesh = next(self._active_exterior_facet_orientations)
if mesh is self._mesh:
return op2.DatKernelArg((1,))
else:
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
assert integral_type == "exterior_facet"
return op2.DatKernelArg((1,), m._global_kernel_arg)


@_as_global_kernel_arg.register(kernel_args.InteriorFacetOrientationKernelArg)
def _as_global_kernel_arg_interior_facet_orientation(_, self):
return op2.DatKernelArg((2,))
mesh = next(self._active_interior_facet_orientations)
if mesh is self._mesh:
return op2.DatKernelArg((2,))
else:
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
assert integral_type == "interior_facet"
return op2.DatKernelArg((2,), m._global_kernel_arg)


@_as_global_kernel_arg.register(CellFacetKernelArg)
Expand Down Expand Up @@ -1874,6 +1890,8 @@ def __init__(self, form, bcs, local_knl, subdomain_id,
self._constants = _FormHandler.iter_constants(form, local_knl.kinfo)
self._active_exterior_facets = _FormHandler.iter_active_exterior_facets(form, local_knl.kinfo)
self._active_interior_facets = _FormHandler.iter_active_interior_facets(form, local_knl.kinfo)
self._active_exterior_facet_orientations = _FormHandler.iter_active_exterior_facet_orientations(form, local_knl.kinfo)
self._active_interior_facet_orientations = _FormHandler.iter_active_interior_facet_orientations(form, local_knl.kinfo)

def build(self, tensor: op2.Global | op2.Dat | op2.Mat) -> op2.Parloop:
"""Construct the parloop.
Expand Down Expand Up @@ -2139,12 +2157,24 @@ def _as_parloop_arg_interior_facet(_, self):

@_as_parloop_arg.register(kernel_args.ExteriorFacetOrientationKernelArg)
def _as_parloop_arg_exterior_facet_orientation(_, self):
return op2.DatParloopArg(self._mesh.exterior_facets.local_facet_orientation_dat)
mesh = next(self._active_exterior_facet_orientations)
if mesh is self._mesh:
m = None
else:
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
assert integral_type == "exterior_facet"
return op2.DatParloopArg(mesh.exterior_facets.local_facet_orientation_dat, m)


@_as_parloop_arg.register(kernel_args.InteriorFacetOrientationKernelArg)
def _as_parloop_arg_interior_facet_orientation(_, self):
return op2.DatParloopArg(self._mesh.interior_facets.local_facet_orientation_dat)
mesh = next(self._active_interior_facet_orientations)
if mesh is self._mesh:
m = None
else:
m, integral_type = mesh.topology.trans_mesh_entity_map(self._mesh.topology, self._integral_type, self._subdomain_id, self._all_integer_subdomain_ids)
assert integral_type == "interior_facet"
return op2.DatParloopArg(mesh.interior_facets.local_facet_orientation_dat, m)


@_as_parloop_arg.register(CellFacetKernelArg)
Expand Down Expand Up @@ -2222,6 +2252,22 @@ def iter_active_interior_facets(form, kinfo):
mesh = all_meshes[i]
yield mesh

@staticmethod
def iter_active_exterior_facet_orientations(form, kinfo):
"""Yield the form exterior facet orientations referenced in ``kinfo``."""
all_meshes = extract_domains(form)
for i in kinfo.active_domain_numbers.exterior_facet_orientations:
mesh = all_meshes[i]
yield mesh

@staticmethod
def iter_active_interior_facet_orientations(form, kinfo):
"""Yield the form interior facet orientations referenced in ``kinfo``."""
all_meshes = extract_domains(form)
for i in kinfo.active_domain_numbers.interior_facet_orientations:
mesh = all_meshes[i]
yield mesh

@staticmethod
def index_function_spaces(form, indices):
"""Return the function spaces of the form's arguments, indexed
Expand Down
4 changes: 1 addition & 3 deletions firedrake/mg/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def compile_element(expression, dual_space=None, parameters=None,

config = dict(interface=builder,
ufl_cell=cell,
integral_type="cell",
domain_integral_type_map={domain: "cell"},
point_indices=(),
point_expr=point,
argument_multiindices=argument_multiindices,
Expand Down Expand Up @@ -540,7 +540,6 @@ def dg_injection_kernel(Vf, Vc, ncell):
integration_dim, entity_ids = lower_integral_type(Vfe.cell, "cell")
macro_cfg = dict(interface=macro_builder,
ufl_cell=Vf.ufl_cell(),
integral_type="cell",
integration_dim=integration_dim,
entity_ids=entity_ids,
index_cache=index_cache,
Expand Down Expand Up @@ -580,7 +579,6 @@ def dg_injection_kernel(Vf, Vc, ncell):

coarse_cfg = dict(interface=coarse_builder,
ufl_cell=Vc.ufl_cell(),
integral_type="cell",
integration_dim=integration_dim,
entity_ids=entity_ids,
index_cache=index_cache,
Expand Down
2 changes: 1 addition & 1 deletion firedrake/pointeval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def compile_element(expression, coordinates, parameters=None):

config = dict(interface=builder,
ufl_cell=extract_unique_domain(coordinates).ufl_cell(),
integral_type="cell",
domain_integral_type_map={domain: "cell"},
point_indices=(),
point_expr=point,
scalar_type=utils.ScalarType)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/pointquery_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def to_reference_coords_newton_step(ufl_coordinate_element, parameters, x0_dtype
context = tsfc.fem.GemPointContext(
interface=builder,
ufl_cell=cell,
integral_type="cell",
domain_integral_type_map={domain: "cell"},
point_indices=(),
point_expr=point,
scalar_type=parameters["scalar_type"]
Expand Down
5 changes: 4 additions & 1 deletion firedrake/slate/slac/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ def generate_loopy_kernel(slate_expr, compiler_parameters=None):
cell_orientations=(0, ) if builder.bag.needs_cell_orientations else (),
cell_sizes=(0, ) if builder.bag.needs_cell_sizes else (),
exterior_facets=(),
interior_facets=()),
interior_facets=(),
exterior_facet_orientations=(),
interior_facet_orientations=(),
),
coefficient_numbers=coefficient_numbers,
constant_numbers=constant_numbers,
needs_cell_facets=builder.bag.needs_cell_facets,
Expand Down

0 comments on commit b0c822e

Please sign in to comment.