-
Notifications
You must be signed in to change notification settings - Fork 844
/
tensorpack_extension.py
64 lines (58 loc) · 1.95 KB
/
tensorpack_extension.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding: utf-8 -*-
#!/usr/bin/env python
import re
from tensorpack.utils import logger
from tensorpack.tfutils.gradproc import GradientProcessor
from tensorpack.callbacks.monitor import JSONWriter
import tensorflow as tf
# class AudioWriter(TrainingMonitor):
# """
# Write summaries to TensorFlow event file.
# """
# def __new__(cls):
# if logger.get_logger_dir():
# return super(TFEventWriter, cls).__new__(cls)
# else:
# logger.warn("logger directory was not set. Ignore TFEventWriter.")
# return NoOpMonitor()
#
# def _setup_graph(self):
# self._writer = tf.summary.FileWriter(logger.get_logger_dir(), graph=tf.get_default_graph())
#
# def process_summary(self, summary):
# self._writer.add_summary(summary, self.global_step)
#
# def process_event(self, evt):
# self._writer.add_event(evt)
#
# def _trigger(self): # flush every epoch
# self._writer.flush()
#
# def _after_train(self):
# self._writer.close()
#
class FilterGradientVariables(GradientProcessor):
"""
Skip the update of certain variables and print a warning.
"""
def __init__(self, var_regex='.*', verbose=True):
"""
Args:
var_regex (string): regular expression to match variable to update.
verbose (bool): whether to print warning about None gradients.
"""
super(FilterGradientVariables, self).__init__()
self._regex = var_regex
self._verbose = verbose
def _process(self, grads):
g = []
to_print = []
for grad, var in grads:
if re.match(self._regex, var.op.name):
g.append((grad, var))
else:
to_print.append(var.op.name)
if self._verbose and len(to_print):
message = ', '.join(to_print)
logger.warn("No gradient w.r.t these trainable variables: {}".format(message))
return g