Skip to content

Commit ad0ca0a

Browse files
authored
ProcessGroupNCCL: use _TimeoutManager to provide sane NCCL abort semantics (#141)
1 parent 9d38340 commit ad0ca0a

File tree

5 files changed

+353
-86
lines changed

5 files changed

+353
-86
lines changed

pyproject.toml

+8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Issues = "https://github.com/pytorch-labs/torchft/issues"
2323
[project.optional-dependencies]
2424
dev = [
2525
"pytest",
26+
"pytest-timeout",
2627
"black",
2728
"pyre-check",
2829
"parameterized",
@@ -41,3 +42,10 @@ torchft_lighthouse = "torchft._torchft:lighthouse_main"
4142
[tool.isort]
4243
multi_line_output = 3
4344
combine_as_imports = true
45+
46+
[tool.pytest.ini_options]
47+
log_format = "%(asctime)s %(levelname)s %(message)s"
48+
log_date_format = "%Y-%m-%d %H:%M:%S"
49+
log_level = "INFO"
50+
timeout = 60
51+
timeout_method = "thread"

pytest.ini

-4
This file was deleted.

torchft/futures.py

+91-19
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
22
import threading
3+
from contextlib import contextmanager
34
from datetime import timedelta
4-
from typing import Optional, TypeVar
5+
from typing import Callable, Generator, Optional, TypeVar
56
from unittest.mock import Mock
67

8+
import torch
79
from torch.futures import Future
810

911
T = TypeVar("T")
@@ -12,20 +14,24 @@
1214
class _TimerHandle:
1315
def __init__(self) -> None:
1416
self._lock = threading.Lock()
15-
self._lock.acquire()
1617
self._timer_handle: Optional[asyncio.TimerHandle] = None
18+
self._cancelled = False
1719

18-
def set_timer(self, timer_handle: asyncio.TimerHandle) -> None:
19-
assert self._lock.locked()
20-
21-
self._timer_handle = timer_handle
22-
self._lock.release()
20+
def set_timer_handle(self, timer_handle: asyncio.TimerHandle) -> None:
21+
with self._lock:
22+
if self._cancelled:
23+
timer_handle.cancel()
24+
self._timer_handle = None
25+
else:
26+
self._timer_handle = timer_handle
2327

2428
def cancel(self) -> None:
2529
with self._lock:
26-
assert self._timer_handle is not None
27-
self._timer_handle.cancel()
28-
self._timer_handle = None
30+
assert not self._cancelled, "timer can only be cancelled once"
31+
self._cancelled = True
32+
if self._timer_handle is not None:
33+
self._timer_handle.cancel()
34+
self._timer_handle = None
2935

3036

3137
class _TimeoutManager:
@@ -81,8 +87,16 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
8187
# pyre-fixme[29]: Future is not a function
8288
timed_fut: Future[T] = Future()
8389
handle: _TimerHandle = _TimerHandle()
84-
# pyre-fixme[6]: *args
85-
loop.call_soon_threadsafe(self._register, loop, timed_fut, timeout, handle)
90+
loop.call_soon_threadsafe(
91+
self._register_callback,
92+
loop,
93+
lambda: timed_fut.set_exception(
94+
# pyre-fixme[6]: e is not T
95+
TimeoutError(f"future did not complete within {timeout}")
96+
),
97+
timeout,
98+
handle,
99+
)
86100

87101
def callback(fut: Future[T]) -> None:
88102
handle.cancel()
@@ -99,22 +113,48 @@ def callback(fut: Future[T]) -> None:
99113
fut.add_done_callback(callback)
100114
return timed_fut
101115

116+
def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None:
117+
loop = self._maybe_start_event_loop()
118+
119+
event: torch.cuda.Event = torch.cuda.Event()
120+
event.record()
121+
122+
def handler() -> None:
123+
if not event.query():
124+
callback()
125+
126+
loop.call_soon_threadsafe(
127+
self._register_callback, loop, handler, timeout, _TimerHandle()
128+
)
129+
102130
@classmethod
103-
def _register(
131+
def _register_callback(
104132
cls,
105133
loop: asyncio.AbstractEventLoop,
106-
fut: Future[T],
134+
callback: Callable[[], None],
107135
timeout: timedelta,
108136
handle: _TimerHandle,
109137
) -> None:
110138
timer_handle = loop.call_later(
111139
timeout.total_seconds(),
112-
lambda: fut.set_exception(
113-
# pyre-fixme[6]: e is not T
114-
TimeoutError(f"future did not complete within {timeout}")
115-
),
140+
callback,
116141
)
117-
handle.set_timer(timer_handle)
142+
handle.set_timer_handle(timer_handle)
143+
144+
@contextmanager
145+
def context_timeout(
146+
self, callback: Callable[[], None], timeout: timedelta
147+
) -> Generator[None, None, None]:
148+
loop = self._maybe_start_event_loop()
149+
handle = _TimerHandle()
150+
151+
loop.call_soon_threadsafe(
152+
self._register_callback, loop, callback, timeout, handle
153+
)
154+
155+
yield
156+
157+
handle.cancel()
118158

119159

120160
_TIMEOUT_MANAGER = _TimeoutManager()
@@ -163,3 +203,35 @@ def callback(fut: Future[T]) -> T:
163203
raise TimeoutError(f"future did not complete within {timeout}")
164204

165205
return fut.wait()
206+
207+
208+
def stream_timeout(callback: Callable[[], None], timeout: timedelta) -> None:
209+
"""
210+
Registers a callback that will be called after the specified timeout if
211+
the current stream doesn't complete in time.
212+
213+
This uses a cuda Event to track the completion of the current stream. If
214+
the stream is not complete after the timeout, the callback is called.
215+
216+
Args:
217+
callback: The callback to call if the stream doesn't complete in time.
218+
timeout: The timeout to wait for the stream to complete.
219+
"""
220+
_TIMEOUT_MANAGER.stream_timeout(callback, timeout)
221+
222+
223+
@contextmanager
224+
def context_timeout(
225+
callback: Callable[[], None], timeout: timedelta
226+
) -> Generator[None, None, None]:
227+
"""
228+
Registers a callback that will be called after the specified timeout if
229+
the current contextmanager doesn't exit in time.
230+
231+
Args:
232+
callback: The callback to call if we time out.
233+
timeout: How long to wait for the contextmanager to exit.
234+
"""
235+
236+
with _TIMEOUT_MANAGER.context_timeout(callback, timeout):
237+
yield

0 commit comments

Comments
 (0)