Skip to content

Commit 15c4b33

Browse files
committed
Update on "[WIP][RFC] TorchFT integration"
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** The byproduct of issue 1 and issue 2: group 1 will continue to print out ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` ***How to reproduce?*** Using the following the steps in `Reproduce steps` to run 2 groups. Then kill any of the group after both start training. Remember to apply pytorch/torchft#83. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` [ghstack-poisoned]
2 parents 4b2edcb + f7ae033 commit 15c4b33

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def main(job_config: JobConfig):
4646
# take control of garbage collection to avoid stragglers
4747
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
4848

49+
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
50+
device_module.set_device(device)
4951
ft_manager = init_ft_manager(job_config)
5052

5153
# init distributed
@@ -60,8 +62,6 @@ def main(job_config: JobConfig):
6062
enable_loss_parallel=not job_config.training.disable_loss_parallel,
6163
ft_manager=ft_manager,
6264
)
63-
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
64-
device_module.set_device(device)
6565
utils.init_distributed(job_config)
6666
# initialize device memory monitor and get peak flops for MFU calculation
6767
device_memory_monitor = build_device_memory_monitor()

0 commit comments

Comments
 (0)