|
31 | 31 | import socket
|
32 | 32 | import uuid
|
33 | 33 | from concurrent.futures import ThreadPoolExecutor
|
| 34 | +from contextlib import nullcontext |
34 | 35 | from datetime import timedelta
|
35 | 36 | from enum import Enum
|
36 | 37 | from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
|
@@ -182,6 +183,10 @@ def __init__(
|
182 | 183 | self._pg = pg
|
183 | 184 | self._manager: Optional[ManagerServer] = None
|
184 | 185 |
|
| 186 | + self._recovery_stream: Optional["torch.cuda.Stream"] = ( |
| 187 | + torch.cuda.Stream() if torch.cuda.is_available() else None |
| 188 | + ) |
| 189 | + |
185 | 190 | if rank == 0:
|
186 | 191 | if port is None:
|
187 | 192 | port = int(os.environ.get(MANAGER_PORT_ENV, 0))
|
@@ -491,53 +496,63 @@ def _async_quorum(
|
491 | 496 | self._quorum_id = quorum_id
|
492 | 497 |
|
493 | 498 | if allow_heal:
|
494 |
| - if quorum.recover_dst_ranks: |
495 |
| - self._logger.info( |
496 |
| - f"peers need recovery from us {quorum.recover_dst_ranks}" |
497 |
| - ) |
498 |
| - self._checkpoint_transport.send_checkpoint( |
499 |
| - dst_ranks=quorum.recover_dst_ranks, |
500 |
| - step=max_step, |
501 |
| - state_dict=self._manager_state_dict(), |
502 |
| - timeout=self._timeout, |
503 |
| - ) |
504 |
| - |
505 |
| - # See manager.rs for healing conditions |
506 |
| - if heal: |
507 |
| - self._healing = True |
508 |
| - self._logger.info( |
509 |
| - f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}" |
510 |
| - ) |
511 |
| - primary_client = ManagerClient( |
512 |
| - recover_src_manager_address, connect_timeout=self._connect_timeout |
513 |
| - ) |
514 |
| - checkpoint_metadata = primary_client._checkpoint_metadata( |
515 |
| - self._rank, timeout=self._timeout |
516 |
| - ) |
517 |
| - recover_src_rank = quorum.recover_src_rank |
518 |
| - assert ( |
519 |
| - recover_src_rank is not None |
520 |
| - ), "must have a recover rank when healing" |
521 |
| - |
522 |
| - self._logger.info( |
523 |
| - f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" |
524 |
| - ) |
525 |
| - |
526 |
| - # we apply the user state dict only when safe from the main thread |
527 |
| - # save it for now |
528 |
| - self._pending_state_dict = self._checkpoint_transport.recv_checkpoint( |
529 |
| - src_rank=recover_src_rank, |
530 |
| - metadata=checkpoint_metadata, |
531 |
| - step=max_step, |
532 |
| - timeout=self._timeout, |
533 |
| - ) |
534 |
| - |
535 |
| - # pyre-fixme[6]: got object |
536 |
| - self.load_state_dict(self._pending_state_dict["torchft"]) |
537 |
| - |
538 |
| - # This isn't strictly needed as loading the state_dict above should |
539 |
| - # restore the correct step but it makes writing tests simpler. |
540 |
| - self._step = max_step |
| 499 | + # run recovery on the recovery stream if available |
| 500 | + recovery_stream = self._recovery_stream |
| 501 | + with ( |
| 502 | + torch.cuda.stream(recovery_stream) |
| 503 | + if recovery_stream is not None |
| 504 | + else nullcontext() |
| 505 | + ): |
| 506 | + if quorum.recover_dst_ranks: |
| 507 | + self._logger.info( |
| 508 | + f"peers need recovery from us {quorum.recover_dst_ranks}" |
| 509 | + ) |
| 510 | + self._checkpoint_transport.send_checkpoint( |
| 511 | + dst_ranks=quorum.recover_dst_ranks, |
| 512 | + step=max_step, |
| 513 | + state_dict=self._manager_state_dict(), |
| 514 | + timeout=self._timeout, |
| 515 | + ) |
| 516 | + |
| 517 | + # See manager.rs for healing conditions |
| 518 | + if heal: |
| 519 | + self._healing = True |
| 520 | + self._logger.info( |
| 521 | + f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}" |
| 522 | + ) |
| 523 | + primary_client = ManagerClient( |
| 524 | + recover_src_manager_address, |
| 525 | + connect_timeout=self._connect_timeout, |
| 526 | + ) |
| 527 | + checkpoint_metadata = primary_client._checkpoint_metadata( |
| 528 | + self._rank, timeout=self._timeout |
| 529 | + ) |
| 530 | + recover_src_rank = quorum.recover_src_rank |
| 531 | + assert ( |
| 532 | + recover_src_rank is not None |
| 533 | + ), "must have a recover rank when healing" |
| 534 | + |
| 535 | + self._logger.info( |
| 536 | + f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" |
| 537 | + ) |
| 538 | + |
| 539 | + # we apply the user state dict only when safe from the main thread |
| 540 | + # save it for now |
| 541 | + self._pending_state_dict = ( |
| 542 | + self._checkpoint_transport.recv_checkpoint( |
| 543 | + src_rank=recover_src_rank, |
| 544 | + metadata=checkpoint_metadata, |
| 545 | + step=max_step, |
| 546 | + timeout=self._timeout, |
| 547 | + ) |
| 548 | + ) |
| 549 | + |
| 550 | + # pyre-fixme[6]: got object |
| 551 | + self.load_state_dict(self._pending_state_dict["torchft"]) |
| 552 | + |
| 553 | + # This isn't strictly needed as loading the state_dict above should |
| 554 | + # restore the correct step but it makes writing tests simpler. |
| 555 | + self._step = max_step |
541 | 556 |
|
542 | 557 | def _apply_pending_state_dict(self) -> None:
|
543 | 558 | assert self._healing, "must be in healing state"
|
@@ -584,6 +599,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
|
584 | 599 | # never return an error.
|
585 | 600 | work.wait()
|
586 | 601 |
|
| 602 | + # make sure recovery is complete before committing |
| 603 | + if self._recovery_stream is not None: |
| 604 | + self._recovery_stream.synchronize() |
| 605 | + |
587 | 606 | self._pending_work = []
|
588 | 607 |
|
589 | 608 | # apply state_dict if healing
|
|
0 commit comments