-
Notifications
You must be signed in to change notification settings - Fork 10
/
autoclip_tf.py
55 lines (44 loc) · 2.05 KB
/
autoclip_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import tensorflow as tf
import tensorflow_probability as tfp
class AutoClipper:
def __init__(self, clip_percentile, history_size=10000):
self.clip_percentile = clip_percentile
self.grad_history = tf.Variable(tf.zeros(history_size), trainable=False)
self.i = tf.Variable(0, trainable=False)
self.history_size = history_size
def __call__(self, grads_and_vars):
grad_norms = [self._get_grad_norm(g) for g, _ in grads_and_vars]
total_norm = tf.norm(grad_norms)
assign_idx = tf.math.mod(self.i, self.history_size)
self.grad_history = self.grad_history[assign_idx].assign(total_norm)
self.i = self.i.assign_add(1)
clip_value = tfp.stats.percentile(self.grad_history[: self.i], q=self.clip_percentile)
return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars]
def _get_grad_norm(self, t, axes=None, name=None):
values = tf.convert_to_tensor(t.values if isinstance(t, tf.IndexedSlices) else t, name="t")
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
l2sum = tf.math.reduce_sum(values * values, axes, keepdims=True)
pred = l2sum > 0
# Two-tap tf.where trick to bypass NaN gradients
l2sum_safe = tf.where(pred, l2sum, tf.ones_like(l2sum))
return tf.squeeze(tf.where(pred, tf.math.sqrt(l2sum_safe), l2sum))
if __name__ == "__main__":
# Example usage
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10),
]
)
model.compile(
optimizer=tf.keras.optimizers.Adam(
learning_rate=0.001, gradient_transformers=[AutoClipper(10)]
),
loss="mean_absolute_error",
metrics=["accuracy"],
)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model.fit(x_train, y_train)