-
Notifications
You must be signed in to change notification settings - Fork 18
/
lenet-bn.py
76 lines (60 loc) · 2.14 KB
/
lenet-bn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from src.net import Net
from src.utils import randn, uniform, guass, read_mnist
import numpy as np
import time
inp_dim = 28; out_dim = 10
std = 1.; lr = 1e-3
inp_shape = (inp_dim, inp_dim, 1)
net = Net()
image = net.portal((28, 28, 1))
label = net.portal((10, ))
is_training = net.portal()
k1 = net.variable(guass(0., std, (5, 5, 1, 32)))
b1 = net.variable(np.ones((32,)) * .1)
k2 = net.variable(guass(0., std, (5, 5, 32, 64)))
b2 = net.variable(np.ones((64,)) * .1)
w3 = net.variable(guass(0., std, (7 * 7 * 64, 1024)))
b3 = net.variable(np.ones((1024,)) * .1)
w4 = net.variable(guass(0., std, (1024, 10)))
b4 = net.variable(np.ones((10,)) * .1)
conv1 = net.conv2d(image, k1, pad = (2,2), stride = (1,1))
conv1 = net.batch_norm(
conv1, net.variable(guass(0., std, (32,))), is_training)
conv1 = net.plus_b(conv1, b1)
conv1 = net.relu(conv1)
pool1 = net.maxpool2(conv1)
conv2 = net.conv2d(pool1, k2, (2,2), (1,1))
conv2 = net.plus_b(conv2, b2)
conv2 = net.relu(conv2)
pool2 = net.maxpool2(conv2)
flat = net.reshape(pool2, (7 * 7 * 64,))
fc1 = net.plus_b(net.matmul(flat, w3), b3)
fc1 = net.relu(fc1)
fc2 = net.plus_b(net.matmul(fc1, w4), b4)
loss = net.softmax_crossent(fc2, label)
net.optimize(loss, 'adam', 1e-3)
mnist_data = read_mnist()
batch = 128
for count in range(5):
batch_num = int(mnist_data.train.num_examples/batch)
for i in range(batch_num):
feed, target = mnist_data.train.next_batch(batch)
feed = feed.reshape(batch, 28, 28, 1).astype(np.float64)
target = target.astype(np.float64)
pred, cost = net.train([fc2], {
image: feed,
label: target,
is_training: True})
predict = pred.argmax(1)
truth = target.argmax(1)
accuracy = np.equal(predict, truth).mean()
print('Step {} Loss {} Accuracy {}'.format(
i+1 + count*batch_num, cost, accuracy))
predict = net.forward([fc2], {
image : mnist_data.test.images.reshape((-1,28,28,1)),
is_training: False
})[0]
true_labels = mnist_data.test.labels.argmax(1)
pred_labels = predict.argmax(1)
accuracy = np.equal(true_labels, pred_labels).mean()
print('Accuracy on test set:', accuracy)