-
Notifications
You must be signed in to change notification settings - Fork 0
/
simpleLogger.py
105 lines (89 loc) · 3.65 KB
/
simpleLogger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import collections
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from pytorch_lightning.loggers.logger import Logger
# from pytorch_lightning.loggers import TensorBoardLogger
# from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
import pdb
class mySimpleLogger(Logger):
def __init__(self, log_dir, keys=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self._log_dir = log_dir
self.keys = ['epoch', 'train_loss_step', 'train_loss_epoch', 'val_loss']
for k in keys: self.keys.append(k)
self.history = collections.defaultdict(list) # copy not necessary here
# The defaultdict in contrast will simply create any items that you try to access
@property
def name(self):
return "my_simple_logger"
@property
def experiment(self):
# Return the experiment version, int or str.
return "default"
@property
def version(self):
return "1.0"
@rank_zero_only
def log_hyperparams(self, params):
# params is an argparse.Namespace
# your code to record hyperparameters goes here
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
for metric_name, metric_value in metrics.items():
if not metric_name in self.keys:
continue
if metric_name != 'epoch':
self.history[metric_name].append(metric_value)
else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses.
if (not len(self.history['epoch']) or # len == 0:
not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add.
self.history['epoch'].append(metric_value)
else:
pass
return
@rank_zero_only
def save(self):
# Optional. Any code necessary to save logger data goes here
# If you implement this, remember to call `super().save()`
# at the start of the method (important for aggregation of metrics)
super().save()
# with open(f"{self._log_dir}/my_logger_state.pkl", "wb") as f:
# pickle.dump(self.state_dict(), f)
@property
def log_dir(self):
return self._log_dir
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
plt.figure()
plt.plot(self.history["train_loss_epoch"], label='train')
plt.plot(self.history["val_loss"], label='val')
plt.xlabel('batch steps')
plt.title("loss")
plt.gca().legend()
plt.tight_layout()
filename = "loss.png"
plt.savefig(os.path.join(self._log_dir, filename))
plt.close()
for key in self.history.keys():
if not any([x in key for x in ["train_loss_epoch", "val_loss", "epoch"]]):
plt.figure()
plt.plot(self.history[key])
plt.title(key)
plt.tight_layout()
filename = os.path.join(self._log_dir, f"{key}.png")
plt.savefig(filename)
plt.close()
for key in self.history.keys():
arr = np.asarray(self.history[key])
if arr.ndim == 0:
arr = np.expand_dims(arr, axis=0)
filename = os.path.join(self._log_dir, f"{key}.csv")
np.savetxt(filename, arr, delimiter=',')