forked from ribanez/Hilbert-AE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
154 lines (99 loc) · 3.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import argparse
import os
import torch
from torch import nn
from src.AutoEncoder import autoencoder
from src.DataLoader import Dataset_Hilbert, contruct_dataloader_from_disk
def get_args():
parser = argparse.ArgumentParser('Train Hilbert AutoEncoder')
parser.add_argument('--hdf5_file', type=str, help='Path to HDF5 file')
parser.add_argument('--checkpoint',
type=str,
default=None,
help='Path to Checkpoint Model')
parser.add_argument('--epochs',
type=int,
default=100,
help='Number of epochs')
parser.add_argument('--early_stop',
type=int,
default=40,
help='Early stop limit')
parser.add_argument('--lr',
type=float,
default=0.001,
help='learning rate')
parser.add_argument('--weight_decay',
type=float,
default=1e-5,
help='weight decay to optimizer')
parser.add_argument('--batch_size',
type=int,
default=256,
help='Batch size')
parser.add_argument('--nc',
type=int,
default=1,
help='Number of channels in data')
parser.add_argument('--ld',
type=int,
default=256,
help='latent dimension size')
args, _ = parser.parse_known_args()
args = parser.parse_args()
return args
def create_folders():
if not os.path.exists("./output/"):
os.mkdir("./output/")
def train(args):
nc = args.nc
ndf = args.ld
model = autoencoder(nc, ndf).to('cuda:1')
checkpoint = args.checkpoint
if checkpoint is not None and os.path.exists(checkpoint):
model.load_state_dict(torch.load(checkpoint))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
train_loader = contruct_dataloader_from_disk(args.hdf5_file,
args.batch_size)
num_epochs = args.epochs
early_stop_limit = args.early_stop
early_stop_count = 0
train_loss = []
create_folders()
best_path = "./output/HILBERT_AE_best.pth"
for epoch in range(num_epochs):
loss_train = 0
for idx, minibatch_ in enumerate(train_loader):
hilbert_map = minibatch_
hilbert_map = torch.stack(hilbert_map).permute(0, 3, 1, 2).type(
torch.FloatTensor)
hilbert_map = hilbert_map.to('cuda:1')
# ===================forward=====================
output = model(hilbert_map)
loss = criterion(output, hilbert_map)
# ===================backward====================
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_train += loss
# ===================log========================
print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs,
loss_train.item() / idx))
train_loss.append(loss_train.item() / idx)
if epoch % 10 == 0:
torch.save(model.state_dict(),
"./output/HILBERT_AE_{}.pth".format(epoch))
if len(train_loss) > 2 and train_loss[-1] == min(train_loss):
torch.save(model.state_dict(), best_path)
early_stop_count = 0
else:
early_stop_count += 1
if early_stop_count > early_stop_limit:
break
print("AutoEncoder was trained !!")
if __name__ == '__main__':
args = get_args()
train(args)