-
Notifications
You must be signed in to change notification settings - Fork 0
/
spline_fit.py
154 lines (117 loc) · 4.67 KB
/
spline_fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
'A helper module from spline fitting taken from ft-fsd-path-planning'
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import numpy as np
from scipy.interpolate import splev, splprep
def trace_distance_to_next(trace: np.ndarray) -> np.ndarray:
"""
Calculates the distance of one point in the trace to the next. Obviously the last
point doesn't have any distance associated
Args:
trace (np.array): The points of the trace
Returns:
np.array: A vector containing the distances from one point to the next
"""
return np.linalg.norm(np.diff(trace, axis=-2), axis=-1)
@dataclass
class SplineEvaluator:
"""
A class for evaluating a spline.
"""
max_u: float
tck: Tuple[Any, Any, int]
predict_every: float
def calculate_u_eval(self, max_u: Optional[float] = None) -> np.ndarray:
"""
Calculate the u_eval values for the spline.
Args:
max_u (Optional[float], optional): The maximum u value. Defaults to None. If
None, the maximum u value used during fitting is taken.
Returns:
np.ndarray: The values for which the spline should be evaluated.
"""
if max_u is None:
max_u = self.max_u
return np.arange(0, max_u, self.predict_every)
def predict(self, der: int, max_u: Optional[float] = None) -> np.ndarray:
"""
Predict the spline. If der is 0, the function returns the spline. If der is 1,
the function returns the first derivative of the spline and so on.
Args:
der (int): The derivative to predict.
max_u (Optional[float], optional): The maximum u value. Defaults to None. If
None, the maximum u value used during fitting is taken.
Returns:
np.ndarray: The predicted spline.
"""
u_eval = self.calculate_u_eval(max_u)
values = np.array(splev(u_eval, tck=self.tck, der=der)).T
return values
class NullSplineEvaluator(SplineEvaluator):
"""
A dummy spline evaluator used for when an empty list is attempted to be fitted
"""
def predict(self, der: int, max_u: Optional[float] = None) -> np.ndarray:
points = np.zeros((0, 2))
return points
class SplineFitterFactory:
"""
Wrapper class for `splev`, `splprep` functions
"""
def __init__(self, smoothing: float, predict_every: float, max_deg: int):
"""
Constructor for SplineFitter class
Args:
smoothing (float): The smoothing factor. 0 means no smoothing
predict_every (float): The approximate distance along the fitted trace to calculate a
point for
max_deg (int): The maximum degree of the fitted splines
"""
self.smoothing = smoothing
self.predict_every = predict_every
self.max_deg = max_deg
def fit(self, trace: np.ndarray, periodic: bool = False) -> SplineEvaluator:
"""
Fit a trace and returns a closure that can evaluate the fitted spline at
different positions. The maximal spline degree is 2.
Args:
trace (np.ndarray): The trace to fit
Returns:
Callable[[int, float]: A closure that when called evalues
the fitted spline on the provided positions.
"""
if len(trace) < 2:
return NullSplineEvaluator(
# dummy values
0,
(0, 0, 0),
0,
)
k = np.clip(len(trace) - 1, 1, self.max_deg)
distance_to_next = trace_distance_to_next(trace)
u_fit = np.concatenate(([0], np.cumsum(distance_to_next)))
tck, _ = splprep( # pylint: disable=unbalanced-tuple-unpacking
trace.T, s=self.smoothing, k=k, u=u_fit, per=periodic
)
max_u = float(u_fit[-1])
return SplineEvaluator(max_u, tck, self.predict_every)
def fit_then_evaluate_trace_and_derivative(
self, trace: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""
Fit a provided trace, then evaluates it, and its derivative in `n_predict`
evenly spaced positions
Args:
trace (np.ndarray): The trace to fit
Returns:
Tuple[np.ndarray, np.ndarray]: A tuple containing the evaluated trace and
the evaluated derivative
"""
if len(trace) < 2:
return trace.copy(), trace.copy()
fitted_func = self.fit(trace)
evaluated_trace = fitted_func.predict(der=0)
evaluated_derivative = fitted_func.predict(der=1)
return evaluated_trace, evaluated_derivative