-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstart.py
33 lines (30 loc) · 993 Bytes
/
start.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import random
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from config import *
from core.worker import main_worker
# from utils import initialize_wandb
if __name__ == '__main__':
args = get_parser()
cudnn.benchmark = True
if args.debug:
rank = 0
local_rank = 0
else:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
args.local_rank = local_rank
world_size = torch.cuda.device_count()
args.nprocs = world_size
# if args.local_rank==0:
# initialize_wandb(args)
dist.init_process_group(backend="nccl", init_method=args.init_method,
rank=local_rank, world_size=world_size)
torch.cuda.set_device(rank % world_size)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
main_worker(args)