Skip to content

Commit

Permalink
Additional tidying, difficult numpy indexing rules
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Oct 1, 2024
1 parent 814ff66 commit 4bb63af
Showing 1 changed file with 66 additions and 62 deletions.
128 changes: 66 additions & 62 deletions ngsPETSc/utils/firedrake/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
import numpy as np
from petsc4py import PETSc

try:
from scipy.spatial.distance import cdist
HAVE_SCIPY = True
except ImportError:
HAVE_SCIPY = False

import netgen
import netgen.meshing as ngm
from netgen.meshing import MeshingParameters
Expand Down Expand Up @@ -70,6 +76,45 @@ def refineMarkedElements(self, mark):
else:
raise NotImplementedError("No implementation for dimension other than 2 and 3.")


def _slow_cdist(XA, XB):
dist = np.zeros([len(XA), len(XB)])

Check failure on line 81 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W0311

ngsPETSc/utils/firedrake/meshes.py:81:0: W0311 Bad indentation. Found 8 spaces, expected 4
for ii, a in enumerate(XA):

Check failure on line 82 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W0311

ngsPETSc/utils/firedrake/meshes.py:82:0: W0311 Bad indentation. Found 8 spaces, expected 4
for jj, b in enumerate(XB):

Check failure on line 83 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W0311

ngsPETSc/utils/firedrake/meshes.py:83:0: W0311 Bad indentation. Found 12 spaces, expected 8
dist[ii, jj] = np.linalg.norm(b - a)

Check failure on line 84 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W0311

ngsPETSc/utils/firedrake/meshes.py:84:0: W0311 Bad indentation. Found 16 spaces, expected 12
return dist

Check failure on line 85 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W0311

ngsPETSc/utils/firedrake/meshes.py:85:0: W0311 Bad indentation. Found 8 spaces, expected 4


if not HAVE_SCIPY:
cdist = _slow_cdist


def find_permutation(points_a, points_b, tol=1e-5):
""" Find all permutations between a list of two sets of points.
Given two numpy arrays of shape (ncells, npoints, dim) containing
floating point coordinates for each cell, determine each index
permutation that takes `points_a` to `points_b`. Ie:
```
permutation = find_permutation(points_a, points_b)
assert np.allclose(points_a[permutation], points_b, rtol=0, atol=tol)
```
"""
if points_a.shape != points_b.shape:
raise ValueError("`points_a` and `points_b` must have the same shape.")

p = [np.where(cdist(a, b).T < tol)[1] for a, b in zip(points_a, points_b)]
try:
permutation = np.array(p, ndmin=2)
except ValueError:
raise ValueError("It was not possible to find a permutation for every cell within the provided tolerance")

Check failure on line 110 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

C0301

ngsPETSc/utils/firedrake/meshes.py:110:0: C0301 Line too long (114/100)

Check failure on line 110 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W0707

ngsPETSc/utils/firedrake/meshes.py:110:8: W0707 Consider explicitly re-raising using 'except ValueError as exc' and 'raise ValueError('It was not possible to find a permutation for every cell within the provided tolerance') from exc'

if permutation.shape != points_a.shape[0:2]:
raise ValueError("It was not possible to find a permutation for every cell within the provided tolerance")

Check failure on line 113 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

C0301

ngsPETSc/utils/firedrake/meshes.py:113:0: C0301 Line too long (114/100)

return permutation


def curveField(self, order, tol=1e-8):

Check failure on line 118 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W0613

ngsPETSc/utils/firedrake/meshes.py:118:28: W0613 Unused argument 'tol'
'''
This method returns a curved mesh as a Firedrake function.
Expand All @@ -78,21 +123,21 @@ def curveField(self, order, tol=1e-8):
'''
#Checking if the mesh is a surface mesh or two dimensional mesh
surf = len(self.netgen_mesh.Elements3D()) == 0 # REMOVE
if len(self.netgen_mesh.Elements3D()) == 0:
ng_element = self.netgen_mesh.Elements2D
else:
ng_element = self.netgen_mesh.Elements3D
ng_dimension = len(ng_element())
geom_dim = self.geometric_dimension()

#Constructing mesh as a function
low_order_element = self.coordinates.function_space().ufl_element().sub_elements[0]
ufl_element = low_order_element.reconstruct(degree=order)
firedrake_space = fd.VectorFunctionSpace(self, fd.BrokenElement(ufl_element))
newFunctionCoordinates = fd.assemble(interpolate(self.coordinates, firedrake_space))
new_coordinates = fd.assemble(interpolate(self.coordinates, firedrake_space))

#Computing reference points using fiat
fiat_element = newFunctionCoordinates.function_space().finat_element.fiat_equivalent
fiat_element = new_coordinates.function_space().finat_element.fiat_equivalent
entity_ids = fiat_element.entity_dofs()
nodes = fiat_element.dual_basis()
ref = []
Expand All @@ -105,9 +150,8 @@ def curveField(self, order, tol=1e-8):
reference_space_points = np.array(ref)

#Mapping to the physical domain
els = {True: self.netgen_mesh.Elements2D, False: self.netgen_mesh.Elements3D} # REMOVE
physical_space_points = np.ndarray((ng_dimension, reference_space_points.shape[0], self.geometric_dimension()))
curved_space_points = np.ndarray((ng_dimension, reference_space_points.shape[0], self.geometric_dimension()))
physical_space_points = np.ndarray((ng_dimension, reference_space_points.shape[0], geom_dim))
curved_space_points = np.ndarray((ng_dimension, reference_space_points.shape[0], geom_dim))

