Skip to content

Commit

Permalink
CP operations for make_periodic
Browse files Browse the repository at this point in the history
  • Loading branch information
TheBB committed Aug 21, 2020
1 parent 58ea7aa commit d7d2686
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 56 deletions.
98 changes: 50 additions & 48 deletions splipy/SplineObject.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from splipy.utils import reshape, rotation_matrix, is_singleton, ensure_listlike,\
check_direction, ensure_flatlist, check_section, sections,\
raise_order_1D
from splipy.operations import TensorDot, Transpose, Identity, Roll, Reverse, Swap, Rationalize
from splipy.operations import TensorDot, Transpose, Identity, Roll, Rationalize, Index, WeightedAverage

__all__ = ['SplineObject']

Expand Down Expand Up @@ -172,7 +172,7 @@ def reverse(self, direction=0):
"""
direction = check_direction(direction, self.pardim)
self.bases[direction].reverse()
return Reverse(direction)
return Index.reverse(direction)

def swap(self, dir1=0, dir2=1):
""" Swaps two parameter directions.
Expand All @@ -189,7 +189,7 @@ def swap(self, dir1=0, dir2=1):
dir1 = check_direction(dir1, self.pardim)
dir2 = check_direction(dir2, self.pardim)
self.bases[dir1], self.bases[dir2] = self.bases[dir2], self.bases[dir1]
return Swap(dir1, dir2, self.pardim)
return Transpose.swap(dir1, dir2, self.pardim)

def insert_knot(self, knot, direction=0):
""" Insert a new knot into the spline structure.
Expand Down Expand Up @@ -334,6 +334,50 @@ def shape(self):
"""The dimensions of the control point array."""
return tuple(b.num_functions() for b in self.bases)

def make_periodic(self, continuity=None, direction=0):
""" Make the spline object periodic in a given parametric direction.
:param int continuity: The continuity along the boundary (default max).
:param int direction: The direction to ensure continuity in.
:return: (SplineStructure, ControlPointOperation)
"""
direction = check_direction(direction, self.pardim)
basis = self.bases[direction]
if continuity is None:
continuity = basis.order - 2
if not -1 <= continuity <= basis.order - 2:
raise ValueError('Illegal continuity for basis of order {}: {}'.format(
continuity, basis.order
))
if continuity == -1:
raise ValueError(
'Operation not supported. '
'For discontinuous spline spaces, consider SplineObject.split().'
)
if basis.periodic >= 0:
raise ValueError('Basis is already periodic')

basis = basis.make_periodic(continuity)

# Merge control points
index_beg = [slice(None,None,None)] * (self.pardim + 1)
index_end = [slice(None,None,None)] * (self.pardim + 1)
cps = np.array(self.controlpoints)
weights = np.linspace(0, 1, continuity + 1) if continuity > 0 else [0.5]
operation = Identity()
for i, j, t in zip(range(continuity + 1), range(-continuity-1, 0), weights):
# Weighted average between cps[..., i, ..., :] and cps[..., -c-1+i, ..., :]
# The weights are chosen so that, for periodic c, the round trip
# c.split().make_periodic() with suitable arguments produces an
# object identical to c. (Mostly black magic.)
operation *= WeightedAverage(direction, i, j, t)

operation *= Index.upto(direction, -(continuity + 1))

bases = list(self.bases)
bases[direction] = basis
return SplineStructure(bases, self.rational), operation

@classmethod
def make_splines_compatible(cls, spline1, spline2):
"""Ensure that two spline structrures are compatible. This
Expand Down Expand Up @@ -1303,51 +1347,9 @@ def make_periodic(self, continuity=None, direction=0):
:param int continuity: The continuity along the boundary (default max).
:param int direction: The direction to ensure continuity in.
"""

direction = check_direction(direction, self.pardim)
basis = self.bases[direction]
if continuity is None:
continuity = basis.order - 2
if not -1 <= continuity <= basis.order - 2:
raise ValueError('Illegal continuity for basis of order {}: {}'.format(
continuity, basis.order
))
if continuity == -1:
raise ValueError(
'Operation not supported. '
'For discontinuous spline spaces, consider SplineObject.split().'
)
if basis.periodic >= 0:
raise ValueError('Basis is already periodic')

basis = basis.make_periodic(continuity)

# Merge control points
index_beg = [slice(None,None,None)] * (self.pardim + 1)
index_end = [slice(None,None,None)] * (self.pardim + 1)
cps = np.array(self.controlpoints)
weights = np.linspace(0, 1, continuity + 1) if continuity > 0 else [0.5]
for i, j, t in zip(range(continuity + 1), range(-continuity-1, 0), weights):
# Weighted average between cps[..., i, ..., :] and cps[..., -c-1+i, ..., :]
# The weights are chosen so that, for periodic c, the round trip
# c.split().make_periodic() with suitable arguments produces an
# object identical to c. (Mostly black magic.)
index_beg[direction] = i
index_end[direction] = j
cps[tuple(index_beg)] = t * cps[tuple(index_beg)] + (1 - t) * cps[tuple(index_end)]

# cps[..., :-(continuity+1), ..., :]
index_beg[direction] = slice(None, -(continuity + 1), None)
cps = cps[tuple(index_beg)]

bases = list(self.bases)
bases[direction] = basis
args = bases + [cps] + [self.rational]

# search for the right subclass constructor, i.e. Volume, Surface or Curve
constructor = [c for c in SplineObject.__subclasses__() if c._intended_pardim == len(self.bases)]
constructor = constructor[0]
return constructor(*args, raw=True)
structure, operation = super().make_periodic(continuity, direction)
cps = operation(self.controlpoints)
return structure.with_controlpoints(cps)

def clone(self):
"""Clone the object."""
Expand Down
35 changes: 27 additions & 8 deletions splipy/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,17 @@ def __init__(self, index):
def __call__(self, cps):
return cps[self.index]

@classmethod
def single(cls, axis, index):
return cls((slice(None, None, None),) * axis + (index,))

def Reverse(axis):
index = (slice(None, None, None),) * axis + (slice(None, None, -1),)
return Index(index)
@classmethod
def upto(cls, axis, index):
return cls((slice(None, None, None),) * axis + (slice(None, index, None),))

@classmethod
def reverse(cls, axis):
return cls((slice(None, None, None),) * axis + (slice(None, None, -1),))


class Rationalize(ControlPointOperation):
Expand Down Expand Up @@ -78,9 +85,21 @@ def __init__(self, permutation):
def __call__(self, cps):
return cps.transpose(self.permutation)

@classmethod
def swap(cls, dir1, dir2, pardim):
permutation = list(range(pardim + 1))
permutation[dir1] = dir2
permutation[dir2] = dir1
return cls(tuple(permutation))


class WeightedAverage(ControlPointOperation):

def Swap(dir1, dir2, pardim):
permutation = list(range(pardim + 1))
permutation[dir1] = dir2
permutation[dir2] = dir1
return Transpose(tuple(permutation))
def __init__(self, axis, i, j, weight):
self.i = Index.single(axis, i).index
self.j = Index.single(axis, j).index
self.weight = weight

def __call__(self, cps):
cps[self.i] = self.weight * cps[self.i] + (1 - self.weight) * cps[self.j]
return cps

0 comments on commit d7d2686

Please sign in to comment.