-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtrain_marepo.py
executable file
·67 lines (54 loc) · 2.21 KB
/
train_marepo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
# Copyright © Niantic, Inc. 2024.
from opt_marepo import get_opts
from marepo.marepo_trainer import TrainerMarepoTransformer
import sys
import time
import logging
# pytorch-lightning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
_logger = logging.getLogger(__name__)
if __name__ == '__main__':
# Setup logging levels.
logging.basicConfig(level=logging.INFO)
options = get_opts()
MarepoFormer = TrainerMarepoTransformer(options)
if options.use_half:
precision="16-mixed"
else:
precision=32
class MyProgressBar(TQDMProgressBar):
def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
if not sys.stdout.isatty():
bar.disable = True
return bar
def init_predict_tqdm(self):
bar = super().init_predict_tqdm()
if not sys.stdout.isatty():
bar.disable = True
return bar
def init_test_tqdm(self):
bar = super().init_test_tqdm()
if not sys.stdout.isatty():
bar.disable = True
return bar
callbacks = [MyProgressBar(refresh_rate=10)]
trainer = Trainer(max_epochs=options.epochs,
check_val_every_n_epoch=options.check_val_every_n_epoch,
callbacks=callbacks,
enable_model_summary=False,
accelerator='gpu',
devices=options.num_gpus, # hparams.num_gpus,
strategy=DDPStrategy(find_unused_parameters=True) if options.num_gpus > 1 else "auto",
num_sanity_val_steps= options.num_sanity_val_steps, # sanity check iter at beginning of the training, 0 is none, -1 is all val data
precision=precision,
# deterministic=True # try to be deterministic for unit tests
)
training_start = time.time()
trainer.fit(MarepoFormer)
end_time = time.time()
_logger.info(f'Done without errors. '
f'Total time: {end_time - training_start:.1f} seconds.')