From 4bb63af0e97b353ceee6d1eb091c27da5f849e4a Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 1 Oct 2024 17:45:37 +0100 Subject: [PATCH] Additional tidying, difficult numpy indexing rules --- ngsPETSc/utils/firedrake/meshes.py | 128 +++++++++++++++-------------- 1 file changed, 66 insertions(+), 62 deletions(-) diff --git a/ngsPETSc/utils/firedrake/meshes.py b/ngsPETSc/utils/firedrake/meshes.py index 9a5a605..5a6e5ff 100644 --- a/ngsPETSc/utils/firedrake/meshes.py +++ b/ngsPETSc/utils/firedrake/meshes.py @@ -12,6 +12,12 @@ import numpy as np from petsc4py import PETSc +try: + from scipy.spatial.distance import cdist + HAVE_SCIPY = True +except ImportError: + HAVE_SCIPY = False + import netgen import netgen.meshing as ngm from netgen.meshing import MeshingParameters @@ -70,6 +76,45 @@ def refineMarkedElements(self, mark): else: raise NotImplementedError("No implementation for dimension other than 2 and 3.") + +def _slow_cdist(XA, XB): + dist = np.zeros([len(XA), len(XB)]) + for ii, a in enumerate(XA): + for jj, b in enumerate(XB): + dist[ii, jj] = np.linalg.norm(b - a) + return dist + + +if not HAVE_SCIPY: + cdist = _slow_cdist + + +def find_permutation(points_a, points_b, tol=1e-5): + """ Find all permutations between a list of two sets of points. + + Given two numpy arrays of shape (ncells, npoints, dim) containing + floating point coordinates for each cell, determine each index + permutation that takes `points_a` to `points_b`. Ie: + ``` + permutation = find_permutation(points_a, points_b) + assert np.allclose(points_a[permutation], points_b, rtol=0, atol=tol) + ``` + """ + if points_a.shape != points_b.shape: + raise ValueError("`points_a` and `points_b` must have the same shape.") + + p = [np.where(cdist(a, b).T < tol)[1] for a, b in zip(points_a, points_b)] + try: + permutation = np.array(p, ndmin=2) + except ValueError: + raise ValueError("It was not possible to find a permutation for every cell within the provided tolerance") + + if permutation.shape != points_a.shape[0:2]: + raise ValueError("It was not possible to find a permutation for every cell within the provided tolerance") + + return permutation + + def curveField(self, order, tol=1e-8): ''' This method returns a curved mesh as a Firedrake function. @@ -78,21 +123,21 @@ def curveField(self, order, tol=1e-8): ''' #Checking if the mesh is a surface mesh or two dimensional mesh - surf = len(self.netgen_mesh.Elements3D()) == 0 # REMOVE if len(self.netgen_mesh.Elements3D()) == 0: ng_element = self.netgen_mesh.Elements2D else: ng_element = self.netgen_mesh.Elements3D ng_dimension = len(ng_element()) + geom_dim = self.geometric_dimension() #Constructing mesh as a function low_order_element = self.coordinates.function_space().ufl_element().sub_elements[0] ufl_element = low_order_element.reconstruct(degree=order) firedrake_space = fd.VectorFunctionSpace(self, fd.BrokenElement(ufl_element)) - newFunctionCoordinates = fd.assemble(interpolate(self.coordinates, firedrake_space)) + new_coordinates = fd.assemble(interpolate(self.coordinates, firedrake_space)) #Computing reference points using fiat - fiat_element = newFunctionCoordinates.function_space().finat_element.fiat_equivalent + fiat_element = new_coordinates.function_space().finat_element.fiat_equivalent entity_ids = fiat_element.entity_dofs() nodes = fiat_element.dual_basis() ref = [] @@ -105,9 +150,8 @@ def curveField(self, order, tol=1e-8): reference_space_points = np.array(ref) #Mapping to the physical domain - els = {True: self.netgen_mesh.Elements2D, False: self.netgen_mesh.Elements3D} # REMOVE - physical_space_points = np.ndarray((ng_dimension, reference_space_points.shape[0], self.geometric_dimension())) - curved_space_points = np.ndarray((ng_dimension, reference_space_points.shape[0], self.geometric_dimension())) + physical_space_points = np.ndarray((ng_dimension, reference_space_points.shape[0], geom_dim)) + curved_space_points = np.ndarray((ng_dimension, reference_space_points.shape[0], geom_dim)) if self.comm.rank == 0: #Curving the mesh on rank 0 @@ -122,7 +166,7 @@ def curveField(self, order, tol=1e-8): physical_space_points = self.comm.bcast(physical_space_points, root=0) curved_space_points = self.comm.bcast(curved_space_points, root=0) curved = self.comm.bcast(curved, root=0) - cell_node_map = newFunctionCoordinates.cell_node_map() + cell_node_map = new_coordinates.cell_node_map() # Select only the points in curved cells physical_space_points = physical_space_points[curved] @@ -137,8 +181,7 @@ def curveField(self, order, tol=1e-8): barycentres = barycentres[owned] ng_index = [idx for idx, o in zip(ng_index, owned) if o] - # Do we want the min or max of these??? - norms = np.linalg.norm(physical_space_points - newFunctionCoordinates.dat.data[cell_node_map.values[ng_index]], axis=2) + breakpoint() # PyOP2 index pyop2_index = [] @@ -146,59 +189,20 @@ def curveField(self, order, tol=1e-8): pyop2_index.extend(cell_node_map.values[ngidx]) np.array(pyop2_index) - for dim in range(self.geometric_dimension()): - newFunctionCoordinates.sub(dim).dat.data[pyop2_index] = curved_space_points[:,:,dim].flatten() - - # ~ breakpoint() - - # ~ for i in range(physical_space_points.shape[0]): - # ~ #Inefficent code but runs only on curved elements - # ~ if curved[i]: - # ~ pts = physical_space_points[i][0:reference_space_points.shape[0]] - # ~ bary = sum([np.array(pts[i]) for i in range(len(pts))])/len(pts) - # ~ Idx = self.locate_cell(bary) - # ~ isInMesh = (0<=Idx tol: - # ~ fd.logging.warning(f"[{self.comm.rank}] Not able to curve Firedrake element {Idx} ({i}) -- residual: {res}") - # ~ else: - # ~ for j, datIdx in enumerate(cell_node_map.values[Idx][0:reference_space_points.shape[0]]): - # ~ for dim in range(self.geometric_dimension()): - # ~ coo = curved_space_points[i][j][dim] - # ~ newFunctionCoordinates.sub(dim).dat.data[datIdx] = coo - # ~ else: - # ~ if isInMesh: - # ~ p = [np.argmin(np.sum((pts - pt)**2, axis=1)) - # ~ for pt in newFunctionCoordinates.dat.data[cell_node_map.values[Idx]][0:reference_space_points.shape[0]]] - # ~ curved_space_points[i] = curved_space_points[i][p] - # ~ res = np.linalg.norm(pts[p]-newFunctionCoordinates.dat.data[cell_node_map.values[Idx]][0:reference_space_points.shape[0]]) - # ~ else: - # ~ res = np.inf - # ~ res = self.comm.gather(res, root=0) - # ~ res = self.comm.bcast(res, root=0) - # ~ rank = np.argmin(res) - # ~ if self.comm.rank == rank: - # ~ if res[rank] > tol: - # ~ fd.logging.warning("[{}, {}] Not able to curve Firedrake element {} \ - # ~ ({}) -- residual: {}".format(self.comm.rank, shared, Idx,i, res)) - # ~ else: - # ~ for j, datIdx in enumerate(cell_node_map.values[Idx][0:reference_space_points.shape[0]]): - # ~ for dim in range(self.geometric_dimension()): - # ~ coo = curved_space_points[i][j][dim] - # ~ newFunctionCoordinates.sub(dim).dat.data[datIdx] = coo - # ~ breakpoint() - return newFunctionCoordinates + # Find the correct coordinate permutation for each cell + permutation = find_permutation( + physical_space_points, + new_coordinates.dat.data[pyop2_index].reshape(physical_space_points.shape) + ) + + # Apply the permutation to each cell in turn + for ii, p in enumerate(curved_space_points): + curved_space_points[ii] = p[permutation[ii]] + + # Assign the curved coordinates to the dat + new_coordinates.dat.data[pyop2_index] = curved_space_points.reshape(-1, geom_dim) + + return new_coordinates def splitToQuads(plex, dim, comm): '''