1
1
import asyncio
2
2
import threading
3
+ from contextlib import contextmanager
3
4
from datetime import timedelta
4
- from typing import Optional , TypeVar
5
+ from typing import Callable , Generator , Optional , TypeVar
5
6
from unittest .mock import Mock
6
7
8
+ import torch
7
9
from torch .futures import Future
8
10
9
11
T = TypeVar ("T" )
12
14
class _TimerHandle :
13
15
def __init__ (self ) -> None :
14
16
self ._lock = threading .Lock ()
15
- self ._lock .acquire ()
16
17
self ._timer_handle : Optional [asyncio .TimerHandle ] = None
18
+ self ._cancelled = False
17
19
18
- def set_timer (self , timer_handle : asyncio .TimerHandle ) -> None :
19
- assert self ._lock .locked ()
20
-
21
- self ._timer_handle = timer_handle
22
- self ._lock .release ()
20
+ def set_timer_handle (self , timer_handle : asyncio .TimerHandle ) -> None :
21
+ with self ._lock :
22
+ if self ._cancelled :
23
+ timer_handle .cancel ()
24
+ self ._timer_handle = None
25
+ else :
26
+ self ._timer_handle = timer_handle
23
27
24
28
def cancel (self ) -> None :
25
29
with self ._lock :
26
- assert self ._timer_handle is not None
27
- self ._timer_handle .cancel ()
28
- self ._timer_handle = None
30
+ assert not self ._cancelled , "timer can only be cancelled once"
31
+ self ._cancelled = True
32
+ if self ._timer_handle is not None :
33
+ self ._timer_handle .cancel ()
34
+ self ._timer_handle = None
29
35
30
36
31
37
class _TimeoutManager :
@@ -81,8 +87,16 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
81
87
# pyre-fixme[29]: Future is not a function
82
88
timed_fut : Future [T ] = Future ()
83
89
handle : _TimerHandle = _TimerHandle ()
84
- # pyre-fixme[6]: *args
85
- loop .call_soon_threadsafe (self ._register , loop , timed_fut , timeout , handle )
90
+ loop .call_soon_threadsafe (
91
+ self ._register_callback ,
92
+ loop ,
93
+ lambda : timed_fut .set_exception (
94
+ # pyre-fixme[6]: e is not T
95
+ TimeoutError (f"future did not complete within { timeout } " )
96
+ ),
97
+ timeout ,
98
+ handle ,
99
+ )
86
100
87
101
def callback (fut : Future [T ]) -> None :
88
102
handle .cancel ()
@@ -99,22 +113,48 @@ def callback(fut: Future[T]) -> None:
99
113
fut .add_done_callback (callback )
100
114
return timed_fut
101
115
116
+ def stream_timeout (self , callback : Callable [[], None ], timeout : timedelta ) -> None :
117
+ loop = self ._maybe_start_event_loop ()
118
+
119
+ event : torch .cuda .Event = torch .cuda .Event ()
120
+ event .record ()
121
+
122
+ def handler () -> None :
123
+ if not event .query ():
124
+ callback ()
125
+
126
+ loop .call_soon_threadsafe (
127
+ self ._register_callback , loop , handler , timeout , _TimerHandle ()
128
+ )
129
+
102
130
@classmethod
103
- def _register (
131
+ def _register_callback (
104
132
cls ,
105
133
loop : asyncio .AbstractEventLoop ,
106
- fut : Future [ T ],
134
+ callback : Callable [[], None ],
107
135
timeout : timedelta ,
108
136
handle : _TimerHandle ,
109
137
) -> None :
110
138
timer_handle = loop .call_later (
111
139
timeout .total_seconds (),
112
- lambda : fut .set_exception (
113
- # pyre-fixme[6]: e is not T
114
- TimeoutError (f"future did not complete within { timeout } " )
115
- ),
140
+ callback ,
116
141
)
117
- handle .set_timer (timer_handle )
142
+ handle .set_timer_handle (timer_handle )
143
+
144
+ @contextmanager
145
+ def context_timeout (
146
+ self , callback : Callable [[], None ], timeout : timedelta
147
+ ) -> Generator [None , None , None ]:
148
+ loop = self ._maybe_start_event_loop ()
149
+ handle = _TimerHandle ()
150
+
151
+ loop .call_soon_threadsafe (
152
+ self ._register_callback , loop , callback , timeout , handle
153
+ )
154
+
155
+ yield
156
+
157
+ handle .cancel ()
118
158
119
159
120
160
_TIMEOUT_MANAGER = _TimeoutManager ()
@@ -163,3 +203,35 @@ def callback(fut: Future[T]) -> T:
163
203
raise TimeoutError (f"future did not complete within { timeout } " )
164
204
165
205
return fut .wait ()
206
+
207
+
208
+ def stream_timeout (callback : Callable [[], None ], timeout : timedelta ) -> None :
209
+ """
210
+ Registers a callback that will be called after the specified timeout if
211
+ the current stream doesn't complete in time.
212
+
213
+ This uses a cuda Event to track the completion of the current stream. If
214
+ the stream is not complete after the timeout, the callback is called.
215
+
216
+ Args:
217
+ callback: The callback to call if the stream doesn't complete in time.
218
+ timeout: The timeout to wait for the stream to complete.
219
+ """
220
+ _TIMEOUT_MANAGER .stream_timeout (callback , timeout )
221
+
222
+
223
+ @contextmanager
224
+ def context_timeout (
225
+ callback : Callable [[], None ], timeout : timedelta
226
+ ) -> Generator [None , None , None ]:
227
+ """
228
+ Registers a callback that will be called after the specified timeout if
229
+ the current contextmanager doesn't exit in time.
230
+
231
+ Args:
232
+ callback: The callback to call if we time out.
233
+ timeout: How long to wait for the contextmanager to exit.
234
+ """
235
+
236
+ with _TIMEOUT_MANAGER .context_timeout (callback , timeout ):
237
+ yield
0 commit comments