-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
535 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
r"""Training executable for detection models. | ||
This executable is used to train DetectionModels. There are two ways of | ||
configuring the training job: | ||
1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file | ||
can be specified by --pipeline_config_path. | ||
Example usage: | ||
./train \ | ||
--logtostderr \ | ||
--train_dir=path/to/train_dir \ | ||
--pipeline_config_path=pipeline_config.pbtxt | ||
2) Three configuration files can be provided: a model_pb2.DetectionModel | ||
configuration file to define what type of DetectionModel is being trained, an | ||
input_reader_pb2.InputReader file to specify what training data will be used and | ||
a train_pb2.TrainConfig file to configure training parameters. | ||
Example usage: | ||
./train \ | ||
--logtostderr \ | ||
--train_dir=path/to/train_dir \ | ||
--model_config_path=model_config.pbtxt \ | ||
--train_config_path=train_config.pbtxt \ | ||
--input_config_path=train_input_config.pbtxt | ||
""" | ||
|
||
import functools | ||
import json | ||
import os | ||
import tensorflow as tf | ||
|
||
from google.protobuf import text_format | ||
|
||
from object_detection import trainer | ||
from object_detection.builders import input_reader_builder | ||
from object_detection.builders import model_builder | ||
from object_detection.protos import input_reader_pb2 | ||
from object_detection.protos import model_pb2 | ||
from object_detection.protos import pipeline_pb2 | ||
from object_detection.protos import train_pb2 | ||
|
||
tf.logging.set_verbosity(tf.logging.INFO) | ||
|
||
flags = tf.app.flags | ||
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.') | ||
flags.DEFINE_integer('task', 0, 'task id') | ||
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.') | ||
flags.DEFINE_boolean('clone_on_cpu', False, | ||
'Force clones to be deployed on CPU. Note that even if ' | ||
'set to False (allowing ops to run on gpu), some ops may ' | ||
'still be run on the CPU if they have no GPU kernel.') | ||
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer ' | ||
'replicas.') | ||
flags.DEFINE_integer('ps_tasks', 0, | ||
'Number of parameter server tasks. If None, does not use ' | ||
'a parameter server.') | ||
flags.DEFINE_string('train_dir', '', | ||
'Directory to save the checkpoints and training summaries.') | ||
|
||
flags.DEFINE_string('pipeline_config_path', '', | ||
'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' | ||
'file. If provided, other configs are ignored') | ||
|
||
flags.DEFINE_string('train_config_path', '', | ||
'Path to a train_pb2.TrainConfig config file.') | ||
flags.DEFINE_string('input_config_path', '', | ||
'Path to an input_reader_pb2.InputReader config file.') | ||
flags.DEFINE_string('model_config_path', '', | ||
'Path to a model_pb2.DetectionModel config file.') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
def get_configs_from_pipeline_file(): | ||
"""Reads training configuration from a pipeline_pb2.TrainEvalPipelineConfig. | ||
Reads training config from file specified by pipeline_config_path flag. | ||
Returns: | ||
model_config: model_pb2.DetectionModel | ||
train_config: train_pb2.TrainConfig | ||
input_config: input_reader_pb2.InputReader | ||
""" | ||
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() | ||
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: | ||
text_format.Merge(f.read(), pipeline_config) | ||
|
||
model_config = pipeline_config.model | ||
train_config = pipeline_config.train_config | ||
input_config = pipeline_config.train_input_reader | ||
|
||
return model_config, train_config, input_config | ||
|
||
|
||
def get_configs_from_multiple_files(): | ||
"""Reads training configuration from multiple config files. | ||
Reads the training config from the following files: | ||
model_config: Read from --model_config_path | ||
train_config: Read from --train_config_path | ||
input_config: Read from --input_config_path | ||
Returns: | ||
model_config: model_pb2.DetectionModel | ||
train_config: train_pb2.TrainConfig | ||
input_config: input_reader_pb2.InputReader | ||
""" | ||
train_config = train_pb2.TrainConfig() | ||
with tf.gfile.GFile(FLAGS.train_config_path, 'r') as f: | ||
text_format.Merge(f.read(), train_config) | ||
|
||
model_config = model_pb2.DetectionModel() | ||
with tf.gfile.GFile(FLAGS.model_config_path, 'r') as f: | ||
text_format.Merge(f.read(), model_config) | ||
|
||
input_config = input_reader_pb2.InputReader() | ||
with tf.gfile.GFile(FLAGS.input_config_path, 'r') as f: | ||
text_format.Merge(f.read(), input_config) | ||
|
||
return model_config, train_config, input_config | ||
|
||
|
||
def main(_): | ||
assert FLAGS.train_dir, '`train_dir` is missing.' | ||
if FLAGS.pipeline_config_path: | ||
model_config, train_config, input_config = get_configs_from_pipeline_file() | ||
else: | ||
model_config, train_config, input_config = get_configs_from_multiple_files() | ||
|
||
model_fn = functools.partial( | ||
model_builder.build, | ||
model_config=model_config, | ||
is_training=True) | ||
|
||
create_input_dict_fn = functools.partial( | ||
input_reader_builder.build, input_config) | ||
|
||
env = json.loads(os.environ.get('TF_CONFIG', '{}')) | ||
cluster_data = env.get('cluster', None) | ||
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None | ||
task_data = env.get('task', None) or {'type': 'master', 'index': 0} | ||
task_info = type('TaskSpec', (object,), task_data) | ||
|
||
# Parameters for a single worker. | ||
ps_tasks = 0 | ||
worker_replicas = 1 | ||
worker_job_name = 'lonely_worker' | ||
task = 0 | ||
is_chief = True | ||
master = '' | ||
|
||
if cluster_data and 'worker' in cluster_data: | ||
# Number of total worker replicas include "worker"s and the "master". | ||
worker_replicas = len(cluster_data['worker']) + 1 | ||
if cluster_data and 'ps' in cluster_data: | ||
ps_tasks = len(cluster_data['ps']) | ||
|
||
if worker_replicas > 1 and ps_tasks < 1: | ||
raise ValueError('At least 1 ps task is needed for distributed training.') | ||
|
||
if worker_replicas >= 1 and ps_tasks > 0: | ||
# Set up distributed training. | ||
server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc', | ||
job_name=task_info.type, | ||
task_index=task_info.index) | ||
if task_info.type == 'ps': | ||
server.join() | ||
return | ||
|
||
worker_job_name = '%s/task:%d' % (task_info.type, task_info.index) | ||
task = task_info.index | ||
is_chief = (task_info.type == 'master') | ||
master = server.target | ||
|
||
trainer.train(create_input_dict_fn, model_fn, train_config, master, task, | ||
FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks, | ||
worker_job_name, is_chief, FLAGS.train_dir) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
r"""Tool to export an object detection model for inference. | ||
Prepares an object detection tensorflow graph for inference using model | ||
configuration and an optional trained checkpoint. Outputs inference | ||
graph, associated checkpoint files, a frozen inference graph and a | ||
SavedModel (https://tensorflow.github.io/serving/serving_basic.html). | ||
The inference graph contains one of three input nodes depending on the user | ||
specified option. | ||
* `image_tensor`: Accepts a uint8 4-D tensor of shape [None, None, None, 3] | ||
* `encoded_image_string_tensor`: Accepts a 1-D string tensor of shape [None] | ||
containing encoded PNG or JPEG images. Image resolutions are expected to be | ||
the same if more than 1 image is provided. | ||
* `tf_example`: Accepts a 1-D string tensor of shape [None] containing | ||
serialized TFExample protos. Image resolutions are expected to be the same | ||
if more than 1 image is provided. | ||
and the following output nodes returned by the model.postprocess(..): | ||
* `num_detections`: Outputs float32 tensors of the form [batch] | ||
that specifies the number of valid boxes per image in the batch. | ||
* `detection_boxes`: Outputs float32 tensors of the form | ||
[batch, num_boxes, 4] containing detected boxes. | ||
* `detection_scores`: Outputs float32 tensors of the form | ||
[batch, num_boxes] containing class scores for the detections. | ||
* `detection_classes`: Outputs float32 tensors of the form | ||
[batch, num_boxes] containing classes for the detections. | ||
* `detection_masks`: Outputs float32 tensors of the form | ||
[batch, num_boxes, mask_height, mask_width] containing predicted instance | ||
masks for each box if its present in the dictionary of postprocessed | ||
tensors returned by the model. | ||
Notes: | ||
* This tool uses `use_moving_averages` from eval_config to decide which | ||
weights to freeze. | ||
Example Usage: | ||
-------------- | ||
python export_inference_graph \ | ||
--input_type image_tensor \ | ||
--pipeline_config_path path/to/ssd_inception_v2.config \ | ||
--trained_checkpoint_prefix path/to/model.ckpt \ | ||
--output_directory path/to/exported_model_directory | ||
The expected output would be in the directory | ||
path/to/exported_model_directory (which is created if it does not exist) | ||
with contents: | ||
- graph.pbtxt | ||
- model.ckpt.data-00000-of-00001 | ||
- model.ckpt.info | ||
- model.ckpt.meta | ||
- frozen_inference_graph.pb | ||
+ saved_model (a directory) | ||
""" | ||
import tensorflow as tf | ||
from google.protobuf import text_format | ||
from object_detection import exporter | ||
from object_detection.protos import pipeline_pb2 | ||
|
||
slim = tf.contrib.slim | ||
flags = tf.app.flags | ||
|
||
flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be ' | ||
'one of [`image_tensor`, `encoded_image_string_tensor`, ' | ||
'`tf_example`]') | ||
flags.DEFINE_string('pipeline_config_path', None, | ||
'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' | ||
'file.') | ||
flags.DEFINE_string('trained_checkpoint_prefix', None, | ||
'Path to trained checkpoint, typically of the form ' | ||
'path/to/model.ckpt') | ||
flags.DEFINE_string('output_directory', None, 'Path to write outputs.') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
def main(_): | ||
assert FLAGS.pipeline_config_path, '`pipeline_config_path` is missing' | ||
assert FLAGS.trained_checkpoint_prefix, ( | ||
'`trained_checkpoint_prefix` is missing') | ||
assert FLAGS.output_directory, '`output_directory` is missing' | ||
|
||
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() | ||
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: | ||
text_format.Merge(f.read(), pipeline_config) | ||
exporter.export_inference_graph( | ||
FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_prefix, | ||
FLAGS.output_directory) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
model_checkpoint_path: "model.ckpt" | ||
all_model_checkpoint_paths: "model.ckpt" |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
item { | ||
id: 1 | ||
name: 'cafo' | ||
} |
Oops, something went wrong.