-
Notifications
You must be signed in to change notification settings - Fork 8
/
util.py
84 lines (67 loc) · 3.32 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Contains custom loss, dice coefficient, and optimizer classes."""
import tensorflow as tf
class DiceVAELoss(object):
"""Implements custom dice-VAE loss."""
def __init__(self,
name='custom_loss',
data_format='channels_last',
**kwargs):
self.axis = (0, 1, 2, 3) if data_format == 'channels_last' else (0, 2, 3, 4)
def __call__(self, x, y, y_pred, y_vae, z_mean, z_logvar, sample_weight=None):
l2_loss = tf.reduce_mean((x - y_vae) ** 2)
kld_loss = tf.reduce_mean(z_mean ** 2 + tf.math.exp(z_logvar) - z_logvar - 1.0)
# Calculate dice loss.
intersection = tf.reduce_sum(y_pred * y, axis=self.axis)
pred = tf.reduce_sum(y_pred ** 2, axis=self.axis)
true = tf.reduce_sum(y ** 2, axis=self.axis)
dice_loss = tf.reduce_mean(1.0 - (2.0 * intersection + 1.0) / (pred + true + 1.0))
return dice_loss + 0.1*l2_loss + 0.1*kld_loss
class DiceCoefficient(object):
"""Implements dice coefficient for binary classification."""
def __init__(self,
name='dice_coefficient',
data_format='channels_last'):
self.name = name
self.data_format = data_format
def __call__(self, y_true, y_pred):
dice_axes = (0, 1, 2) if self.data_format == 'channels_last' else (0, 2, 3, 4)
onehot_axis = -1 if self.data_format == 'channels_last' else 1
# Mask out values that correspond to values < 0.5.
mask = tf.reduce_max(y_pred, axis=onehot_axis, keepdims=True)
mask = tf.cast(mask > 0.5, tf.float32)
# Create one-hot encoding of predictions.
out_ch = y_pred.shape[onehot_axis]
y_pred = tf.argmax(y_pred, axis=onehot_axis, output_type=tf.int32)
y_pred = tf.one_hot(y_pred, out_ch, axis=onehot_axis, dtype=tf.float32)
y_pred *= mask
# Compute dice score.
intersection = tf.reduce_sum(y_pred * y_true, axis=dice_axes)
pred = tf.reduce_sum(y_pred, axis=dice_axes)
true = tf.reduce_sum(y_true, axis=dice_axes)
macroavg = tf.reduce_mean((2.0 * intersection + 1.0) / (pred + true + 1.0))
microavg = tf.reduce_sum(y_pred * y_true) / (tf.reduce_sum(y_pred) + tf.reduce_sum(y_true))
return macroavg, microavg
class ScheduledOptim(tf.keras.optimizers.Adam):
"""Adam optimizer that allows for scheduling every epoch."""
def __init__(self,
learning_rate=1e-4,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
name='Adam',
n_epochs=300,
**kwargs):
super(ScheduledOptim, self).__init__(
learning_rate=learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon,
amsgrad=amsgrad,
name=name,
**kwargs)
self.init_lr = tf.constant(learning_rate, dtype=tf.float32)
self.n_epochs = float(n_epochs)
def __call__(self, epoch):
new_lr = self.init_lr * ((1.0 - epoch / self.n_epochs) ** 0.9)
self._set_hyper('learning_rate', new_lr)