if self.comm.rank == 0:
#Curving the mesh on rank 0
Expand All @@ -122,7 +166,7 @@ def curveField(self, order, tol=1e-8):
physical_space_points = self.comm.bcast(physical_space_points, root=0)
curved_space_points = self.comm.bcast(curved_space_points, root=0)
curved = self.comm.bcast(curved, root=0)
cell_node_map = newFunctionCoordinates.cell_node_map()
cell_node_map = new_coordinates.cell_node_map()

# Select only the points in curved cells
physical_space_points = physical_space_points[curved]
Expand All @@ -137,68 +181,28 @@ def curveField(self, order, tol=1e-8):
barycentres = barycentres[owned]
ng_index = [idx for idx, o in zip(ng_index, owned) if o]

# Do we want the min or max of these???
norms = np.linalg.norm(physical_space_points - newFunctionCoordinates.dat.data[cell_node_map.values[ng_index]], axis=2)
breakpoint()

Check failure on line 184 in ngsPETSc/utils/firedrake/meshes.py

View workflow job for this annotation

GitHub Actions / lint

W1515

ngsPETSc/utils/firedrake/meshes.py:184:4: W1515 Leaving functions creating breakpoints in production code is not recommended

# PyOP2 index
pyop2_index = []
for ngidx in ng_index:
pyop2_index.extend(cell_node_map.values[ngidx])
np.array(pyop2_index)

for dim in range(self.geometric_dimension()):
newFunctionCoordinates.sub(dim).dat.data[pyop2_index] = curved_space_points[:,:,dim].flatten()

# ~ breakpoint()

# ~ for i in range(physical_space_points.shape[0]):
# ~ #Inefficent code but runs only on curved elements
# ~ if curved[i]:
# ~ pts = physical_space_points[i][0:reference_space_points.shape[0]]
# ~ bary = sum([np.array(pts[i]) for i in range(len(pts))])/len(pts)
# ~ Idx = self.locate_cell(bary)
# ~ isInMesh = (0<=Idx<len(cell_node_map.values)) if Idx is not None else False
# ~ #Check if element is shared across processes
# ~ shared = self.comm.gather(isInMesh, root=0)
# ~ shared = self.comm.bcast(shared, root=0) # `shared` isn't used in anything other than warning messages!?
# ~ # self.comm.Allgather(isInMesh, shared)
# ~ #Bend if not shared
# ~ if np.sum(shared) == 1:
# ~ if isInMesh:
# ~ # This seems dodgy, I think we're just trying to ensure the points in the netgen cell are close to the coordinates in the dat
# ~ p = [np.argmin(np.sum((pts - pt)**2, axis=1))
# ~ for pt in newFunctionCoordinates.dat.data[cell_node_map.values[Idx]][0:reference_space_points.shape[0]]]
# ~ curved_space_points[i] = curved_space_points[i][p]
# ~ res = np.linalg.norm(pts[p] - newFunctionCoordinates.dat.data[cell_node_map.values[Idx]][0:reference_space_points.shape[0]])
# ~ if res > tol:
# ~ fd.logging.warning(f"[{self.comm.rank}] Not able to curve Firedrake element {Idx} ({i}) -- residual: {res}")
# ~ else:
# ~ for j, datIdx in enumerate(cell_node_map.values[Idx][0:reference_space_points.shape[0]]):
# ~ for dim in range(self.geometric_dimension()):
# ~ coo = curved_space_points[i][j][dim]
# ~ newFunctionCoordinates.sub(dim).dat.data[datIdx] = coo
# ~ else:
# ~ if isInMesh:
# ~ p = [np.argmin(np.sum((pts - pt)**2, axis=1))
# ~ for pt in newFunctionCoordinates.dat.data[cell_node_map.values[Idx]][0:reference_space_points.shape[0]]]
# ~ curved_space_points[i] = curved_space_points[i][p]
# ~ res = np.linalg.norm(pts[p]-newFunctionCoordinates.dat.data[cell_node_map.values[Idx]][0:reference_space_points.shape[0]])
# ~ else:
# ~ res = np.inf
# ~ res = self.comm.gather(res, root=0)
# ~ res = self.comm.bcast(res, root=0)
# ~ rank = np.argmin(res)
# ~ if self.comm.rank == rank:
# ~ if res[rank] > tol:
# ~ fd.logging.warning("[{}, {}] Not able to curve Firedrake element {} \
# ~ ({}) -- residual: {}".format(self.comm.rank, shared, Idx,i, res))
# ~ else:
# ~ for j, datIdx in enumerate(cell_node_map.values[Idx][0:reference_space_points.shape[0]]):
# ~ for dim in range(self.geometric_dimension()):
# ~ coo = curved_space_points[i][j][dim]
# ~ newFunctionCoordinates.sub(dim).dat.data[datIdx] = coo
# ~ breakpoint()
return newFunctionCoordinates
# Find the correct coordinate permutation for each cell
permutation = find_permutation(
physical_space_points,
new_coordinates.dat.data[pyop2_index].reshape(physical_space_points.shape)
)

# Apply the permutation to each cell in turn
for ii, p in enumerate(curved_space_points):
curved_space_points[ii] = p[permutation[ii]]

# Assign the curved coordinates to the dat
new_coordinates.dat.data[pyop2_index] = curved_space_points.reshape(-1, geom_dim)

return new_coordinates

def splitToQuads(plex, dim, comm):
'''
Expand Down

0 comments on commit 4bb63af

Please sign in to comment.