-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathmain.py
113 lines (101 loc) · 3.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# Copyright 2016 Timothy Dozat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import os
import sys
import codecs
from argparse import ArgumentParser
from parser import Configurable
from parser import Network
# TODO make the pretrained vocab names a list given to TokenVocab
#***************************************************************
# Set up the argparser
argparser = ArgumentParser('Network')
argparser.add_argument('--save_dir', required=True)
subparsers = argparser.add_subparsers()
section_names = set()
# --section_name opt1=value1 opt2=value2 opt3=value3
with codecs.open('config/defaults.cfg') as f:
section_regex = re.compile('\[(.*)\]')
for line in f:
match = section_regex.match(line)
if match:
section_names.add(match.group(1).lower().replace(' ', '_'))
#===============================================================
# Train
#---------------------------------------------------------------
def train(save_dir, **kwargs):
""""""
kwargs['config_file'] = kwargs.pop('config_file', '')
load = kwargs.pop('load')
try:
if not load and os.path.isdir(save_dir):
raw_input('Save directory already exists. Press <Enter> to continue or <Ctrl-c> to abort.')
if os.path.isfile(os.path.join(save_dir, 'config.cfg')):
os.remove(os.path.join(save_dir, 'config.cfg'))
except KeyboardInterrupt:
print()
sys.exit(0)
network = Network(**kwargs)
network.train(load=load)
return
#---------------------------------------------------------------
train_parser = subparsers.add_parser('train')
train_parser.set_defaults(action=train)
train_parser.add_argument('--load', action='store_true')
train_parser.add_argument('--config_file')
for section_name in section_names:
train_parser.add_argument('--'+section_name, nargs='+')
#===============================================================
# Parse
#---------------------------------------------------------------
def parse(save_dir, **kwargs):
""""""
kwargs['config_file'] = os.path.join(save_dir, 'config.cfg')
files = kwargs.pop('files')
output_file = kwargs.pop('output_file', None)
output_dir = kwargs.pop('output_dir', None)
if len(files) > 1 and output_file is not None:
raise ValueError('Cannot provide a value for --output_file when parsing multiple files')
kwargs['is_evaluation'] = True
network = Network(**kwargs)
network.parse(files, output_file=output_file, output_dir=output_dir)
return
#---------------------------------------------------------------
parse_parser = subparsers.add_parser('parse')
parse_parser.set_defaults(action=parse)
parse_parser.add_argument('files', nargs='+')
for section_name in section_names:
parse_parser.add_argument('--'+section_name, nargs='+')
parse_parser.add_argument('--output_file')
parse_parser.add_argument('--output_dir')
#***************************************************************
# Parse the arguments
kwargs = vars(argparser.parse_args())
action = kwargs.pop('action')
save_dir = kwargs.pop('save_dir')
kwargs = {key: value for key, value in kwargs.iteritems() if value is not None}
for section, values in kwargs.iteritems():
if section in section_names:
values = [value.split('=', 1) for value in values]
kwargs[section] = {opt: value for opt, value in values}
if 'default' not in kwargs:
kwargs['default'] = {}
kwargs['default']['save_dir'] = save_dir
action(save_dir, **kwargs)