20
20
from typing import Any , Callable
21
21
22
22
import zmq
23
+ import zmq_anyio
23
24
from anyio import create_task_group , run , sleep , to_thread
24
25
from jupyter_client .session import extract_header
25
26
@@ -55,11 +56,11 @@ def run(self):
55
56
run (self ._main )
56
57
57
58
async def _main (self ):
58
- async with create_task_group () as tg :
59
+ async with create_task_group () as self . _task_group :
59
60
for task in self ._tasks :
60
- tg .start_soon (task )
61
+ self . _task_group .start_soon (task )
61
62
await to_thread .run_sync (self .__stop .wait )
62
- tg .cancel_scope .cancel ()
63
+ self . _task_group .cancel_scope .cancel ()
63
64
64
65
def stop (self ):
65
66
"""Stop the thread.
@@ -78,7 +79,7 @@ class IOPubThread:
78
79
whose IO is always run in a thread.
79
80
"""
80
81
81
- def __init__ (self , socket , pipe = False ):
82
+ def __init__ (self , socket : zmq_anyio . Socket , pipe = False ):
82
83
"""Create IOPub thread
83
84
84
85
Parameters
@@ -91,10 +92,7 @@ def __init__(self, socket, pipe=False):
91
92
"""
92
93
# ensure all of our sockets as sync zmq.Sockets
93
94
# don't create async wrappers until we are within the appropriate coroutines
94
- self .socket : zmq .Socket [bytes ] | None = zmq .Socket (socket )
95
- if self .socket .context is None :
96
- # bug in pyzmq, shadow socket doesn't always inherit context attribute
97
- self .socket .context = socket .context # type:ignore[unreachable]
95
+ self .socket : zmq_anyio .Socket = socket
98
96
self ._context = socket .context
99
97
100
98
self .background_socket = BackgroundSocket (self )
@@ -108,14 +106,14 @@ def __init__(self, socket, pipe=False):
108
106
self ._event_pipe_gc_lock : threading .Lock = threading .Lock ()
109
107
self ._event_pipe_gc_seconds : float = 10
110
108
self ._setup_event_pipe ()
111
- tasks = [self ._handle_event , self ._run_event_pipe_gc ]
109
+ tasks = [self ._handle_event , self ._run_event_pipe_gc , self . socket . start ]
112
110
if pipe :
113
111
tasks .append (self ._handle_pipe_msgs )
114
112
self .thread = _IOPubThread (tasks )
115
113
116
114
def _setup_event_pipe (self ):
117
115
"""Create the PULL socket listening for events that should fire in this thread."""
118
- self ._pipe_in0 = self ._context .socket (zmq .PULL , socket_class = zmq . Socket )
116
+ self ._pipe_in0 = self ._context .socket (zmq .PULL )
119
117
self ._pipe_in0 .linger = 0
120
118
121
119
_uuid = b2a_hex (os .urandom (16 )).decode ("ascii" )
@@ -150,7 +148,7 @@ def _event_pipe(self):
150
148
except AttributeError :
151
149
# new thread, new event pipe
152
150
# create sync base socket
153
- event_pipe = self ._context .socket (zmq .PUSH , socket_class = zmq . Socket )
151
+ event_pipe = self ._context .socket (zmq .PUSH )
154
152
event_pipe .linger = 0
155
153
event_pipe .connect (self ._event_interface )
156
154
self ._local .event_pipe = event_pipe
@@ -169,30 +167,28 @@ async def _handle_event(self):
169
167
Whenever *an* event arrives on the event stream,
170
168
*all* waiting events are processed in order.
171
169
"""
172
- # create async wrapper within coroutine
173
- pipe_in = zmq . asyncio . Socket ( self . _pipe_in0 )
174
- try :
175
- while True :
176
- await pipe_in .recv ()
177
- # freeze event count so new writes don't extend the queue
178
- # while we are processing
179
- n_events = len (self ._events )
180
- for _ in range (n_events ):
181
- event_f = self ._events .popleft ()
182
- event_f ()
183
- except Exception :
184
- if self .thread .__stop .is_set ():
185
- return
186
- raise
170
+ pipe_in = zmq_anyio . Socket ( self . _pipe_in0 )
171
+ async with pipe_in :
172
+ try :
173
+ while True :
174
+ await pipe_in .arecv ()
175
+ # freeze event count so new writes don't extend the queue
176
+ # while we are processing
177
+ n_events = len (self ._events )
178
+ for _ in range (n_events ):
179
+ event_f = self ._events .popleft ()
180
+ event_f ()
181
+ except Exception :
182
+ if self .thread .__stop .is_set ():
183
+ return
184
+ raise
187
185
188
186
def _setup_pipe_in (self ):
189
187
"""setup listening pipe for IOPub from forked subprocesses"""
190
- ctx = self ._context
191
-
192
188
# use UUID to authenticate pipe messages
193
189
self ._pipe_uuid = os .urandom (16 )
194
190
195
- self ._pipe_in1 = ctx . socket (zmq .PULL , socket_class = zmq . Socket )
191
+ self ._pipe_in1 = zmq_anyio . Socket ( self . _context . socket (zmq .PULL ) )
196
192
self ._pipe_in1 .linger = 0
197
193
198
194
try :
@@ -210,18 +206,18 @@ def _setup_pipe_in(self):
210
206
async def _handle_pipe_msgs (self ):
211
207
"""handle pipe messages from a subprocess"""
212
208
# create async wrapper within coroutine
213
- self . _async_pipe_in1 = zmq . asyncio . Socket ( self ._pipe_in1 )
214
- try :
215
- while True :
216
- await self ._handle_pipe_msg ()
217
- except Exception :
218
- if self .thread .__stop .is_set ():
219
- return
220
- raise
209
+ async with self ._pipe_in1 :
210
+ try :
211
+ while True :
212
+ await self ._handle_pipe_msg ()
213
+ except Exception :
214
+ if self .thread .__stop .is_set ():
215
+ return
216
+ raise
221
217
222
218
async def _handle_pipe_msg (self , msg = None ):
223
219
"""handle a pipe message from a subprocess"""
224
- msg = msg or await self ._async_pipe_in1 . recv_multipart ()
220
+ msg = msg or await self ._pipe_in1 . arecv_multipart ()
225
221
if not self ._pipe_flag or not self ._is_main_process ():
226
222
return
227
223
if msg [0 ] != self ._pipe_uuid :
0 commit comments