From d7d26864870021e9e64087f9882933392aec68f0 Mon Sep 17 00:00:00 2001 From: Eivind Fonn Date: Fri, 21 Aug 2020 16:28:55 +0200 Subject: [PATCH] CP operations for make_periodic --- splipy/SplineObject.py | 98 +++++++++++++++++++++--------------------- splipy/operations.py | 35 +++++++++++---- 2 files changed, 77 insertions(+), 56 deletions(-) diff --git a/splipy/SplineObject.py b/splipy/SplineObject.py index 176b9680..ad0161de 100644 --- a/splipy/SplineObject.py +++ b/splipy/SplineObject.py @@ -9,7 +9,7 @@ from splipy.utils import reshape, rotation_matrix, is_singleton, ensure_listlike,\ check_direction, ensure_flatlist, check_section, sections,\ raise_order_1D -from splipy.operations import TensorDot, Transpose, Identity, Roll, Reverse, Swap, Rationalize +from splipy.operations import TensorDot, Transpose, Identity, Roll, Rationalize, Index, WeightedAverage __all__ = ['SplineObject'] @@ -172,7 +172,7 @@ def reverse(self, direction=0): """ direction = check_direction(direction, self.pardim) self.bases[direction].reverse() - return Reverse(direction) + return Index.reverse(direction) def swap(self, dir1=0, dir2=1): """ Swaps two parameter directions. @@ -189,7 +189,7 @@ def swap(self, dir1=0, dir2=1): dir1 = check_direction(dir1, self.pardim) dir2 = check_direction(dir2, self.pardim) self.bases[dir1], self.bases[dir2] = self.bases[dir2], self.bases[dir1] - return Swap(dir1, dir2, self.pardim) + return Transpose.swap(dir1, dir2, self.pardim) def insert_knot(self, knot, direction=0): """ Insert a new knot into the spline structure. @@ -334,6 +334,50 @@ def shape(self): """The dimensions of the control point array.""" return tuple(b.num_functions() for b in self.bases) + def make_periodic(self, continuity=None, direction=0): + """ Make the spline object periodic in a given parametric direction. + + :param int continuity: The continuity along the boundary (default max). + :param int direction: The direction to ensure continuity in. + :return: (SplineStructure, ControlPointOperation) + """ + direction = check_direction(direction, self.pardim) + basis = self.bases[direction] + if continuity is None: + continuity = basis.order - 2 + if not -1 <= continuity <= basis.order - 2: + raise ValueError('Illegal continuity for basis of order {}: {}'.format( + continuity, basis.order + )) + if continuity == -1: + raise ValueError( + 'Operation not supported. ' + 'For discontinuous spline spaces, consider SplineObject.split().' + ) + if basis.periodic >= 0: + raise ValueError('Basis is already periodic') + + basis = basis.make_periodic(continuity) + + # Merge control points + index_beg = [slice(None,None,None)] * (self.pardim + 1) + index_end = [slice(None,None,None)] * (self.pardim + 1) + cps = np.array(self.controlpoints) + weights = np.linspace(0, 1, continuity + 1) if continuity > 0 else [0.5] + operation = Identity() + for i, j, t in zip(range(continuity + 1), range(-continuity-1, 0), weights): + # Weighted average between cps[..., i, ..., :] and cps[..., -c-1+i, ..., :] + # The weights are chosen so that, for periodic c, the round trip + # c.split().make_periodic() with suitable arguments produces an + # object identical to c. (Mostly black magic.) + operation *= WeightedAverage(direction, i, j, t) + + operation *= Index.upto(direction, -(continuity + 1)) + + bases = list(self.bases) + bases[direction] = basis + return SplineStructure(bases, self.rational), operation + @classmethod def make_splines_compatible(cls, spline1, spline2): """Ensure that two spline structrures are compatible. This @@ -1303,51 +1347,9 @@ def make_periodic(self, continuity=None, direction=0): :param int continuity: The continuity along the boundary (default max). :param int direction: The direction to ensure continuity in. """ - - direction = check_direction(direction, self.pardim) - basis = self.bases[direction] - if continuity is None: - continuity = basis.order - 2 - if not -1 <= continuity <= basis.order - 2: - raise ValueError('Illegal continuity for basis of order {}: {}'.format( - continuity, basis.order - )) - if continuity == -1: - raise ValueError( - 'Operation not supported. ' - 'For discontinuous spline spaces, consider SplineObject.split().' - ) - if basis.periodic >= 0: - raise ValueError('Basis is already periodic') - - basis = basis.make_periodic(continuity) - - # Merge control points - index_beg = [slice(None,None,None)] * (self.pardim + 1) - index_end = [slice(None,None,None)] * (self.pardim + 1) - cps = np.array(self.controlpoints) - weights = np.linspace(0, 1, continuity + 1) if continuity > 0 else [0.5] - for i, j, t in zip(range(continuity + 1), range(-continuity-1, 0), weights): - # Weighted average between cps[..., i, ..., :] and cps[..., -c-1+i, ..., :] - # The weights are chosen so that, for periodic c, the round trip - # c.split().make_periodic() with suitable arguments produces an - # object identical to c. (Mostly black magic.) - index_beg[direction] = i - index_end[direction] = j - cps[tuple(index_beg)] = t * cps[tuple(index_beg)] + (1 - t) * cps[tuple(index_end)] - - # cps[..., :-(continuity+1), ..., :] - index_beg[direction] = slice(None, -(continuity + 1), None) - cps = cps[tuple(index_beg)] - - bases = list(self.bases) - bases[direction] = basis - args = bases + [cps] + [self.rational] - - # search for the right subclass constructor, i.e. Volume, Surface or Curve - constructor = [c for c in SplineObject.__subclasses__() if c._intended_pardim == len(self.bases)] - constructor = constructor[0] - return constructor(*args, raw=True) + structure, operation = super().make_periodic(continuity, direction) + cps = operation(self.controlpoints) + return structure.with_controlpoints(cps) def clone(self): """Clone the object.""" diff --git a/splipy/operations.py b/splipy/operations.py index a3be346f..a3304e22 100644 --- a/splipy/operations.py +++ b/splipy/operations.py @@ -38,10 +38,17 @@ def __init__(self, index): def __call__(self, cps): return cps[self.index] + @classmethod + def single(cls, axis, index): + return cls((slice(None, None, None),) * axis + (index,)) -def Reverse(axis): - index = (slice(None, None, None),) * axis + (slice(None, None, -1),) - return Index(index) + @classmethod + def upto(cls, axis, index): + return cls((slice(None, None, None),) * axis + (slice(None, index, None),)) + + @classmethod + def reverse(cls, axis): + return cls((slice(None, None, None),) * axis + (slice(None, None, -1),)) class Rationalize(ControlPointOperation): @@ -78,9 +85,21 @@ def __init__(self, permutation): def __call__(self, cps): return cps.transpose(self.permutation) + @classmethod + def swap(cls, dir1, dir2, pardim): + permutation = list(range(pardim + 1)) + permutation[dir1] = dir2 + permutation[dir2] = dir1 + return cls(tuple(permutation)) + + +class WeightedAverage(ControlPointOperation): -def Swap(dir1, dir2, pardim): - permutation = list(range(pardim + 1)) - permutation[dir1] = dir2 - permutation[dir2] = dir1 - return Transpose(tuple(permutation)) + def __init__(self, axis, i, j, weight): + self.i = Index.single(axis, i).index + self.j = Index.single(axis, j).index + self.weight = weight + + def __call__(self, cps): + cps[self.i] = self.weight * cps[self.i] + (1 - self.weight) * cps[self.j] + return cps