Skip to content

Commit cf65784

Browse files
Clean up in training loop
1 parent 78333e9 commit cf65784

File tree

21 files changed

+246
-582
lines changed

21 files changed

+246
-582
lines changed

README.md

-14
Original file line numberDiff line numberDiff line change
@@ -254,20 +254,6 @@ $ python3 gan.py
254254
<img src="http://eriklindernoren.se/images/gan_mnist5.gif" width="640"\>
255255
</p>
256256

257-
GAN on RGB face images
258-
[Code](gan/gan_rgb.py)
259-
260-
#### Example
261-
```
262-
$ cd gan/
263-
<follow steps at the top of gan_rgb.py>
264-
$ python3 gan_rgb.py
265-
```
266-
267-
<p align="center">
268-
<img src="gan/etc/adam.gif" width="640"\>
269-
</p>
270-
271257
### InfoGAN
272258
Implementation of _InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets_.
273259

aae/aae.py

+19-23
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,20 @@ def build_encoder(self):
7474
return Model(img, latent_repr)
7575

7676
def build_decoder(self):
77-
# Decoder
78-
decoder = Sequential()
7977

80-
decoder.add(Dense(512, input_dim=self.latent_dim))
81-
decoder.add(LeakyReLU(alpha=0.2))
82-
decoder.add(Dense(512))
83-
decoder.add(LeakyReLU(alpha=0.2))
84-
decoder.add(Dense(np.prod(self.img_shape), activation='tanh'))
85-
decoder.add(Reshape(self.img_shape))
78+
model = Sequential()
79+
80+
model.add(Dense(512, input_dim=self.latent_dim))
81+
model.add(LeakyReLU(alpha=0.2))
82+
model.add(Dense(512))
83+
model.add(LeakyReLU(alpha=0.2))
84+
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
85+
model.add(Reshape(self.img_shape))
8686

87-
decoder.summary()
87+
model.summary()
8888

8989
z = Input(shape=(self.latent_dim,))
90-
img = decoder(z)
90+
img = model(z)
9191

9292
return Model(z, img)
9393

@@ -116,38 +116,34 @@ def train(self, epochs, batch_size=128, sample_interval=50):
116116
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
117117
X_train = np.expand_dims(X_train, axis=3)
118118

119-
half_batch = int(batch_size / 2)
119+
# Adversarial ground truths
120+
valid = np.ones((batch_size, 1))
121+
fake = np.zeros((batch_size, 1))
120122

121123
for epoch in range(epochs):
122124

123-
124125
# ---------------------
125126
# Train Discriminator
126127
# ---------------------
127128

128-
# Select a random half batch of images
129-
idx = np.random.randint(0, X_train.shape[0], half_batch)
129+
# Select a random batch of images
130+
idx = np.random.randint(0, X_train.shape[0], batch_size)
130131
imgs = X_train[idx]
131132

132133
latent_fake = self.encoder.predict(imgs)
133-
latent_real = np.random.normal(size=(half_batch, self.latent_dim))
134+
latent_real = np.random.normal(size=(batch_size, self.latent_dim))
134135

135136
# Train the discriminator
136-
d_loss_real = self.discriminator.train_on_batch(latent_real, np.ones((half_batch, 1)))
137-
d_loss_fake = self.discriminator.train_on_batch(latent_fake, np.zeros((half_batch, 1)))
137+
d_loss_real = self.discriminator.train_on_batch(latent_real, valid)
138+
d_loss_fake = self.discriminator.train_on_batch(latent_fake, fake)
138139
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
139140

140-
141141
# ---------------------
142142
# Train Generator
143143
# ---------------------
144144

145-
# Select a random half batch of images
146-
idx = np.random.randint(0, X_train.shape[0], batch_size)
147-
imgs = X_train[idx]
148-
149145
# Train the generator
150-
g_loss = self.adversarial_autoencoder.train_on_batch(imgs, [imgs, np.ones((batch_size, 1))])
146+
g_loss = self.adversarial_autoencoder.train_on_batch(imgs, [imgs, valid])
151147

152148
# Plot the progress
153149
print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))

acgan/acgan.py

+14-27
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def __init__(self):
4747
# and the label of that image
4848
valid, target_label = self.discriminator(img)
4949

50-
# The combined model (stacked generator and discriminator) takes
51-
# noise as input => generates images => determines validity
50+
# The combined model (stacked generator and discriminator)
51+
# Trains the generator to fool the discriminator
5252
self.combined = Model([noise, label], [valid, target_label])
5353
self.combined.compile(loss=losses,
5454
optimizer=optimizer)
@@ -75,11 +75,9 @@ def build_generator(self):
7575

