-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_model.py
56 lines (47 loc) · 2.21 KB
/
my_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import tensorflow as tf
def get_my_model(conv):
MODEL_LAYERS = [
'conv1','maxpool1','conv2','maxpool1','conv3','maxpool2'
]
model_arg =[
(3,32,3,1),(2,2),(32,64,3,1),(2,2),(64,144,3,2),(2,1)
]
def instance_norm(x):
epsilon = 1e-9
mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon)))
def relu(layer):
return tf.nn.relu(layer)
def max_pool(layer,ksize,stride):
return tf.nn.max_pool(layer,ksize=[1,ksize,ksize*4,1],strides=[1,stride,stride,1],padding='SAME')
def conv2d(x,input_filter,output_filter, kernal, strides,scale = 4):
with tf.variable_scope('conv2d'):
shape=[kernal,kernal*scale,input_filter,output_filter]
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
return tf.nn.conv2d(x,filter = weight,strides=[1,strides,strides,1],
padding='SAME',name='conv')
def reslayer(x,filter,kernel,strides):
with tf.variable_scope('resnet'):
conv1 = conv2d(x, filter, filter, kernel, strides)
conv2 = conv2d(relu(conv1), filter, filter, kernel, strides)
residual = x + conv2
return residual
for name,arg_num in zip(MODEL_LAYERS,model_arg):
if name.startswith('c'):
with tf.variable_scope(name):
conv = relu(instance_norm(conv2d(conv,arg_num[0],arg_num[1],arg_num[2],arg_num[3],scale=4)))
elif name.startswith('r'):
with tf.variable_scope(name):
conv = relu(instance_norm(reslayer(conv,arg_num[0],arg_num[1],arg_num[2])))
elif name.startswith('m'):
with tf.variable_scope(name):
conv = max_pool(conv,arg_num[0],arg_num[1])
elif name.startswith('d'):
with tf.variable_scope(name):
conv = tf.nn.dropout(conv,keep_prob=0.75);
#return (tf.shape(conv));
conv = tf.contrib.layers.flatten(conv)
#conv = tf.contrib.layers.fully_connected(conv,1024)
#conv = tf.contrib.layers.fully_connected(conv, 512)
conv = tf.contrib.layers.fully_connected(conv, 144, activation_fn=None)
return conv