-
Notifications
You must be signed in to change notification settings - Fork 0
/
formatData.py
69 lines (48 loc) · 2.13 KB
/
formatData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# coding: utf-8
# In[ ]:
import numpy as np
import random
# # Converting data to train and test
# In[ ]:
class CryptoData(object):
def __init__(self,
stock_sym,
df,
labels,
input_size=1,
num_steps=30,
test_ratio=0.1,
normalized=True):
self.stock_sym = stock_sym
self.input_size = input_size
self.num_steps = num_steps
self.test_ratio = test_ratio
self.normalized = normalized
self.training_data = df
self.labels = labels
dates = df.iloc[:, 3].values.tolist()
df = df.drop(['end_time'], axis = 1)
x = np.array(df.values.tolist())
y = np.array(labels.values.tolist())
x = x.flatten('C')
i = num_steps
seq = [np.array(x[i * input_size: (i + 1) * input_size]) for i in range(len(x) // input_size)]
X_raw_seq = np.array([seq[i: i + num_steps] for i in range(len(seq)-num_steps)])
y_raw_seq = np.array([y[i + num_steps] for i in range(len(y) - num_steps)])
dates_y = np.array([dates[i + num_steps] for i in range(len(dates)-num_steps)])
train_size = int(len(X_raw_seq) * (1-test_ratio))
self.train_X, self.test_X = X_raw_seq[:train_size], X_raw_seq[train_size:]
self.train_y, self.test_y = y_raw_seq[:train_size], y_raw_seq[train_size:]
self.dates = np.array(dates_y[train_size:])
print(self.train_y.shape,self.train_X.shape,self.test_y.shape,self.test_X.shape, self.dates.shape)
def generate_one_epoch(self, batch_size):
num_batches = int(len(self.train_X)) // batch_size
if batch_size * num_batches < len(self.train_X):
num_batches += 1
batch_indices = list(range(num_batches))
random.shuffle(batch_indices)
for j in batch_indices:
batch_X = self.train_X[j * batch_size: (j + 1) * batch_size]
batch_y = self.train_y[j * batch_size: (j + 1) * batch_size]
# assert set(map(len, batch_X)) == {self.num_steps}
yield batch_X, batch_y