7676
noise = Input(shape=(self.latent_dim,))
7777
label = Input(shape=(1,), dtype='int32')
78-
7978
label_embedding = Flatten()(Embedding(self.num_classes, 100)(label))
8079

8180
model_input = multiply([noise, label_embedding])
82-
8381
img = model(model_input)
8482

8583
return Model([noise, label], img)
@@ -123,38 +121,38 @@ def train(self, epochs, batch_size=128, sample_interval=50):
123121
# Load the dataset
124122
(X_train, y_train), (_, _) = mnist.load_data()
125123

126-
# Rescale -1 to 1
124+
# Configure inputs
127125
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
128126
X_train = np.expand_dims(X_train, axis=3)
129127
y_train = y_train.reshape(-1, 1)
130128

131-
half_batch = int(batch_size / 2)
129+
# Adversarial ground truths
130+
valid = np.ones((batch_size, 1))
131+
fake = np.zeros((batch_size, 1))
132132

133133
for epoch in range(epochs):
134134

135135
# ---------------------
136136
# Train Discriminator
137137
# ---------------------
138138

139-
# Select a random half batch of images
140-
idx = np.random.randint(0, X_train.shape[0], half_batch)
139+
# Select a random batch of images
140+
idx = np.random.randint(0, X_train.shape[0], batch_size)
141141
imgs = X_train[idx]
142142

143-
noise = np.random.normal(0, 1, (half_batch, 100))
143+
# Sample noise as generator input
144+
noise = np.random.normal(0, 1, (batch_size, 100))
144145

145146
# The labels of the digits that the generator tries to create an
146147
# image representation of
147-
sampled_labels = np.random.randint(0, 10, half_batch).reshape(-1, 1)
148+
sampled_labels = np.random.randint(0, 10, (batch_size, 1))
148149

149150
# Generate a half batch of new images
150151
gen_imgs = self.generator.predict([noise, sampled_labels])
151152

152-
valid = np.ones((half_batch, 1))
153-
fake = np.zeros((half_batch, 1))
154-
155153
# Image labels. 0-9 if image is valid or 10 if it is generated (fake)
156154
img_labels = y_train[idx]
157-
fake_labels = 10 * np.ones(half_batch).reshape(-1, 1)
155+
fake_labels = 10 * np.ones(img_labels.shape)
158156

159157
# Train the discriminator
160158
d_loss_real = self.discriminator.train_on_batch(imgs, [valid, img_labels])
@@ -165,14 +163,6 @@ def train(self, epochs, batch_size=128, sample_interval=50):
165163
# Train Generator
166164
# ---------------------
167165

168-
# Sample generator input
169-
noise = np.random.normal(0, 1, (batch_size, 100))
170-
171-
valid = np.ones((batch_size, 1))
172-
# Generator wants discriminator to label the generated images as the intended
173-
# digits
174-
sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
175-
176166
# Train the generator
177167
g_loss = self.combined.train_on_batch([noise, sampled_labels], [valid, sampled_labels])
178168

@@ -188,9 +178,7 @@ def sample_images(self, epoch):
188178
r, c = 10, 10
189179
noise = np.random.normal(0, 1, (r * c, 100))
190180
sampled_labels = np.array([num for _ in range(r) for num in range(c)])
191-
192181
gen_imgs = self.generator.predict([noise, sampled_labels])
193-
194182
# Rescale images 0 - 1
195183
gen_imgs = 0.5 * gen_imgs + 0.5
196184

@@ -215,9 +203,8 @@ def save(model, model_name):
215203
open(options['file_arch'], 'w').write(json_string)
216204
model.save_weights(options['file_weight'])
217205

218-
save(self.generator, "mnist_acgan_generator")
219-
save(self.discriminator, "mnist_acgan_discriminator")
220-
save(self.combined, "mnist_acgan_adversarial")
206+
save(self.generator, "generator")
207+
save(self.discriminator, "discriminator")
221208

222209

223210
if __name__ == '__main__':

bgan/bgan.py

+13-19
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def __init__(self):
4545
# The valid takes generated images as input and determines validity
4646
valid = self.discriminator(img)
4747

48-
# The combined model (stacked generator and discriminator) takes
49-
# noise as input => generates images => determines validity
48+
# The combined model (stacked generator and discriminator)
49+
# Trains the generator to fool the discriminator
5050
self.combined = Model(z, valid)
5151
self.combined.compile(loss=self.boundary_loss, optimizer=optimizer)
5252

@@ -103,44 +103,39 @@ def train(self, epochs, batch_size=128, sample_interval=50):
103103
(X_train, _), (_, _) = mnist.load_data()
104104

105105
# Rescale -1 to 1
106-
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
106+
X_train = X_train / 127.5 - 1.
107107
X_train = np.expand_dims(X_train, axis=3)
108108

