diff --git a/torchft/futures.py b/torchft/futures.py index df41a53..9409190 100644 --- a/torchft/futures.py +++ b/torchft/futures.py @@ -1,4 +1,6 @@ import asyncio +import queue +import sys import threading from contextlib import contextmanager from datetime import timedelta @@ -36,8 +38,13 @@ def cancel(self) -> None: class _TimeoutManager: """ - This class manages timeouts for futures. It uses a background thread with an - event loop to schedule the timeouts. + This class manages timeouts for code blocks, futures and CUDA events. It + uses a background thread with an event loop to schedule the timeouts and + call the callback function when the timeout is reached. + + Generally there is a single instance of this class that is used for all + timeouts. The callbacks should not block otherwise other timeouts may not + be processed. """ def __init__(self) -> None: @@ -46,6 +53,10 @@ def __init__(self) -> None: self._event_loop_thread: Optional[threading.Thread] = None self._next_timer_id = 0 + # This queue is used to delete events on the main thread as cudaEventDestroy + # can block if the CUDA queue is full. + self._del_queue: queue.SimpleQueue[object] = queue.SimpleQueue() + def _maybe_start_event_loop(self) -> asyncio.AbstractEventLoop: """ Start the event loop if it has not already been started. @@ -82,6 +93,8 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]: if isinstance(fut, Mock): return fut + self._clear_del_queue() + loop = self._maybe_start_event_loop() # pyre-fixme[29]: Future is not a function @@ -114,6 +127,8 @@ def callback(fut: Future[T]) -> None: return timed_fut def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None: + self._clear_del_queue() + loop = self._maybe_start_event_loop() event: torch.cuda.Event = torch.cuda.Event() @@ -123,6 +138,11 @@ def handler() -> None: if not event.query(): callback() + # cudaEventDestroy can block so we never want to delete in the event + # loop. Put it on the del queue so we can delete it in the main + # thread. + self._del_queue.put(event) + loop.call_soon_threadsafe( self._register_callback, loop, handler, timeout, _TimerHandle() ) @@ -145,6 +165,8 @@ def _register_callback( def context_timeout( self, callback: Callable[[], None], timeout: timedelta ) -> Generator[None, None, None]: + self._clear_del_queue() + loop = self._maybe_start_event_loop() handle = _TimerHandle() @@ -156,6 +178,31 @@ def context_timeout( handle.cancel() + def _clear_del_queue(self) -> int: + """ + Clear the queue of futures to be deleted. + + Returns the number of items deleted. + """ + count = 0 + while True: + try: + # get and immediately discard item + item = self._del_queue.get_nowait() + refcount = sys.getrefcount(item) + assert ( + # 1 from item, 1 from getrefcount + refcount + == 2 + ), f"items in del_queue reference should not have other references, found {refcount=}" + del item + + count += 1 + except queue.Empty: + break + + return count + _TIMEOUT_MANAGER = _TimeoutManager() diff --git a/torchft/futures_test.py b/torchft/futures_test.py index b9ea191..935079a 100644 --- a/torchft/futures_test.py +++ b/torchft/futures_test.py @@ -1,9 +1,17 @@ +import threading from datetime import timedelta -from unittest import TestCase +from unittest import TestCase, skipUnless +import torch from torch.futures import Future -from torchft.futures import future_timeout, future_wait +from torchft.futures import ( + _TIMEOUT_MANAGER, + context_timeout, + future_timeout, + future_wait, + stream_timeout, +) class FuturesTest(TestCase): @@ -45,3 +53,39 @@ def test_future_timeout_exception(self) -> None: fut.set_exception(RuntimeError("test")) with self.assertRaisesRegex(RuntimeError, "test"): timed_fut.wait() + + def test_context_timeout(self) -> None: + barrier: threading.Barrier = threading.Barrier(2) + + def callback() -> None: + barrier.wait() + + with context_timeout(callback, timedelta(seconds=0.01)): + # block until timeout fires + barrier.wait() + + def fail() -> None: + self.fail("timeout should be cancelled") + + with context_timeout(fail, timedelta(seconds=10)): + pass + + # pyre-fixme[56]: Pyre was not able to infer the type of decorator + @skipUnless(torch.cuda.is_available(), "CUDA is required for this test") + def test_stream_timeout(self) -> None: + torch.cuda.synchronize() + + def callback() -> None: + self.fail() + + stream_timeout(callback, timeout=timedelta(seconds=0.01)) + + # make sure event completes + torch.cuda.synchronize() + + # make sure that event is deleted on the deletion queue + item = _TIMEOUT_MANAGER._del_queue.get(timeout=10.0) + _TIMEOUT_MANAGER._del_queue.put(item) + del item + + self.assertEqual(_TIMEOUT_MANAGER._clear_del_queue(), 1)