-
Notifications
You must be signed in to change notification settings - Fork 52
/
model.py
115 lines (89 loc) · 4.9 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import tensorflow as tf
def conv2d(x, input_filters, output_filters, kernel, strides, mode='REFLECT'):
with tf.variable_scope('conv') as scope:
shape = [kernel, kernel, input_filters, output_filters]
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
x_padded = tf.pad(x, [[0,0], [kernel / 2, kernel / 2], [kernel / 2, kernel / 2], [0,0]], mode=mode)
return tf.nn.conv2d(x_padded, weight, strides=[1, strides, strides, 1], padding='VALID', name='conv')
def conv2d_transpose(x, input_filters, output_filters, kernel, strides):
with tf.variable_scope('conv_transpose') as scope:
shape = [kernel, kernel, output_filters, input_filters]
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
batch_size = tf.shape(x)[0]
height = tf.shape(x)[1] * strides
width = tf.shape(x)[2] * strides
output_shape = tf.pack([batch_size, height, width, output_filters])
return tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], name='conv_transpose')
def resize_conv2d(x, input_filters, output_filters, kernel, strides, training):
'''
An alternative to transposed convolution where we first resize, then convolve.
See http://distill.pub/2016/deconv-checkerboard/
For some reason the shape needs to be statically known for gradient propagation
through tf.image.resize_images, but we only know that for fixed image size, so we
plumb through a "training" argument
'''
with tf.variable_scope('conv_transpose') as scope:
height = x.get_shape()[1].value if training else tf.shape(x)[1]
width = x.get_shape()[2].value if training else tf.shape(x)[2]
new_height = height * strides * 2
new_width = width * strides * 2
x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
shape = [kernel, kernel, input_filters, output_filters]
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
return conv2d(x_resized, input_filters, output_filters, kernel, strides)
def instance_norm(x):
epsilon = 1e-9
mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
return tf.div(tf.sub(x, mean), tf.sqrt(tf.add(var, epsilon)))
def batch_norm(x, size, training, decay=0.999):
beta = tf.Variable(tf.zeros([size]), name='beta')
scale = tf.Variable(tf.ones([size]), name='scale')
pop_mean = tf.Variable(tf.zeros([size]))
pop_var = tf.Variable(tf.ones([size]))
epsilon = 1e-3
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
def batch_statistics():
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm')
return tf.cond(training, batch_statistics, population_statistics)
def residual(x, filters, kernel, strides):
with tf.variable_scope('residual') as scope:
conv1 = conv2d(x, filters, filters, kernel, strides)
conv2 = conv2d(tf.nn.relu(conv1), filters, filters, kernel, strides)
residual = x + conv2
return residual
def net(image, training):
# Less border effects when padding a little before passing through ..
image = tf.pad(image, [[0,0], [10,10], [10,10],[0,0]], mode='REFLECT')
with tf.variable_scope('conv1'):
conv1 = tf.nn.relu(instance_norm(conv2d(image, 3, 32, 9, 1)))
with tf.variable_scope('conv2'):
conv2 = tf.nn.relu(instance_norm(conv2d(conv1, 32, 64, 3, 2)))
with tf.variable_scope('conv3'):
conv3 = tf.nn.relu(instance_norm(conv2d(conv2, 64, 128, 3, 2)))
with tf.variable_scope('res1'):
res1 = residual(conv3, 128, 3, 1)
with tf.variable_scope('res2'):
res2 = residual(res1, 128, 3, 1)
with tf.variable_scope('res3'):
res3 = residual(res2, 128, 3, 1)
with tf.variable_scope('res4'):
res4 = residual(res3, 128, 3, 1)
with tf.variable_scope('res5'):
res5 = residual(res4, 128, 3, 1)
with tf.variable_scope('deconv1'):
deconv1 = tf.nn.relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training)))
with tf.variable_scope('deconv2'):
deconv2 = tf.nn.relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training)))
with tf.variable_scope('deconv3'):
deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1)))
y = (deconv3+1) * 127.5
# Remove border effect reducing padding.
height = tf.shape(y)[1]
width = tf.shape(y)[2]
y = tf.slice(y, [0, 10, 10, 0], tf.pack([-1, height - 20, width - 20, -1]))
return y