-
Notifications
You must be signed in to change notification settings - Fork 7
/
klstm.py
36 lines (30 loc) · 940 Bytes
/
klstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# klstm.py
import numpy
def create_dataset(dataset, look_back=1):
"""
Convert an array of values into a dataset matrix
"""
dataX, dataY = [], []
for i in range(len(dataset) - look_back - 1):
a = dataset[i:(i + look_back), 0]
dataX.append(a)
dataY.append(dataset[i + look_back, 0])
return numpy.array(dataX), numpy.array(dataY)
def create_dataset_nd(dataset, look_back=1):
dataX, dataY = [], []
for i in range(len(dataset) - look_back - 1):
a = dataset[i:(i + look_back), :].T
dataX.append(a)
dataY.append(dataset[i + look_back, :].T)
return numpy.array(dataX), numpy.array(dataY)
def create_dataset_a(dataset_a, look_back=1):
"""
Input
-----
dataset_a, 2d nd.array[ Nsample, Ntime]
"""
XT, yT = create_dataset_nd(dataset_a.T, look_back=look_back)
X = XT.T
X = X.reshape(-1, 1, X.shape[2])
y = yT.T
return X, y