Skip to content

Commit

Permalink
This is a mess, revert and try again
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Oct 3, 2024
1 parent f3ad982 commit 37e96d9
Showing 1 changed file with 57 additions and 49 deletions.
106 changes: 57 additions & 49 deletions ngsPETSc/utils/firedrake/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def _slow_cdist(XA, XB):


def find_permutation(points_a, points_b, tol=1e-5):
""" Find all permutations between a list of two sets of points.
""" Find permutation between two sets of points.
Given two numpy arrays of shape (ncells, npoints, dim) containing
floating point coordinates for each cell, determine each index
Given two numpy arrays of shape (npoints, dim) containing
floating point coordinates for a cell, determine each index
permutation that takes `points_a` to `points_b`. Ie:
```
permutation = find_permutation(points_a, points_b)
Expand All @@ -103,18 +103,11 @@ def find_permutation(points_a, points_b, tol=1e-5):
if points_a.shape != points_b.shape:
raise ValueError("`points_a` and `points_b` must have the same shape.")

p = [np.where(cdist(a, b).T < tol)[1] for a, b in zip(points_a, points_b)]
try:
permutation = np.array(p, ndmin=2)
except ValueError as e:
raise ValueError(
"It was not possible to find a permutation for every cell"
" within the provided tolerance"
) from e
permutation = np.where(cdist(points_a, points_b).T < tol)[1]

if permutation.shape != points_a.shape[0:2]:
if permutation.shape[0] != points_a.shape[0]:
raise ValueError(
"It was not possible to find a permutation for every cell"
"It was not possible to find a permutation for the cell"
" within the provided tolerance"
)

Expand Down Expand Up @@ -155,64 +148,79 @@ def curveField(self, order, tol=1e-8):
ref.append(pt)
reference_space_points = np.array(ref)

# Map to the physical domain
physical_space_points = np.ndarray(
(ng_dimension, reference_space_points.shape[0], geom_dim)
)
curved_space_points = np.ndarray(
(ng_dimension, reference_space_points.shape[0], geom_dim)
)

# Curve the mesh on rank 0 only
# JBTODO: (why?)
if self.comm.rank == 0:
# Construct numpy arrays for physical domain data
physical_space_points = np.ndarray(
(ng_dimension, reference_space_points.shape[0], geom_dim)
)
curved_space_points = np.ndarray(
(ng_dimension, reference_space_points.shape[0], geom_dim)
)
self.netgen_mesh.CalcElementMapping(reference_space_points, physical_space_points)
self.netgen_mesh.Curve(order)
self.netgen_mesh.CalcElementMapping(reference_space_points, curved_space_points)
curved = ng_element().NumPy()["curved"]
# Broadcast curving data
curved = self.comm.bcast(curved, root=0)
physical_space_points = physical_space_points[curved]
curved_space_points = curved_space_points[curved]
else:
curved = np.array((ng_dimension, 1))
curved = self.comm.bcast(None, root=0)
# Construct numpy arrays as buffers to receive physical domain data
ncurved = np.sum(curved)
physical_space_points = np.ndarray(
(ncurved, reference_space_points.shape[0], geom_dim)
)
curved_space_points = np.ndarray(
(ncurved, reference_space_points.shape[0], geom_dim)
)

# Broadcast curving data
physical_space_points = self.comm.bcast(physical_space_points, root=0)
curved_space_points = self.comm.bcast(curved_space_points, root=0)
curved = self.comm.bcast(curved, root=0)
# ~ import pytest; pytest.set_trace()
# ~ breakpoint()
# Broadcast physical domain data for curved points
self.comm.Bcast(physical_space_points, root=0)
self.comm.Bcast(curved_space_points, root=0)
cell_node_map = new_coordinates.cell_node_map()

# Select only the points in curved cells
physical_space_points = physical_space_points[curved]
curved_space_points = curved_space_points[curved]
barycentres = np.average(physical_space_points, axis=1)
ng_index = [*map(self.locate_cell, barycentres)]

# Select only the indices of points owned by this rank
owned = [(0 <= ii < len(cell_node_map.values)) if ii is not None else False for ii in ng_index]

# Select only the points owned by this rank
physical_space_points = physical_space_points[owned]
curved_space_points = curved_space_points[owned]
barycentres = barycentres[owned]
ng_index = [idx for idx, o in zip(ng_index, owned) if o]
owned_physical_space_points = physical_space_points[owned]
owned_curved_space_points = curved_space_points[owned]
owned_barycentres = barycentres[owned]

Check failure on line 195 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W0612

ngsPETSc/utils/firedrake/meshes.py:195:4: W0612 Unused variable 'owned_barycentres'
owned_index = [idx for idx, o in zip(ng_index, owned) if o]

# Get the PyOP2 indices corresponding to the netgen indices
pyop2_index = []
for ngidx in ng_index:
for ngidx in owned_index:
pyop2_index.extend(cell_node_map.values[ngidx])
np.array(pyop2_index)

# Find the correct coordinate permutation for each cell
# JBTODO: This should be moved to the next loop if we have to loop
# over cells any way to actually perform the permutation
permutation = find_permutation(
physical_space_points,
new_coordinates.dat.data[pyop2_index].reshape(physical_space_points.shape),
tol=tol
)
pyop2_index = np.array(pyop2_index).reshape(-1, physical_space_points.shape[1])

# ~ breakpoint()
# Find the correct coordinate permutation for this cell and apply the
# permutation to each cell in turn. For owned cells this must not fail.
for ii, (physical, index) in enumerate(zip(owned_physical_space_points, pyop2_index)):
try:
permutation = find_permutation(
physical,
new_coordinates.dat.data[index].reshape(physical.shape),
tol=tol
)
owned_curved_space_points[ii] = physical[permutation]
except ValueError:
fd.logging.warning(
f"[{self.comm.rank}] Not able to curve Firedrake element"
f" {owned_index[ii]}"
)

# Apply the permutation to each cell in turn
for ii, p in enumerate(curved_space_points):
curved_space_points[ii] = p[permutation[ii]]

# Assign the curved coordinates to the dat
new_coordinates.dat.data[pyop2_index] = curved_space_points.reshape(-1, geom_dim)
new_coordinates.dat.data[pyop2_index.flatten()] = owned_curved_space_points.reshape(-1, geom_dim)

Check failure on line 223 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

C0301

ngsPETSc/utils/firedrake/meshes.py:223:0: C0301 Line too long (101/100)

return new_coordinates

Expand Down

0 comments on commit 37e96d9

Please sign in to comment.