Skip to content

Commit

Permalink
Merge pull request #6 from pbrubeck/pbrubeck/fix/fiat-ordering
Browse files Browse the repository at this point in the history
Read points from FIAT dual basis, thx Pablo!
  • Loading branch information
UZerbinati authored Oct 16, 2023
2 parents 6641c1a + 99f8bdd commit d4bc6fa
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions ngsPETSc/utils/firedrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,27 @@ def curveField(self, order):
:arg order: the order of the curved mesh
'''
newFunctionCoordinates = fd.interpolate(self.coordinates,
fd.VectorFunctionSpace(self,"DG",order))
low_order_element = self.coordinates.function_space().ufl_element().sub_elements()[0]
element = low_order_element.reconstruct(degree=order)
space = fd.VectorFunctionSpace(self, ufl.BrokenElement(element))
newFunctionCoordinates = fd.interpolate(self.coordinates, space)

#Computing reference points using fiat
fiat_element = newFunctionCoordinates.function_space().finat_element.fiat_equivalent
entity_ids = fiat_element.entity_dofs()
nodes = fiat_element.dual_basis()
refPts = []
for dim in entity_ids:
for entity in entity_ids[dim]:
for dof in entity_ids[dim][entity]:
# Assert singleton point for each node.
pt, = nodes[dof].get_point_dict().keys()
refPts.append(pt)

V = newFunctionCoordinates.dat.data
#Computing reference points using ufl
ref_element = newFunctionCoordinates.function_space().finat_element.fiat_equivalent.ref_el
getIdx = self._cell_numbering.getOffset
refPts = []
for (i,j) in ref_element.sub_entities[self.geometric_dimension()][0]:
if i < self.geometric_dimension():
refPts = refPts+list(ref_element.make_points(i,j,order))
refPts = np.array(refPts)
rnd = lambda x: round(x, 8)
if self.geometric_dimension() == 2:
#Mapping to the physical domain
physPts = np.ndarray((len(self.netgen_mesh.Elements2D()),
Expand All @@ -86,10 +96,10 @@ def curveField(self, order):
cellMap = newFunctionCoordinates.cell_node_map()
for i, el in enumerate(self.netgen_mesh.Elements2D()):
if el.curved:
pts = [tuple(map(lambda x: round(x,8),pts))
pts = [tuple(map(rnd, pts))
for pts in physPts[i][0:refPts.shape[0]]]
dofMap = {k: v for v, k in enumerate(pts)}
p = [dofMap[tuple(map(lambda x: round(x,8),pts))]
p = [dofMap[tuple(map(rnd, pts))]
for pts in V[cellMap.values[getIdx(i)]][0:refPts.shape[0]]]
curvedPhysPts[i] = curvedPhysPts[i][p]
for j, datIdx in enumerate(cellMap.values[getIdx(i)][0:refPts.shape[0]]):
Expand All @@ -109,10 +119,10 @@ def curveField(self, order):
cellMap = newFunctionCoordinates.cell_node_map()
for i, el in enumerate(self.netgen_mesh.Elements3D()):
if el.curved:
pts = [tuple(map(lambda x: round(x,8),pts))
pts = [tuple(map(rnd, pts))
for pts in physPts[i][0:refPts.shape[0]]]
dofMap = {k: v for v, k in enumerate(pts)}
p = [dofMap[tuple(map(lambda x: round(x,8),pts))]
p = [dofMap[tuple(map(rnd, pts))]
for pts in V[cellMap.values[getIdx(i)]][0:refPts.shape[0]]]
curvedPhysPts[i] = curvedPhysPts[i][p]
for j, datIdx in enumerate(cellMap.values[getIdx(i)][0:refPts.shape[0]]):
Expand Down

0 comments on commit d4bc6fa

Please sign in to comment.