-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
82 lines (68 loc) · 2.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from tensorboardX import SummaryWriter
class TensorboardLogger(object):
def __init__(self, log_dir):
self.writer = SummaryWriter(logdir=log_dir)
self.step = 0
def set_step(self, step=None):
if step is not None:
self.step = step
else:
self.step += 1
def update(self, head='scalar', step=None, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
def flush(self):
self.writer.flush()
class WandbLogger(object):
def __init__(self, args):
self.args = args
try:
import wandb
self._wandb = wandb
except ImportError:
raise ImportError(
"To use the Weights and Biases Logger please install wandb."
"Run `pip install wandb` to install it."
)
# Initialize a W&B run
if self._wandb.run is None:
self._wandb.init(
project=args.project,
name = args.model,
config=args
)
def log_epoch_metrics(self, metrics, commit=True):
"""
Log train/test metrics onto W&B.
"""
# Log number of model parameters as W&B summary
self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
metrics.pop('n_parameters', None)
# Log current epoch
self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
metrics.pop('epoch')
for k, v in metrics.items():
if 'train' in k:
self._wandb.log({f'Global Train/{k}': v}, commit=False)
elif 'test' in k:
self._wandb.log({f'Global Test/{k}': v}, commit=False)
self._wandb.log({})
def log_checkpoints(self):
output_dir = self.args.output_dir
model_artifact = self._wandb.Artifact(
self._wandb.run.id + "_model", type="model"
)
model_artifact.add_dir(output_dir)
self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
def set_steps(self):
# Set global training step
self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
# Set epoch-wise step
self._wandb.define_metric('Global Train/*', step_metric='epoch')
self._wandb.define_metric('Global Test/*', step_metric='epoch')