-
Notifications
You must be signed in to change notification settings - Fork 0
/
Matrix.java
138 lines (117 loc) · 4.3 KB
/
Matrix.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package com.company;
/**
* Created by Benjamin R. on 30-07-2016.
*/
public class Matrix { // Includes gaussian elimination methods
private double[][] data;
public Matrix(int m, int n) {
data = new double[m][n];
}
public static Matrix gaussian(Matrix a, Matrix b) {
int n = a.data.length; // Number of unknowns
Matrix q = new Matrix(n, n + 1);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) // Form q matrix
q.data[i][j] = a.data[i][j];
q.data[i][n] = b.data[i][0];
}
forward_solve(q); // Do Gaussian elimination
back_solve(q); // Perform back substitution
Matrix x = new Matrix(n, 1);
for (int i = 0; i < n; i++)
x.data[i][0] = q.data[i][n];
return x;
}
private static void forward_solve(Matrix q) {
int n = q.data.length;
for (int i = 0; i < n; i++) { // Find row w/max element in this
int maxRow = i; // column, at or below diagonal
for (int k = i + 1; k < n; k++)
if (Math.abs(q.data[k][i]) > Math.abs(q.data[maxRow][i]))
maxRow = k;
if (maxRow != i) // If row not current row, swap
for (int j = i; j <= n; j++) {
double t = q.data[i][j];
q.data[i][j] = q.data[maxRow][j];
q.data[maxRow][j] = t;
}
for (int j = i + 1; j < n; j++) { // Calculate pivot ratio
double pivot = q.data[j][i] / q.data[i][i];
for (int k = i; k <= n; k++) // Pivot operation itself
q.data[j][k] -= q.data[i][k] * pivot;
}
}
}
private static void back_solve(Matrix q) {
int n = q.data.length;
for (int j = n - 1; j >= 0; j--) { // Start at last row
double t = 0.0; // t- temporary
for (int k = j + 1; k < n; k++)
t += q.data[j][k] * q.data[k][n];
q.data[j][n] = (q.data[j][n] - t) / q.data[j][j];
}
}
public void setIdentity() {
int i, j;
int nrows = data.length;
int ncols = data[0].length;
for (i = 0; i < nrows; i++)
for (j = 0; j < ncols; j++)
if (i == j)
data[i][j] = 1.0;
else
data[i][j] = 0.0;
}
public int getNumRows() {
return data.length;
}
public int getNumCols() {
return data[0].length;
}
public double getElement(int i, int j) {
return data[i][j];
}
public void setElement(int i, int j, double val) {
data[i][j] = val;
}
public void incrElement(int i, int j, double incr) {
data[i][j] += incr;
}
public Matrix add(Matrix b) {
Matrix result = null;
int nrows = data.length;
int ncols = data[0].length;
if (nrows == b.data.length && ncols == b.data[0].length) {
result = new Matrix(nrows, ncols);
for (int i = 0; i < nrows; i++)
for (int j = 0; j < ncols; j++)
result.data[i][j] = this.data[i][j] + b.data[i][j];
}
return result;
}
public Matrix mult(Matrix b) {
Matrix result = null;
int nrows = data.length;
int p = data[0].length;
if (p == b.data.length) {
result = new Matrix(nrows, b.data[0].length);
for (int i = 0; i < nrows; i++)
for (int j = 0; j < result.data[0].length; j++) {
double t = 0.0;
for (int k = 0; k < p; k++) {
t += data[i][k] * b.data[k][j];
}
result.data[i][j] = t;
}
}
return result;
}
public void print() {
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++)
System.out.print(data[i][j] + " ");
System.out.println();
}
System.out.println();
}
}