|
| 1 | +# CNN-LSTM-CTC-OCR |
| 2 | +# Copyright (C) 2017 Jerod Weinman |
| 3 | +# |
| 4 | +# This program is free software: you can redistribute it and/or modify |
| 5 | +# it under the terms of the GNU General Public License as published by |
| 6 | +# the Free Software Foundation, either version 3 of the License, or |
| 7 | +# (at your option) any later version. |
| 8 | +# |
| 9 | +# This program is distributed in the hope that it will be useful, |
| 10 | +# but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 11 | +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 12 | +# GNU General Public License for more details. |
| 13 | +# |
| 14 | +# You should have received a copy of the GNU General Public License |
| 15 | +# along with this program. If not, see <http://www.gnu.org/licenses/>. |
| 16 | + |
| 17 | +import os |
| 18 | +import tensorflow as tf |
| 19 | +import math |
| 20 | + |
| 21 | +"""Each record within the TFRecord file is a serialized Example proto. |
| 22 | +The Example proto contains the following fields: |
| 23 | + image/encoded: string containing JPEG encoded grayscale image |
| 24 | + image/height: integer, image height in pixels |
| 25 | + image/width: integer, image width in pixels |
| 26 | + image/filename: string containing the basename of the image file |
| 27 | + image/labels: list containing the sequence labels for the image text |
| 28 | + image/text: string specifying the human-readable version of the text |
| 29 | +""" |
| 30 | + |
| 31 | +# The list (well, string) of valid output characters |
| 32 | +# If any example contains a character not found here, an error will result |
| 33 | +# from the calls to .index in the decoder below |
| 34 | +out_charset="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" |
| 35 | + |
| 36 | +jpeg_data = tf.placeholder(dtype=tf.string) |
| 37 | +jpeg_decoder = tf.image.decode_jpeg(jpeg_data,channels=1) |
| 38 | + |
| 39 | +kernel_sizes = [5,5,3,3,3,3] # CNN kernels for image reduction |
| 40 | + |
| 41 | +# Minimum allowable width of image after CNN processing |
| 42 | +min_width = 20 |
| 43 | + |
| 44 | +def calc_seq_len(image_width): |
| 45 | + """Calculate sequence length of given image after CNN processing""" |
| 46 | + |
| 47 | + conv1_trim = 2 * (kernel_sizes[0] // 2) |
| 48 | + fc6_trim = 2*(kernel_sizes[5] // 2) |
| 49 | + |
| 50 | + after_conv1 = image_width - conv1_trim |
| 51 | + after_pool1 = after_conv1 // 2 |
| 52 | + after_pool2 = after_pool1 // 2 |
| 53 | + after_pool4 = after_pool2 - 1 # max without stride |
| 54 | + after_fc6 = after_pool4 - fc6_trim |
| 55 | + seq_len = 2*after_fc6 |
| 56 | + return seq_len |
| 57 | + |
| 58 | +seq_lens = [calc_seq_len(w) for w in range(1024)] |
| 59 | + |
| 60 | +def gen_data(input_base_dir, image_list_filename, output_filebase, |
| 61 | + num_shards=1000,start_shard=0): |
| 62 | + """ Generate several shards worth of TFRecord data """ |
| 63 | + session_config = tf.ConfigProto() |
| 64 | + session_config.gpu_options.allow_growth=True |
| 65 | + sess = tf.Session(config=session_config) |
| 66 | + image_filenames = get_image_filenames(os.path.join(input_base_dir, |
| 67 | + image_list_filename)) |
| 68 | + num_digits = math.ceil( math.log10( num_shards - 1 )) |
| 69 | + shard_format = '%0'+ ('%d'%num_digits) + 'd' # Use appropriate # leading zeros |
| 70 | + images_per_shard = int(math.ceil( len(image_filenames) / float(num_shards) )) |
| 71 | + |
| 72 | + for i in range(start_shard,num_shards): |
| 73 | + start = i*images_per_shard |
| 74 | + end = (i+1)*images_per_shard |
| 75 | + out_filename = output_filebase+'-'+(shard_format % i)+'.tfrecord' |
| 76 | + if os.path.isfile(out_filename): # Don't recreate data if restarting |
| 77 | + continue |
| 78 | + print str(i),'of',str(num_shards),'[',str(start),':',str(end),']',out_filename |
| 79 | + gen_shard(sess, input_base_dir, image_filenames[start:end], out_filename) |
| 80 | + # Clean up writing last shard |
| 81 | + start = num_shards*images_per_shard |
| 82 | + out_filename = output_filebase+'-'+(shard_format % num_shards)+'.tfrecord' |
| 83 | + print str(i),'of',str(num_shards),'[',str(start),':]',out_filename |
| 84 | + gen_shard(sess, input_base_dir, image_filenames[start:], out_filename) |
| 85 | + |
| 86 | + sess.close() |
| 87 | + |
| 88 | +def gen_shard(sess, input_base_dir, image_filenames, output_filename): |
| 89 | + """Create a TFRecord file from a list of image filenames""" |
| 90 | + writer = tf.python_io.TFRecordWriter(output_filename) |
| 91 | + |
| 92 | + for filename in image_filenames: |
| 93 | + path_filename = os.path.join(input_base_dir,filename) |
| 94 | + if os.stat(path_filename).st_size == 0: |
| 95 | + print('SKIPPING',filename) |
| 96 | + continue |
| 97 | + try: |
| 98 | + image_data,height,width = get_image(sess,path_filename) |
| 99 | + text,labels = get_text_and_labels(filename) |
| 100 | + if is_writable(width,text): |
| 101 | + example = make_example(filename, image_data, labels, text, |
| 102 | + height, width) |
| 103 | + writer.write(example.SerializeToString()) |
| 104 | + else: |
| 105 | + print('SKIPPING',filename) |
| 106 | + except: |
| 107 | + # Some files have bogus payloads, catch and note the error, moving on |
| 108 | + print('ERROR',filename) |
| 109 | + writer.close() |
| 110 | + |
| 111 | + |
| 112 | +def get_image_filenames(image_list_filename): |
| 113 | + """ Given input file, generate a list of relative filenames""" |
| 114 | + filenames = [] |
| 115 | + with open(image_list_filename) as f: |
| 116 | + for line in f: |
| 117 | + # Carve out the ground truth string and file path from lines like: |
| 118 | + # ./2697/6/466_MONIKER_49537.jpg 49537 |
| 119 | + filename = line.split(' ',1)[0][2:] # split off "./" and number |
| 120 | + filenames.append(filename) |
| 121 | + return filenames |
| 122 | + |
| 123 | +def get_image(sess,filename): |
| 124 | + """Given path to an image file, load its data and size""" |
| 125 | + with tf.gfile.FastGFile(filename, 'r') as f: |
| 126 | + image_data = f.read() |
| 127 | + image = sess.run(jpeg_decoder,feed_dict={jpeg_data: image_data}) |
| 128 | + height = image.shape[0] |
| 129 | + width = image.shape[1] |
| 130 | + return image_data, height, width |
| 131 | + |
| 132 | +def is_writable(image_width,text): |
| 133 | + """Determine whether the CNN-processed image is longer than the string""" |
| 134 | + return (image_width > min_width) and (len(text) <= seq_lens[image_width]) |
| 135 | + |
| 136 | +def get_text_and_labels(filename): |
| 137 | + """ Extract the human-readable text and label sequence from image filename""" |
| 138 | + # Ground truth string lines embedded within base filename between underscores |
| 139 | + # 2697/6/466_MONIKER_49537.jpg --> MONIKER |
| 140 | + text = os.path.basename(filename).split('_',2)[1] |
| 141 | + # Transform string text to sequence of indices using charset, e.g., |
| 142 | + # MONIKER -> [12, 14, 13, 8, 10, 4, 17] |
| 143 | + labels = [out_charset.index(c) for c in list(text)] |
| 144 | + return text,labels |
| 145 | + |
| 146 | +def make_example(filename, image_data, labels, text, height, width): |
| 147 | + """Build an Example proto for an example. |
| 148 | + Args: |
| 149 | + filename: string, path to an image file, e.g., '/path/to/example.JPG' |
| 150 | + image_data: string, JPEG encoding of grayscale image |
| 151 | + labels: integer list, identifiers for the ground truth for the network |
| 152 | + text: string, unique human-readable, e.g. 'dog' |
| 153 | + height: integer, image height in pixels |
| 154 | + width: integer, image width in pixels |
| 155 | + Returns: |
| 156 | + Example proto |
| 157 | + """ |
| 158 | + example = tf.train.Example(features=tf.train.Features(feature={ |
| 159 | + 'image/encoded': _bytes_feature(tf.compat.as_bytes(image_data)), |
| 160 | + 'image/labels': _int64_feature(labels), |
| 161 | + 'image/height': _int64_feature([height]), |
| 162 | + 'image/width': _int64_feature([width]), |
| 163 | + 'image/filename': _bytes_feature(tf.compat.as_bytes(filename)), |
| 164 | + 'text/string': _bytes_feature(tf.compat.as_bytes(text)), |
| 165 | + 'text/length': _int64_feature([len(text)]) |
| 166 | + })) |
| 167 | + return example |
| 168 | + |
| 169 | +def _int64_feature(values): |
| 170 | + return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) |
| 171 | + |
| 172 | +def _bytes_feature(values): |
| 173 | + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) |
| 174 | + |
| 175 | +def main(argv=None): |
| 176 | + |
| 177 | + gen_data('../data/images', 'annotation_train.txt', '../data/train/words') |
| 178 | + gen_data('../data/images', 'annotation_val.txt', '../data/val/words') |
| 179 | + gen_data('../data/images', 'annotation_test.txt', '../data/test/words') |
| 180 | + |
| 181 | +if __name__ == '__main__': |
| 182 | + main() |
0 commit comments