forked from mrahtz/learning-from-human-preferences
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nn_layers.py
46 lines (36 loc) · 1.12 KB
/
nn_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
"""
Wrappers for TensorFlow's layers integrating batchnorm in the right place.
"""
def conv_layer(x, filters, kernel_size, strides, batchnorm, training, name,
reuse, activation='relu'):
x = tf.layers.conv2d(
x,
filters,
kernel_size,
strides,
activation=None,
name=name,
reuse=reuse)
if batchnorm:
batchnorm_name = name + "_batchnorm"
x = tf.layers.batch_normalization(
x, training=training, reuse=reuse, name=batchnorm_name)
if activation == 'relu':
x = tf.nn.leaky_relu(x, alpha=0.01)
else:
raise Exception("Unknown activation for conv_layer", activation)
return x
def dense_layer(x,
units,
name,
reuse,
activation=None):
x = tf.layers.dense(x, units, activation=None, name=name, reuse=reuse)
if activation is None:
pass
elif activation == 'relu':
x = tf.nn.leaky_relu(x, alpha=0.01)
else:
raise Exception("Unknown activation for dense_layer", activation)
return x