-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain-fcn-unet.py
executable file
·84 lines (72 loc) · 3.11 KB
/
train-fcn-unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
import tensorflow as tf
import aardvark
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('re_weight', 0.0001, 'regularization weight')
class Model (aardvark.SegmentationModel):
def __init__ (self):
super().__init__()
pass
def inference (self, images, classes, is_training):
self.backbone, backbone_stride = myunet(self.images-127.0, self.is_training)
assert FLAGS.clip_stride % backbone_stride == 0
return tf.layers.conv2d_transpose(self.backbone, classes, 3, 1, activation=None, padding='SAME')
pass
def myunet (X, is_training):
BN = True
net = X
stack = []
with tf.name_scope('myunet'):
regularizer = tf.contrib.layers.l2_regularizer(scale=FLAGS.re_weight)
def conv2d (input, channels, filter_size, stride):
if BN:
input = tf.layers.conv2d(input, channels, filter_size, stride, padding='SAME', activation=None, kernel_regularizer=regularizer)
input = tf.layers.batch_normalization(input, training=is_training)
return tf.nn.relu(input)
return tf.layers.conv2d(input, channels, filter_size, stride, padding='SAME', kernel_regularizer=regularizer, activation=tf.nn.relu)
def max_pool2d (input, filter_size, stride):
return tf.layers.max_pooling2d(input, filter_size, stride, padding='SAME')
def conv2d_transpose (input, channels, filter_size, stride):
if BN:
input = tf.layers.conv2d_transpose(input, channels, filter_size, stride, padding='SAME', activation=None, kernel_regularizer=regularizer)
input = tf.layers.batch_normalization(input, training=is_training)
return tf.nn.relu(input)
return tf.layers.conv2d_transpose(input, channels, filter_size, stride, padding='SAME', kernel_regularizer=regularizer, activation=tf.nn.relu)
net = conv2d(net, 32, 3, 2)
net = conv2d(net, 32, 3, 1)
stack.append(net) # 1/2
net = conv2d(net, 64, 3, 1)
net = conv2d(net, 64, 3, 1)
net = max_pool2d(net, 2, 2)
stack.append(net) # 1/4
net = conv2d(net, 128, 3, 1)
net = conv2d(net, 128, 3, 1)
net = max_pool2d(net, 2, 2)
stack.append(net) # 1/8
net = conv2d(net, 256, 3, 1)
net = conv2d(net, 256, 3, 1)
net = max_pool2d(net, 2, 2)
# 1/16
net = conv2d(net, 256, 3, 1)
net = conv2d(net, 256, 3, 1)
net = conv2d_transpose(net, 128, 5, 2)
# 1/8
net = tf.concat([net, stack.pop()], 3)
net = conv2d_transpose(net, 64, 5, 2)
# 1/4
net = tf.concat([net, stack.pop()], 3)
net = conv2d_transpose(net, 32, 5, 2)
net = tf.concat([net, stack.pop()], 3)
net = conv2d_transpose(net, 16, 5, 2)
assert len(stack) == 0
return net, 16
def main (_):
model = Model()
aardvark.train(model)
pass
if __name__ == '__main__':
try:
tf.app.run()
except KeyboardInterrupt:
pass