109-
half_batch = int(batch_size / 2)
109+
# Adversarial ground truths
110+
valid = np.ones((batch_size, 1))
111+
fake = np.zeros((batch_size, 1))
110112

111113
for epoch in range(epochs):
112114

113115
# ---------------------
114116
# Train Discriminator
115117
# ---------------------
116118

117-
# Select a random half batch of images
118-
idx = np.random.randint(0, X_train.shape[0], half_batch)
119+
# Select a random batch of images
120+
idx = np.random.randint(0, X_train.shape[0], batch_size)
119121
imgs = X_train[idx]
120122

121-
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
123+
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
122124

123-
# Generate a half batch of new images
125+
# Generate a batch of new images
124126
gen_imgs = self.generator.predict(noise)
125127

126128
# Train the discriminator
127-
d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
128-
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
129+
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
130+
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
129131
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
130132

131133

132134
# ---------------------
133135
# Train Generator
134136
# ---------------------
135137

136-
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
137-
138-
# The generator wants the discriminator to label the generated samples
139-
# as valid (ones)
140-
valid_y = np.array([1] * batch_size)
141-
142-
# Train the generator
143-
g_loss = self.combined.train_on_batch(noise, valid_y)
138+
g_loss = self.combined.train_on_batch(noise, valid)
144139

145140
# Plot the progress
146141
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
@@ -153,7 +148,6 @@ def sample_images(self, epoch):
153148
r, c = 5, 5
154149
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
155150
gen_imgs = self.generator.predict(noise)
156-
157151
# Rescale images 0 - 1
158152
gen_imgs = 0.5 * gen_imgs + 0.5
159153

bigan/bigan.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self):
3434

3535
# Build the generator
3636
self.generator = self.build_generator()
37-
37+
3838
# Build the encoder
3939
self.encoder = self.build_encoder()
4040

@@ -54,6 +54,7 @@ def __init__(self):
5454
valid = self.discriminator([z_, img])
5555

5656
# Set up and compile the combined model
57+
# Trains generator to fool the discriminator
5758
self.bigan_generator = Model([z, img], [fake, valid])
5859
self.bigan_generator.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
5960
optimizer=optimizer)
@@ -125,7 +126,9 @@ def train(self, epochs, batch_size=128, sample_interval=50):
125126
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
126127
X_train = np.expand_dims(X_train, axis=3)
127128

128-
half_batch = int(batch_size / 2)
129+
# Adversarial ground truths
130+
valid = np.ones((batch_size, 1))
131+
fake = np.zeros((batch_size, 1))
129132

130133
for epoch in range(epochs):
131134

@@ -135,17 +138,14 @@ def train(self, epochs, batch_size=128, sample_interval=50):
135138
# ---------------------
136139

137140
# Sample noise and generate img
138-
z = np.random.normal(size=(half_batch, self.latent_dim))
141+
z = np.random.normal(size=(batch_size, self.latent_dim))
139142
imgs_ = self.generator.predict(z)
140143

141-
# Select a random half batch of images and encode
142-
idx = np.random.randint(0, X_train.shape[0], half_batch)
144+
# Select a random batch of images and encode
145+
idx = np.random.randint(0, X_train.shape[0], batch_size)
143146
imgs = X_train[idx]
144147
z_ = self.encoder.predict(imgs)
145148

146-
valid = np.ones((half_batch, 1))
147-
fake = np.zeros((half_batch, 1))
148-
149149
# Train the discriminator (img -> z is valid, z -> img is fake)
150150
d_loss_real = self.discriminator.train_on_batch([z_, imgs], valid)
151151
d_loss_fake = self.discriminator.train_on_batch([z, imgs_], fake)
@@ -155,16 +155,6 @@ def train(self, epochs, batch_size=128, sample_interval=50):
155155
# Train Generator
156156
# ---------------------
157157

158-
# Sample gaussian noise
159-
z = np.random.normal(size=(batch_size, self.latent_dim))
160-
161-
# Select a random half batch of images
162-
idx = np.random.randint(0, X_train.shape[0], batch_size)
163-
imgs = X_train[idx]
164-
165-
valid = np.ones((batch_size, 1))
166-
fake = np.zeros((batch_size, 1))
167-
168158
# Train the generator (z -> img is valid and img -> z is is invalid)
169159
g_loss = self.bigan_generator.train_on_batch([z, imgs], [valid, fake])
170160

@@ -173,7 +163,6 @@ def train(self, epochs, batch_size=128, sample_interval=50):
173163

174164
# If at save interval => save generated image samples
175165
if epoch % sample_interval == 0:
176-
# Select a random half batch of images
177166
self.sample_interval(epoch)
178167

179168
def sample_interval(self, epoch):

0 commit comments

Comments
 (0)