-
Notifications
You must be signed in to change notification settings - Fork 2
/
Tridiagonal.h
65 lines (53 loc) · 1.18 KB
/
Tridiagonal.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include "Eigen.h"
template<typename T, int N>
class Tridiagonal
{
public:
Tridiagonal()
{
a.setZero();
b.setZero();
C.setZero();
}
void compute(const MatrixX& A)
{
assert(A.rows() == N && A.cols() == N);
a.tail(N-1) = A.diagonal(-1);
b = A.diagonal(0);
C.head(N-1) = A.diagonal(1);
C(0) = C(0)/b(0);
for (int j=1; j<N-1; j++)
{
C(j) = C(j)/(b(j)-a(j)*C(j-1));
}
}
template<typename R>
Matrix<R, N, 1> solve(const Matrix<R, N, 1>& d) const
{
Matrix<R, N, 1> x;
Matrix<R, N, 1> D;
// from wikipedia, Thomas algorithm
// forward pass
D(0) = d(0)/b(0);
for (int j=1; j<N; j++)
{
D(j) = (d(j) - a(j)*D(j-1))/(b(j)-a(j)*C(j-1));
}
// backward pass
x(N-1) = D(N-1);
for (int j=N-2; j>=0; j--)
{
x(j) = D(j) - C(j)*x(j+1);
}
return x;
}
int rows()
{
return N;
}
private:
// as per wikipedia
Matrix<T, N, 1> a; // lower
Matrix<T, N, 1> b; // diagonal
Matrix<T, N, 1> C; // transformed upper
};