Skip to content

Commit 038d222

Browse files
authoredMar 20, 2025··
manager: use separate stream for recovery (#144)
1 parent f0a4061 commit 038d222

File tree

1 file changed

+66
-47
lines changed

1 file changed

+66
-47
lines changed
 

‎torchft/manager.py

+66-47
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import socket
3232
import uuid
3333
from concurrent.futures import ThreadPoolExecutor
34+
from contextlib import nullcontext
3435
from datetime import timedelta
3536
from enum import Enum
3637
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
@@ -182,6 +183,10 @@ def __init__(
182183
self._pg = pg
183184
self._manager: Optional[ManagerServer] = None
184185

186+
self._recovery_stream: Optional["torch.cuda.Stream"] = (
187+
torch.cuda.Stream() if torch.cuda.is_available() else None
188+
)
189+
185190
if rank == 0:
186191
if port is None:
187192
port = int(os.environ.get(MANAGER_PORT_ENV, 0))
@@ -491,53 +496,63 @@ def _async_quorum(
491496
self._quorum_id = quorum_id
492497

493498
if allow_heal:
494-
if quorum.recover_dst_ranks:
495-
self._logger.info(
496-
f"peers need recovery from us {quorum.recover_dst_ranks}"
497-
)
498-
self._checkpoint_transport.send_checkpoint(
499-
dst_ranks=quorum.recover_dst_ranks,
500-
step=max_step,
501-
state_dict=self._manager_state_dict(),
502-
timeout=self._timeout,
503-
)
504-
505-
# See manager.rs for healing conditions
506-
if heal:
507-
self._healing = True
508-
self._logger.info(
509-
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
510-
)
511-
primary_client = ManagerClient(
512-
recover_src_manager_address, connect_timeout=self._connect_timeout
513-
)
514-
checkpoint_metadata = primary_client._checkpoint_metadata(
515-
self._rank, timeout=self._timeout
516-
)
517-
recover_src_rank = quorum.recover_src_rank
518-
assert (
519-
recover_src_rank is not None
520-
), "must have a recover rank when healing"
521-
522-
self._logger.info(
523-
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
524-
)
525-
526-
# we apply the user state dict only when safe from the main thread
527-
# save it for now
528-
self._pending_state_dict = self._checkpoint_transport.recv_checkpoint(
529-
src_rank=recover_src_rank,
530-
metadata=checkpoint_metadata,
531-
step=max_step,
532-
timeout=self._timeout,
533-
)
534-
535-
# pyre-fixme[6]: got object
536-
self.load_state_dict(self._pending_state_dict["torchft"])
537-
538-
# This isn't strictly needed as loading the state_dict above should
539-
# restore the correct step but it makes writing tests simpler.
540-
self._step = max_step
499+
# run recovery on the recovery stream if available
500+
recovery_stream = self._recovery_stream
501+
with (
502+
torch.cuda.stream(recovery_stream)
503+
if recovery_stream is not None
504+
else nullcontext()
505+
):
506+
if quorum.recover_dst_ranks:
507+
self._logger.info(
508+
f"peers need recovery from us {quorum.recover_dst_ranks}"
509+
)
510+
self._checkpoint_transport.send_checkpoint(
511+
dst_ranks=quorum.recover_dst_ranks,
512+
step=max_step,
513+
state_dict=self._manager_state_dict(),
514+
timeout=self._timeout,
515+
)
516+
517+
# See manager.rs for healing conditions
518+
if heal:
519+
self._healing = True
520+
self._logger.info(
521+
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
522+
)
523+
primary_client = ManagerClient(
524+
recover_src_manager_address,
525+
connect_timeout=self._connect_timeout,
526+
)
527+
checkpoint_metadata = primary_client._checkpoint_metadata(
528+
self._rank, timeout=self._timeout
529+
)
530+
recover_src_rank = quorum.recover_src_rank
531+
assert (
532+
recover_src_rank is not None
533+
), "must have a recover rank when healing"
534+
535+
self._logger.info(
536+
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
537+
)
538+
539+
# we apply the user state dict only when safe from the main thread
540+
# save it for now
541+
self._pending_state_dict = (
542+
self._checkpoint_transport.recv_checkpoint(
543+
src_rank=recover_src_rank,
544+
metadata=checkpoint_metadata,
545+
step=max_step,
546+
timeout=self._timeout,
547+
)
548+
)
549+
550+
# pyre-fixme[6]: got object
551+
self.load_state_dict(self._pending_state_dict["torchft"])
552+
553+
# This isn't strictly needed as loading the state_dict above should
554+
# restore the correct step but it makes writing tests simpler.
555+
self._step = max_step
541556

542557
def _apply_pending_state_dict(self) -> None:
543558
assert self._healing, "must be in healing state"
@@ -584,6 +599,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
584599
# never return an error.
585600
work.wait()
586601

602+
# make sure recovery is complete before committing
603+
if self._recovery_stream is not None:
604+
self._recovery_stream.synchronize()
605+
587606
self._pending_work = []
588607

589608
# apply state_dict if healing

0 commit comments

Comments
 (0)
Please sign in to comment.