forked from uzh-rpg/rpg_public_dronet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
log_utils.py
54 lines (39 loc) · 1.81 KB
/
log_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import logz
import numpy as np
import keras
from keras import backend as K
class MyCallback(keras.callbacks.Callback):
"""
Customized callback class.
# Arguments
filepath: Path to save model.
period: Frequency in epochs with which model is saved.
batch_size: Number of images per batch.
"""
def __init__(self, filepath, period, batch_size):
self.filepath = filepath
self.period = period
self.batch_size = batch_size
def on_epoch_begin(self, epoch, logs=None):
# Decrease weight for binary cross-entropy loss
sess = K.get_session()
# self.model.beta.load(np.maximum(0.0, 1.0-np.exp(-1.0/10.0*(epoch-10))), sess)
self.model.alpha.load(1.0, sess)
def on_epoch_end(self, epoch, logs={}):
# Save training and validation losses
logz.log_tabular('train_loss', logs.get('loss'))
logz.log_tabular('val_loss', logs.get('val_loss'))
logz.log_tabular('train_accuracy', logs.get('categorical_accuracy'))
logz.log_tabular('val_accuracy', logs.get('val_categorical_accuracy'))
logz.dump_tabular()
# Save model every 'period' epochs
if (epoch+1) % self.period == 0:
filename = self.filepath + '/model_weights_' + str(epoch) + '.h5'
print("Saved model at {}".format(filename))
self.model.save_weights(filename, overwrite=True)
# Hard mining
sess = K.get_session()
# mse_function = self.batch_size-(self.batch_size-10)*(np.maximum(0.0,1.0-np.exp(-1.0/30.0*(epoch-30.0))))
entropy_function = self.batch_size-(self.batch_size-10)*(np.maximum(0.0,1.0-np.exp(-1.0/30.0*(epoch-30.0))))
# self.model.k_mse.load(int(np.round(mse_function)), sess)
self.model.k_entropy.load(int(np.round(entropy_function)), sess)