-
Notifications
You must be signed in to change notification settings - Fork 7
/
plots.py
107 lines (85 loc) · 3.19 KB
/
plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import glob
import json
import tensorflow.compat.v2 as tf
import pandas as pd
import matplotlib.pyplot as plt
from absl import flags
FLAGS = flags.FLAGS
def print_contrastive_history(training_history, flags):
con_acc = training_history['train/contrast_acc']
nrows = training_history.shape[0]
if flags['checkpoint_epochs'] == 1 and flags['checkpoint_steps'] == 0:
xvalues = pd.Series(range(nrows))+1
plt.xlabel('Epochs')
else:
xvalues = training_history['global_step']
plt.xlabel('Steps')
plt.plot(xvalues, con_acc, color='black', label='Contrastive accuracy')
plt.title('Accuracy')
plt.ylabel('Accuracy')
plt.legend()
save_path = os.path.join(FLAGS.tmp_folder, 'Accuracy.jpeg')
plt.savefig(save_path)
plt.close()
if tf.io.gfile.exists(save_path):
dest_path = os.path.join(FLAGS.model_dir, 'Accuracy.jpeg')
tf.io.gfile.copy(save_path, dest_path, overwrite=True)
return
def print_accuracy_history(training_history, flags):
acc = training_history['train/supervised_acc']
val_acc = training_history['eval/label_top_1_accuracy']
nrows = training_history.shape[0]
if flags['checkpoint_epochs'] == 1 and flags['checkpoint_steps'] == 0:
xvalues = pd.Series(range(nrows))+1
plt.xlabel('Epochs')
else:
xvalues = training_history['global_step']
plt.xlabel('Steps')
plt.plot(xvalues, acc, color='darkorange', label='Train')
plt.plot(xvalues, val_acc, color='steelblue', label='Valid')
plt.title('Accuracy')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig(os.path.join(FLAGS.model_dir, 'Accuracy.jpeg'))
plt.close()
return
def print_loss_history(training_history, flags, logscale=False):
val_sup_loss = training_history['eval/supervised_loss']
train_sup_loss = training_history['train/supervised_loss']
nrows = training_history.shape[0]
if flags['checkpoint_epochs'] == 1 and flags['checkpoint_steps'] == 0:
xvalues = pd.Series(range(nrows))+1
plt.xlabel('Epochs')
else:
xvalues = training_history['global_step']
plt.xlabel('Steps')
plt.plot(xvalues, train_sup_loss, color='darkorange', label='Train')
plt.plot(xvalues, val_sup_loss, color='steelblue', label='Valid')
plt.title('Supervised Loss')
plt.ylabel('Loss')
plt.legend()
if logscale:
plt.yscale('log')
plt.savefig(os.path.join(FLAGS.model_dir, 'Loss.jpeg'))
plt.close()
return
def gen_plots():
"""Generate train plots."""
def create_df(fnames):
results = []
for fname in fnames:
with tf.io.gfile.GFile(fname, 'r') as f:
result = json.load(f)
results.append(result)
df = pd.DataFrame(results)
return df
metric_paths = tf.io.gfile.glob(os.path.join(FLAGS.model_dir, 'metric_[0-9]*.json'))
train_df = create_df(metric_paths)
train_df.sort_values(by=['global_step'], inplace=True)
flags_path = os.path.join(FLAGS.model_dir, 'flags.json')
with tf.io.gfile.GFile(flags_path, 'r') as f:
flags_dict = json.load(f)
training = (FLAGS.mode == 'train' or FLAGS.mode == 'train_then_eval')
if training and FLAGS.train_mode == 'pretrain':
print_contrastive_history(train_df, flags_dict)