forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
boilerplate.py
79 lines (58 loc) · 2.02 KB
/
boilerplate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
# Author: Your Name <[email protected]>
import os
import argparse
import tensorflow as tf
from tensorpack import *
"""
This is a boiler-plate template.
All code is in this file is the most minimalistic way to solve a deep-learning problem with cross-validation.
"""
BATCH_SIZE = 16
SHAPE = 28
CHANNELS = 3
class Model(ModelDesc):
def inputs(self):
return [tf.placeholder(tf.float32, (None, SHAPE, SHAPE, CHANNELS), 'input1'),
tf.placeholder(tf.int32, (None,), 'input2')]
def build_graph(self, input1, input2):
cost = tf.identity(input1 - input2, name='total_costs')
summary.add_moving_summary(cost)
return cost
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=5e-3, trainable=False)
return tf.train.AdamOptimizer(lr)
def get_data(subset):
# something that yields [[SHAPE, SHAPE, CHANNELS], [1]]
ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False,
dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)])
ds = PrefetchDataZMQ(ds, 2)
ds = BatchData(ds, BATCH_SIZE)
return ds
def get_config():
logger.auto_set_dir()
ds_train = get_data('train')
ds_test = get_data('test')
return TrainConfig(
model=Model(),
data=QueueInput(ds_train),
callbacks=[
ModelSaver(),
InferenceRunner(ds_test, [ScalarStats('total_costs')]),
],
steps_per_epoch=ds_train.size(),
max_epoch=100,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
if args.load:
config.session_init = SaverRestore(args.load)
launch_train_with_config(config, SimpleTrainer())