Skip to content

Commit 5fa06b7

Browse files
committed
Added initial source and data files
1 parent 40e73a2 commit 5fa06b7

11 files changed

+1053
-0
lines changed

AUTHOR

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Jerod Weinman
2+

Makefile

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
all: mjsynth-download mjsynth-tfrecord train
2+
3+
demo: train
4+
5+
mjsynth-download: mjsynth-wget mjsynth-unpack
6+
7+
mjsynth-wget:
8+
mkdir -p data
9+
cd data ; \
10+
wget http://www.robots.ox.ac.uk/~vgg/data/text/mjsynth.tar.gz
11+
12+
mjsynth-unpack:
13+
mkdir -p data/images
14+
# strip leading mnt/ramdisk/max/90kDICT32px/
15+
tar xzvf data/mjsynth.tar.gz \
16+
--strip=4 \
17+
-C data/images
18+
19+
mjsynth-tfrecord:
20+
mkdir -p data/train data/val data/test
21+
cd src ; python mjsynth-tfrecord.py
22+
23+
train:
24+
cd src ; python train.py # use --help for options
25+
26+
monitor:
27+
tensorboard --logdir=data/model --port=8008
28+
29+
test:
30+
cd src ; python test.py # use --help for options

README.md

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Overview
2+
3+
This collection demonstrates how to construct and train a deep,
4+
bidirectional stacked LSTM using a CNN features as input with CTC loss
5+
to perform robust word recognition. The model is a straightforward
6+
adaptation of Shi et al.'s CRNN architecture (arXiv:1507.0571). Code
7+
provided downloads and trains using Jaderberg et al.'s synthetic data
8+
(doi: 10.1007/s11263-015-0823-z).
9+
10+
11+
12+
# Structure
13+
14+
The model as build is a hybrid of Shi et al.'s CRNN architecture
15+
(arXiv:1507.0571) and the VGG deep convnet, which reduces the number
16+
of parameters by stacking pairs of small 3x3 kernels. In addition, the
17+
pooling is also limited in the horizontal direction to preserve
18+
resolution for character recognition. There must be at least one
19+
horizontal element per character.
20+
21+
Assuming one starts with a 32x32 image, the dimensions at each level
22+
of filtering are as follows:
23+
24+
25+
===================================================================
26+
Layer Op KrnSz Stride(v,h) OutDim H W Options
27+
-------------------------------------------------------------------
28+
1 Conv 3 1 64 30 30 valid
29+
2 Conv 3 1 64 30 30 same
30+
Pool 2 2 64 15 15
31+
3 Conv 3 1 128 15 15 same
32+
4 Conv 3 1 128 15 15 same
33+
Pool 2 2,1 128 7 14
34+
5 Conv 3 1 256 7 14 same
35+
6 Conv 3 1 256 7 14 same
36+
Pool 2 2,1 256 3 13
37+
7 Conv 3 1 512 3 13 same
38+
8 Conv 3 1 512 3 13 same
39+
Pool 3 3,1 512 1 13
40+
9 LSTM 512
41+
10 LSTM 512
42+
43+
To accelerate training, a batch normalization layer is included before
44+
each pooling layer and ReLU non-linearities are used throughout. Other
45+
model details should be easily identifiable in the code.
46+
47+
The default training mechanism uses the ADAM optimizer with learning
48+
rate decay.
49+
50+
# Training
51+
52+
To completely train the model, you will need to download the mjsynth
53+
dataset, pack it into sharded tensorflow records. Then you can start
54+
the training process, a tensorboard monitor, and an ongoing evaluation
55+
thread. The individual commands are packaged in the accompanying `Makefile`.
56+
57+
make mjsynth-download
58+
make mjsynth-tfrecord
59+
make train &
60+
make monitor &
61+
make test
62+
63+
To monitor training, point your web browser to the url (e.g.,
64+
(http://127.0.1.1:8008)) given by the Tensorboard output.
65+
66+
Note that it may take 4-12 hours to download the complete mjsynth data
67+
set. A very small set (0.1%) of packaged example data is included; to
68+
run the small demo, skip the first two lines involving `mjsynth`.
69+
70+
With a Geforce GTX 1080, the demo takes about 20 minutes for the
71+
validation character error to reach 45% (using the default
72+
parameters); at one hour (roughly 7000 iterations), the validation
73+
error is just over 20%.
74+
75+
With the full training data, the model typically converges to around
76+
7% training character error and 35% word error, both varying by 2-5%.
77+
78+
# Testing
79+
80+
The test script streams statistics for small batches of validation (or test) data. It ouputs the label error (percentage of characters predicted incorrectly), the test loss, and the sequence error (percentage of words--entire sequences--predicted incorrectly.)
81+
82+
# Configuration
83+
84+
There are many command-line options to configure training
85+
parameters. Run `train.py` or `test.py` with the `--help` flag to see
86+
them or inspect the scripts. Model parameters are not command-line
87+
configurable and need to be edited in the code (see `model.py`).

data/test/words-000.tfrecord

1.59 MB
Binary file not shown.

data/train/words-000.tfrecord

12.9 MB
Binary file not shown.

data/val/words-000.tfrecord

1.46 MB
Binary file not shown.

src/mjsynth-tfrecord.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# CNN-LSTM-CTC-OCR
2+
# Copyright (C) 2017 Jerod Weinman
3+
#
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the GNU General Public License as published by
6+
# the Free Software Foundation, either version 3 of the License, or
7+
# (at your option) any later version.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
# GNU General Public License for more details.
13+
#
14+
# You should have received a copy of the GNU General Public License
15+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
import os
18+
import tensorflow as tf
19+
import math
20+
21+
"""Each record within the TFRecord file is a serialized Example proto.
22+
The Example proto contains the following fields:
23+
image/encoded: string containing JPEG encoded grayscale image
24+
image/height: integer, image height in pixels
25+
image/width: integer, image width in pixels
26+
image/filename: string containing the basename of the image file
27+
image/labels: list containing the sequence labels for the image text
28+
image/text: string specifying the human-readable version of the text
29+
"""
30+
31+
# The list (well, string) of valid output characters
32+
# If any example contains a character not found here, an error will result
33+
# from the calls to .index in the decoder below
34+
out_charset="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
35+
36+
jpeg_data = tf.placeholder(dtype=tf.string)
37+
jpeg_decoder = tf.image.decode_jpeg(jpeg_data,channels=1)
38+
39+
kernel_sizes = [5,5,3,3,3,3] # CNN kernels for image reduction
40+
41+
# Minimum allowable width of image after CNN processing
42+
min_width = 20
43+
44+
def calc_seq_len(image_width):
45+
"""Calculate sequence length of given image after CNN processing"""
46+
47+
conv1_trim = 2 * (kernel_sizes[0] // 2)
48+
fc6_trim = 2*(kernel_sizes[5] // 2)
49+
50+
after_conv1 = image_width - conv1_trim
51+
after_pool1 = after_conv1 // 2
52+
after_pool2 = after_pool1 // 2
53+
after_pool4 = after_pool2 - 1 # max without stride
54+
after_fc6 = after_pool4 - fc6_trim
55+
seq_len = 2*after_fc6
56+
return seq_len
57+
58+
seq_lens = [calc_seq_len(w) for w in range(1024)]
59+
60+
def gen_data(input_base_dir, image_list_filename, output_filebase,
61+
num_shards=1000,start_shard=0):
62+
""" Generate several shards worth of TFRecord data """
63+
session_config = tf.ConfigProto()
64+
session_config.gpu_options.allow_growth=True
65+
sess = tf.Session(config=session_config)
66+
image_filenames = get_image_filenames(os.path.join(input_base_dir,
67+
image_list_filename))
68+
num_digits = math.ceil( math.log10( num_shards - 1 ))
69+
shard_format = '%0'+ ('%d'%num_digits) + 'd' # Use appropriate # leading zeros
70+
images_per_shard = int(math.ceil( len(image_filenames) / float(num_shards) ))
71+
72+
for i in range(start_shard,num_shards):
73+
start = i*images_per_shard
74+
end = (i+1)*images_per_shard
75+
out_filename = output_filebase+'-'+(shard_format % i)+'.tfrecord'
76+
if os.path.isfile(out_filename): # Don't recreate data if restarting
77+
continue
78+
print str(i),'of',str(num_shards),'[',str(start),':',str(end),']',out_filename
79+
gen_shard(sess, input_base_dir, image_filenames[start:end], out_filename)
80+
# Clean up writing last shard
81+
start = num_shards*images_per_shard
82+
out_filename = output_filebase+'-'+(shard_format % num_shards)+'.tfrecord'
83+
print str(i),'of',str(num_shards),'[',str(start),':]',out_filename
84+
gen_shard(sess, input_base_dir, image_filenames[start:], out_filename)
85+
86+
sess.close()
87+
88+
def gen_shard(sess, input_base_dir, image_filenames, output_filename):
89+
"""Create a TFRecord file from a list of image filenames"""
90+
writer = tf.python_io.TFRecordWriter(output_filename)
91+
92+
for filename in image_filenames:
93+
path_filename = os.path.join(input_base_dir,filename)
94+
if os.stat(path_filename).st_size == 0:
95+
print('SKIPPING',filename)
96+
continue
97+
try:
98+
image_data,height,width = get_image(sess,path_filename)
99+
text,labels = get_text_and_labels(filename)
100+
if is_writable(width,text):
101+
example = make_example(filename, image_data, labels, text,
102+
height, width)
103+
writer.write(example.SerializeToString())
104+
else:
105+
print('SKIPPING',filename)
106+
except:
107+
# Some files have bogus payloads, catch and note the error, moving on
108+
print('ERROR',filename)
109+
writer.close()
110+
111+
112+
def get_image_filenames(image_list_filename):
113+
""" Given input file, generate a list of relative filenames"""
114+
filenames = []
115+
with open(image_list_filename) as f:
116+
for line in f:
117+
# Carve out the ground truth string and file path from lines like:
118+
# ./2697/6/466_MONIKER_49537.jpg 49537
119+
filename = line.split(' ',1)[0][2:] # split off "./" and number
120+
filenames.append(filename)
121+
return filenames
122+
123+
def get_image(sess,filename):
124+
"""Given path to an image file, load its data and size"""
125+
with tf.gfile.FastGFile(filename, 'r') as f:
126+
image_data = f.read()
127+
image = sess.run(jpeg_decoder,feed_dict={jpeg_data: image_data})
128+
height = image.shape[0]
129+
width = image.shape[1]
130+
return image_data, height, width
131+
132+
def is_writable(image_width,text):
133+
"""Determine whether the CNN-processed image is longer than the string"""
134+
return (image_width > min_width) and (len(text) <= seq_lens[image_width])
135+
136+
def get_text_and_labels(filename):
137+
""" Extract the human-readable text and label sequence from image filename"""
138+
# Ground truth string lines embedded within base filename between underscores
139+
# 2697/6/466_MONIKER_49537.jpg --> MONIKER
140+
text = os.path.basename(filename).split('_',2)[1]
141+
# Transform string text to sequence of indices using charset, e.g.,
142+
# MONIKER -> [12, 14, 13, 8, 10, 4, 17]
143+
labels = [out_charset.index(c) for c in list(text)]
144+
return text,labels
145+
146+
def make_example(filename, image_data, labels, text, height, width):
147+
"""Build an Example proto for an example.
148+
Args:
149+
filename: string, path to an image file, e.g., '/path/to/example.JPG'
150+
image_data: string, JPEG encoding of grayscale image
151+
labels: integer list, identifiers for the ground truth for the network
152+
text: string, unique human-readable, e.g. 'dog'
153+
height: integer, image height in pixels
154+
width: integer, image width in pixels
155+
Returns:
156+
Example proto
157+
"""
158+
example = tf.train.Example(features=tf.train.Features(feature={
159+
'image/encoded': _bytes_feature(tf.compat.as_bytes(image_data)),
160+
'image/labels': _int64_feature(labels),
161+
'image/height': _int64_feature([height]),
162+
'image/width': _int64_feature([width]),
163+
'image/filename': _bytes_feature(tf.compat.as_bytes(filename)),
164+
'text/string': _bytes_feature(tf.compat.as_bytes(text)),
165+
'text/length': _int64_feature([len(text)])
166+
}))
167+
return example
168+
169+
def _int64_feature(values):
170+
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
171+
172+
def _bytes_feature(values):
173+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
174+
175+
def main(argv=None):
176+
177+
gen_data('../data/images', 'annotation_train.txt', '../data/train/words')
178+
gen_data('../data/images', 'annotation_val.txt', '../data/val/words')
179+
gen_data('../data/images', 'annotation_test.txt', '../data/test/words')
180+
181+
if __name__ == '__main__':
182+
main()

0 commit comments

Comments
 (0)