Skip to content

Commit

Permalink
Merge pull request #2437 from firedrakeproject/periodic_vertexonlymesh
Browse files Browse the repository at this point in the history
Transfer responsibility for computing cell inclusion for vertices from Plex to Firedrake. This enables periodic meshes to be supported in VertexOnlyMesh.
  • Loading branch information
dham authored May 20, 2022
1 parent 4e21d1a commit e6e778a
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 221 deletions.
135 changes: 0 additions & 135 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2845,141 +2845,6 @@ def clear_adjacency_callback(PETSc.DM dm not None):
CHKERR(DMPlexSetAdjacencyUser(dm.dm, NULL, NULL))


@cython.boundscheck(False)
@cython.wraparound(False)
def remove_ghosts_pic(PETSc.DM swarm, PETSc.DM plex):
"""Remove DMSwarm PICs which are in ghost cells of a distributed
DMPlex.
:arg swarm: The DMSWARM which has been associated with the input
DMPlex `plex` using PETSc `DMSwarmSetCellDM`.
:arg plex: The DMPlex which is associated with the input DMSWARM
`swarm`
"""
cdef:
PetscInt cStart, cEnd, ncells, i, npics
PETSc.SF sf
PetscInt nroots, nleaves
const PetscInt *ilocal = NULL
const PetscSFNode *iremote = NULL
np.ndarray[PetscInt, ndim=1, mode="c"] pic_cell_indices
np.ndarray[PetscInt, ndim=1, mode="c"] ghost_cell_indices

if type(plex) is not PETSc.DMPlex:
raise ValueError("plex must be a DMPlex")

if type(swarm) is not PETSc.DMSwarm:
raise ValueError("swarm must be a DMSwarm")

if plex.handle != swarm.getCellDM().handle:
raise ValueError("plex is not the swarm's CellDM")

if plex.comm.size > 1:
get_height_stratum(plex.dm, 0, &cStart, &cEnd)
ncells = cEnd - cStart

# Get full list of cell indices for particles
pic_cell_indices = np.copy(swarm.getField("DMSwarm_cellid"))
swarm.restoreField("DMSwarm_cellid")
npics = len(pic_cell_indices)

# Initialise with zeros since these can't be valid ranks or cell ids
ghost_cell_indices = np.full(ncells, -1, dtype=IntType)

# Search for ghost cell indices (spooky!)
sf = plex.getPointSF()
CHKERR(PetscSFGetGraph(sf.sf, &nroots, &nleaves, &ilocal, &iremote))
for i in range(nleaves):
if cStart <= ilocal[i] < cEnd:
# NOTE need to check this is correct index. Can I check the labels some how?
ghost_cell_indices[ilocal[i] - cStart] = ilocal[i]

# trim -1's and make into set to reduce searching needed
ghost_cell_indices_set = set(ghost_cell_indices[ghost_cell_indices != -1])

# remove swarm pic parent cell indices which match ghost cell indices
for i in range(npics-1, -1, -1):
if pic_cell_indices[i] in ghost_cell_indices_set:
# removePointAtIndex shift cell numbers down by 1
swarm.removePointAtIndex(i)


@cython.boundscheck(False)
@cython.wraparound(False)
def label_pic_parent_cell_info(PETSc.DM swarm, parentmesh):
"""
For each PIC in the input swarm, label its `parentcellnum` field with
the relevant cell number from the `parentmesh` in which is it emersed
and the `refcoord` field with the relevant cell reference coordinate.
This information is given by the
`parentmesh.locate_cell_and_reference_coordinate` method.
For a swarm with N PICs emersed in `parentmesh`
the `parentcellnum` field is N long and
the `refcoord` field is N*parentmesh.topological_dimension() long.
:arg swarm: The DMSWARM which contains the PICs immersed in
`parentmesh`
:arg parentmesh: The mesh within with the `swarm` PICs are immersed.
..note:: All PICs must be within the parentmesh or this will try to
assign `None` (returned by
`parentmesh.locate_cell_and_reference_coordinate`) to the
`parentcellnum` or `refcoord` fields.
"""
cdef:
PetscInt num_vertices, i, gdim, tdim
PetscInt parent_cell_num
np.ndarray[PetscReal, ndim=2, mode="c"] swarm_coords
np.ndarray[PetscInt, ndim=1, mode="c"] parent_cell_nums
np.ndarray[PetscReal, ndim=2, mode="c"] reference_coords
np.ndarray[PetscReal, ndim=1, mode="c"] reference_coord

gdim = parentmesh.geometric_dimension()
tdim = parentmesh.topological_dimension()

num_vertices = swarm.getLocalSize()

# Check size of biggest num_vertices so
# locate_cell can be called on every processor
comm = swarm.comm.tompi4py()
max_num_vertices = comm.allreduce(num_vertices, op=MPI.MAX)

# Create an out of mesh point to use in locate_cell when needed
out_of_mesh_point = np.full((1, gdim), np.inf)

# get fields - NOTE this isn't copied so make sure
# swarm.restoreField is called for each field too!
swarm_coords = swarm.getField("DMSwarmPIC_coor").reshape((num_vertices, gdim))
parent_cell_nums = swarm.getField("parentcellnum")
reference_coords = swarm.getField("refcoord").reshape((num_vertices, tdim))

# find parent cell numbers
# TODO We should be able to do this for all the points in in one call to
# the parent mesh's _c_locator.
# SUGGESTED API 1:
# parent_cell_nums, reference_coords = parentmesh.locate_cell_and_reference_coordinates(swarm_coords)
# with second call for collectivity.
# SUGGESTED API 2:
# parent_cell_nums, reference_coords = parentmesh.locate_cell_and_reference_coordinates(swarm_coords, local_num_vertices)
# with behaviour changing inside locate_cell_and_reference_coordinates to
# ensure collectivity.
for i in range(max_num_vertices):
if i < num_vertices:
parent_cell_num, reference_coord = parentmesh.locate_cell_and_reference_coordinate(swarm_coords[i])
parent_cell_nums[i] = parent_cell_num
reference_coords[i] = reference_coord
else:
parentmesh.locate_cell(out_of_mesh_point) # should return None

# have to restore fields once accessed to allow access again
swarm.restoreField("refcoord")
swarm.restoreField("parentcellnum")
swarm.restoreField("DMSwarmPIC_coor")


@cython.boundscheck(False)
@cython.wraparound(False)
def fill_reference_coordinates_function(reference_coordinates_f):
Expand Down
Loading

0 comments on commit e6e778a

Please sign in to comment.