-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
executable file
·198 lines (152 loc) · 7.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# BSD 3-Clause License (see LICENSE file)
# Copyright (c) Image and Signaling Process Group (ISP) IPL-UV 2021
# All rights reserved.
"""
Main script to execute the Latent Granger autoencoder
"""
import os
import git
import argparse
import yaml
#from uuid import uuid4
from datetime import datetime
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
# Model
import archs
import loaders
import torch
#torch.autograd.set_detect_anomaly(True)
def main(args):
# Load YAML config files into a dict variable
with open(f'configs/archs/{args.arch}.yaml') as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python dictionary format
arch_config = yaml.load(file, Loader=yaml.FullLoader)
with open(f'configs/loaders/{args.loader}.yaml') as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python dictionary format
loader_config = yaml.load(file, Loader=yaml.FullLoader)
with open(f'configs/data/{args.data}.yaml') as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python dictionary format
data_config = yaml.load(file, Loader=yaml.FullLoader)
# Experiment ID
repo = git.Repo(search_parent_directories=True)
if repo.is_dirty() and not args.nogitcheck:
kg = input("WARNING: the current repo has not tracked changes\n" +
"if you continue, any saved results will" +
" not be correctly associated to a commit hash\n" +
"type yes(y) to continue anyway: ")
if kg != "yes" and kg != "y":
return
git_commit_sha = repo.head.object.hexsha[:7]
#experiment_id = str(uuid4())
dt = str(datetime.now())
log_dir = os.path.join(args.dir, 'logs')
checkpoints_dir = os.path.join(args.dir, 'checkpoints',
args.data, args.arch,
dt)
#,experiment_id)
os.makedirs(args.dir, exist_ok=True)
pathlogfile = os.path.join(args.dir, 'log.txt')
savedir = os.path.join(args.dir, 'latents')
os.makedirs(savedir, exist_ok = True)
# Build model
if arch_config['processing_mode'] == 'flat':
input_size = data_config['flat_input_size']
else:
input_size = tuple(data_config['input_size'])
model_class = getattr(archs, arch_config['class'])
if args.seed >= 0:
print(f'seed set to {args.seed}')
pl.utilities.seed.seed_everything(seed=args.seed)
model = model_class(arch_config, input_size, data_config['tpb'],
args.maxlag, args.gamma, args.gltype)
print(model)
# Build data module
datamodule_class = getattr(loaders, loader_config['class'])
datamodule = datamodule_class(loader_config, data_config,
arch_config['processing_mode'])
# Loggers
tb_logger = pl_loggers.TensorBoardLogger(log_dir, name=f'{args.arch}_{args.data}',
version=dt)
# Callbacks
# Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(dirpath=checkpoints_dir,
filename='best',
mode='min', monitor='val_loss',
save_last=False, save_top_k=1)
early_stopping = EarlyStopping(monitor='val_loss',
min_delta=0.0, patience=10,
verbose=False, mode='min', strict=True)
#lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [checkpoint_callback]
if args.earlystop:
callbacks += [early_stopping]
trainer = pl.Trainer.from_argparse_args(args, logger=[tb_logger],
callbacks=callbacks)
# Training
trainer.fit(model, datamodule)
# get validation losses
val = trainer.validate(ckpt_path='best')
mse_val = val[0]["mse_loss"]["val"].detach().numpy()
granger_val = val[0]["granger_loss_min"]["val"].detach().numpy()
loss_val = val[0]["loss"]["val"].detach().numpy()
# Test
res = trainer.test(ckpt_path='best')
mse_test = res[0]["mse_loss"]["test"].detach().numpy()
granger_test = res[0]["granger_loss_min"]["test"].detach().numpy()
loss_test = res[0]["loss"]["test"].detach().numpy()
with open(pathlogfile, "a") as logfile:
logfile.write(f'{dt},{git_commit_sha},{args.arch},' +
f'{args.data},{args.loader},{args.gltype},{args.gamma},{args.maxlag},' +
f'{loss_val},{mse_val},{granger_val},' +
f'{loss_test},{mse_test},{granger_test}\n')
# Predict
pred = trainer.predict(ckpt_path='best')
x_out, x_latent, mu, sigma, causalix = pred[0]
x, target = datamodule.data_predict.getAll()
# Save latents
np.savetxt(os.path.join(savedir, f'{dt}_causal_latent.csv'), mu.detach().numpy()[:,
int(causalix.numpy())])
np.savetxt(os.path.join(savedir, f'{dt}_all_latents.csv'), mu.detach().numpy())
# save target
np.savetxt(os.path.join(savedir, f'{dt}_target.csv'), target.detach().numpy())
if __name__ == '__main__':
# ArgParse
parser = argparse.ArgumentParser(description="ArgParse")
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument('--arch', default='vae', type=str,
help='name of the architecture associated' +
'to a config file' +
'in configs/archs/')
parser.add_argument('-d', '--data', default='toy', type=str,
help='database name (default: toy) associated to a ' +
'config file in configs/data/')
parser.add_argument('--loader', default='base', type=str,
help='loaders name (default: base) associated ' +
'to a config file in configs/loaders/')
parser.add_argument('--maxlag', default=1, type=int,
help='maxlag (default: 1)')
parser.add_argument('-g', '--gamma', default=0, type=float,
help='gamma regulazier for granger' +
'penalty (default: 0)')
parser.add_argument('--earlystop', action='store_true',
help='whether to use early stopping')
parser.add_argument('--dir', default="experiment",
type=str, help='experiemnt directory')
parser.add_argument('--seed', default=-1,
type=int, help='seed if >0')
parser.add_argument('--nogitcheck', action='store_true',
help='do not check clean git')
parser.add_argument('--gltype', type=str,
default='diff',
help='type of granger loss')
args = parser.parse_args()
main(args)