-
Notifications
You must be signed in to change notification settings - Fork 14
/
electricity.py
96 lines (77 loc) · 3.15 KB
/
electricity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
torch.set_default_tensor_type(torch.DoubleTensor)
# import argparse
from config import *
from UKDALE_Parser import *
from REDD_Parser import *
from Refit_Parser import *
from Electricity_model import *
from NILM_Dataloader import *
from Trainer import *
from time import time
import pickle as pkl
if __name__ == "__main__":
args = get_args()
setup_seed(args.seed)
if args.dataset_code == 'redd_lf':
args.house_indicies = [2, 3, 4, 5, 6]
ds_parser = Redd_Parser(args)
elif args.dataset_code == 'uk_dale':
args.house_indicies = [1, 3, 4, 5]
ds_parser = UK_Dale_Parser(args)
elif args.dataset_code == 'refit':
args.house_indicies = [2,3,16]
args.sampling = '7s'
ds_parser = Refit_Parser(args)
model = ELECTRICITY(args)
trainer = Trainer(args,ds_parser,model)
#Training Loop
start_time = time()
if args.num_epochs > 0:
try:
model.load_state_dict(torch.load(os.path.join(trainer.export_root, 'best_acc_model.pth'), map_location='cpu'))
print('Successfully loaded previous model, continue training...')
except FileNotFoundError:
print('Failed to load old model, continue training new model...')
trainer.train()
end_time = time()
training_time = end_time-start_time
print("Total Training Time: " + str(training_time/60) + "minutes")
#Testing Loop
args.validation_size = 1.
x_mean = trainer.x_mean.detach().cpu().numpy()
x_std = trainer.x_std.detach().cpu().numpy()
stats = (x_mean,x_std)
if args.dataset_code == 'redd_lf':
args.house_indicies = [1]
ds_parser = Redd_Parser(args, stats)
elif args.dataset_code == 'uk_dale':
args.house_indicies = [2]
ds_parser = UK_Dale_Parser(args, stats)
elif args.dataset_code == 'refit':
args.house_indicies = [5]
ds_parser = Refit_Parser(args)
dataloader = NILMDataloader(args, ds_parser)
_, test_loader = dataloader.get_dataloaders()
mre, mae, acc, prec, recall, f1 = trainer.test(test_loader)
print('Mean Accuracy:', acc)
print('Mean F1-Score:', f1)
print('MAE:', mae)
print('MRE:', mre)
results = dict()
results['args'] = args
results['training_time'] = training_time/60
results['best_epoch'] = trainer.best_model_epoch
results['training_loss'] = trainer.training_loss
results['val_rel_err'] = trainer.test_metrics_dict['mre']
results['val_abs_err'] = trainer.test_metrics_dict['mae']
results['val_acc'] = trainer.test_metrics_dict['acc']
results['val_precision'] = trainer.test_metrics_dict['precision']
results['val_recall'] = trainer.test_metrics_dict['recall']
results['val_f1'] = trainer.test_metrics_dict['f1']
results['label_curve'] = trainer.y_curve
results['e_pred_curve'] = trainer.y_pred_curve
results['status_curve'] = trainer.status_curve
results['s_pred_curve'] = trainer.s_pred_curve
fname = trainer.export_root.joinpath('results.pkl')
pkl.dump(results,open( fname, "wb" ))