-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun.py
59 lines (47 loc) · 1.72 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import os
from omegaconf import OmegaConf
from trainers import trainers_dict
def make_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='configs/_.yaml')
parser.add_argument('--name', '-n', default=None)
parser.add_argument('--tag', '-t', default=None)
parser.add_argument('--resume', '-r', action='store_true')
parser.add_argument('--force-replace', '-f', action='store_true')
parser.add_argument('--wandb', '-w', action='store_true')
parser.add_argument('--save-root', default='save')
parser.add_argument('--eval-only', action='store_true')
args = parser.parse_args()
return args
def parse_config(config):
if config.get('__base__') is not None:
filenames = config.pop('__base__')
if isinstance(filenames, str):
filenames = [filenames]
base_config = OmegaConf.merge(*[
parse_config(OmegaConf.load(_))
for _ in filenames
])
config = OmegaConf.merge(base_config, config)
return config
def make_env(args):
env = dict()
if args.name is None:
exp_name = os.path.splitext(os.path.basename(args.config))[0]
else:
exp_name = args.name
if args.tag is not None:
exp_name += '_' + args.tag
env['exp_name'] = exp_name
env['save_dir'] = os.path.join(args.save_root, exp_name)
env['wandb'] = args.wandb
env['resume'] = args.resume
env['force_replace'] = args.force_replace
return env
if __name__ == '__main__':
args = make_args()
env = make_env(args)
config = parse_config(OmegaConf.load(args.config))
trainer = trainers_dict[config.trainer](env, config)
trainer.run(eval_only=args.eval_only)