Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SplineStructure superclass of SplineObject #122

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 111 additions & 2 deletions splipy/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

from bisect import bisect_right, bisect_left
import copy
from operator import attrgetter, methodcaller

import numpy as np
from scipy.sparse import csr_matrix

from .utils import ensure_listlike
from . import basis_eval, state
from typing import List, Iterable, Tuple

from .utils import ensure_listlike, check_direction, is_singleton
from . import basis_eval, state, transform as trf

__all__ = ['BSplineBasis']

Expand Down Expand Up @@ -527,3 +530,109 @@ def __repr__(self):
if self.periodic > -1:
result += ', C' + str(self.periodic) + '-periodic'
return result


class TensorBasis:

bases: List[BSplineBasis]
rational: bool

def __init__(self, *bases: BSplineBasis, rational: bool = False):
self.bases = list(bases)
self.rational = rational

def __iter__(self) -> Iterable[BSplineBasis]:
yield from self.bases

def __len__(self) -> int:
return len(self.bases)

def __getitem__(self, index: int) -> BSplineBasis:
return self.bases[index]

@property
def ndims(self) -> int:
return len(self)

@property
def shape(self) -> Tuple[int, ...]:
return tuple(b.num_functions() for b in self.bases)

def validate_domain(self, *params: float):
"""Check whether the given evaluation parameters are valid.

:raises ValueError: If the parameters are outside the domain
"""
for b, p in zip(self.bases, params):
b.snap(p)
if b.periodic < 0 and (min(p) < b.start() or b.end() < max(p)):
raise ValueError("Evaluation outside parametric domain")

def start(self, direction=None):
if direction is None:
return tuple(b.start() for b in self.bases)
direction = check_direction(direction, self.ndims)
return self.bases[direction].start()

def end(self, direction=None):
if direction is None:
return tuple(b.end() for b in self.bases)
direction = check_direction(direction, self.ndims)
return self.bases[direction].end()

def order(self, direction=None):
if direction is None:
return tuple(b.order for b in self.bases)
direction = check_direction(direction, self.ndims)
return self.bases[direction].order

def knots(self, direction=None, with_multiplicities=False):
getter = attrgetter('knots') if with_multiplicities else methodcaller('knot_spans')
if direction is None:
return tuple(getter(b) for b in self.bases)
direction = check_direction(direction, self.ndims)
return getter(self.bases[direction])

def evaluate(self, *params, **kwargs) -> trf.Transform:
squeeze = all(is_singleton(p) for p in params)
params = [ensure_listlike(p) for p in params]

tensor = kwargs.get('tensor', True)
if not tensor and len({len(p) for p in params}) != 1:
raise ValueError('Parameters must have same length')

self.validate_domain(*params)

# Evaluate the corresponding bases at the corresponding points
# and build the result array
Ns = [b.evaluate(p) for b, p in zip(self.bases, params)]
return trf.Evaluator(Ns, tensor, self.rational, squeeze)

def derivative(self, *params, **kwargs) -> trf.Transform:
squeeze = all(is_singleton(p) for p in params)
params = [ensure_listlike(p) for p in params]

derivs = kwargs.get('d', [1] * self.ndims)
derivs = ensure_listlike(derivs, self.ndims)

above = kwargs.get('above', [True] * self.ndims)
above = ensure_listlike(above, self.ndims)

tensor = kwargs.get('tensor', True)
if not tensor and len({len(p) for p in params}) != 1:
raise ValueError('Parameters must have same length')

self.validate_domain(*params)
dNs = [b.evaluate(p, d, from_right) for b, p, d, from_right in zip(self.bases, params, derivs, above)]

if not self.rational:
return trf.DerivativeEvaluator(dNs, tensor, None, squeeze)

Ns = [b.evaluate(p) for b, p in zip(self.bases, params)]
return trf.DerivativeEvaluator(dNs, tensor, Ns, squeeze)

def swap(self, dir1: int, dir2: int) -> trf.Transform:
dir1 = check_direction(dir1, self.ndims)
dir2 = check_direction(dir2, self.ndims)
self.bases[dir1], self.bases[dir2] = self.bases[dir2], self.bases[dir1]
return trf.SwapTransform(dir1, dir2)
8 changes: 4 additions & 4 deletions splipy/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import scipy.sparse.linalg as splinalg

from .basis import BSplineBasis
from .basis import BSplineBasis, TensorBasis
from .splineobject import SplineObject
from .utils import ensure_listlike, is_singleton

Expand Down Expand Up @@ -52,7 +52,7 @@ def evaluate(self, *params):
squeeze = is_singleton(params[0])
params = [ensure_listlike(p) for p in params]

self._validate_domain(*params)
self.basis.validate_domain(*params)

# Evaluate the derivatives of the corresponding bases at the corresponding points
# and build the result array
Expand Down Expand Up @@ -290,7 +290,7 @@ def raise_order(self, amount, direction=None):

# solve the interpolation problem
self.controlpoints = np.array(splinalg.spsolve(N_new, interpolation_pts_x))
self.bases = [newBasis]
self.basis = TensorBasis(newBasis, rational=self.rational)

return self

Expand Down Expand Up @@ -344,7 +344,7 @@ def append(self, curve):
new_controlpoints[n1:, :] = extending_curve.controlpoints[1:, :]

# update basis and controlpoints
self.bases = [BSplineBasis(p, new_knot)]
self.basis = TensorBasis(BSplineBasis(p, new_knot), rational=self.rational)
self.controlpoints = new_controlpoints

return self
Expand Down
Loading