Skip to content

Commit b06100f

Browse files
committed
ProcessGroupNCCL,Manager: surface async abort errors correctly
1 parent 3724f7c commit b06100f

File tree

4 files changed

+66
-0
lines changed

4 files changed

+66
-0
lines changed

torchft/manager.py

+3
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
605605

606606
self._pending_work = []
607607

608+
if err := self._pg.errored():
609+
self.report_error(err)
610+
608611
# apply state_dict if healing
609612
if self._healing:
610613
self._apply_pending_state_dict()

torchft/manager_test.py

+35
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def _create_manager(
4242
timeout: timedelta = timedelta(seconds=10),
4343
) -> Manager:
4444
pg = create_autospec(ProcessGroup)
45+
pg.errored.return_value = None
46+
4547
self.store = TCPStore(
4648
host_name="localhost", port=0, is_master=True, wait_for_workers=False
4749
)
@@ -408,6 +410,39 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
408410
manager.allreduce(torch.tensor([1.0])).wait()
409411
self.assertTrue(manager.should_commit())
410412

413+
@patch("torchft.manager.ManagerClient", autospec=True)
414+
def test_pg_errored(self, client_mock: MagicMock) -> None:
415+
manager = self._create_manager()
416+
client_mock().should_commit = mock_should_commit
417+
418+
quorum = QuorumResult()
419+
quorum.quorum_id = 123
420+
quorum.replica_rank = 1
421+
quorum.replica_world_size = 2
422+
quorum.recover_src_manager_address = "manager address"
423+
quorum.store_address = f"localhost:{self.store.port}"
424+
quorum.max_step = 1
425+
quorum.max_rank = 1
426+
quorum.max_world_size = 2
427+
quorum.heal = False
428+
429+
client_mock()._quorum.return_value = quorum
430+
431+
self.assertEqual(manager._quorum_id, -1)
432+
self.assertEqual(manager.current_step(), 0)
433+
434+
manager.start_quorum()
435+
436+
injected_failure = RuntimeError("injected failure")
437+
438+
# pyre-ignore[16]: _pg is mocked
439+
manager._pg.errored.return_value = injected_failure
440+
441+
self.assertFalse(manager.should_commit())
442+
self.assertEqual(manager._errored, injected_failure)
443+
# pyre-ignore[16]: _pg is mocked
444+
self.assertEqual(manager._pg.errored.call_count, 1)
445+
411446
@patch("torchft.manager.ManagerClient", autospec=True)
412447
def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
413448
# test active and spares

torchft/process_group.py

+22
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,12 @@ def shutdown(self) -> None:
344344
"""
345345
pass
346346

347+
def errored(self) -> Optional[Exception]:
348+
"""
349+
Whether an async error occured that requires reconfiguration.
350+
"""
351+
return None
352+
347353
def __repr__(self) -> str:
348354
return f"{self.__class__.__name__}()"
349355

@@ -657,6 +663,8 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
657663
super().__init__(timeout)
658664
self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25)
659665

666+
self._errored: Optional[Exception] = None
667+
660668
def _opts_hook(self, opts: T) -> T:
661669
if not self._use_abort:
662670
return opts
@@ -679,6 +687,8 @@ def _wrap_work(self, work: Work, opts: object) -> Work:
679687
return _WorkCUDATimeout(self, work, timeout)
680688

681689
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
690+
self._errored = None
691+
682692
pg = BaseProcessGroup(store, rank, world_size)
683693
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
684694
# pyre-fixme[16]: no attribute ProcessGroupNCCL
@@ -689,6 +699,18 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
689699
)
690700
return pg
691701

702+
def abort(self) -> None:
703+
super().abort()
704+
705+
self._errored = RuntimeError("aborted")
706+
707+
def errored(self) -> Optional[Exception]:
708+
pg = self._pg
709+
if pg is not None:
710+
pg._wait_for_pending_works()
711+
712+
return self._errored
713+
692714
def getBackendName(self) -> str:
693715
return "torchft-nccl"
694716

torchft/process_group_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,8 @@ def _run_with_resiliency(self, collective: str, device: str = "cpu") -> None:
921921
def worker(pg: ProcessGroup, rank: int, dev: str) -> str:
922922
if dev == "cuda":
923923
torch.cuda.set_device(rank)
924+
# Use a separate stream to avoid deadlocks between threads.
925+
torch.cuda.set_stream(torch.cuda.Stream())
924926

925927
fault_rank = self.WORLD_SIZE - 1
926928
test = _COLLECTIVE_TO_FUNC[collective]
@@ -952,6 +954,10 @@ def worker(pg: ProcessGroup, rank: int, dev: str) -> str:
952954
test(pg, rank, t1.clone())
953955
raise RuntimeError("no error")
954956

957+
if err := pg.errored():
958+
with self.assertRaisesRegex(RuntimeError, "aborted"):
959+
raise err
960+
955961
return f"Rank{rank} final success."
956962

957963
# run in parallel

0 commit comments

Comments
 (0)