Skip to content

Commit 90f12c2

Browse files
committed
Use zmq-anyio
1 parent 8c8d7d2 commit 90f12c2

12 files changed

+150
-188
lines changed

ipykernel/debugger.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ async def _send_request(self, msg):
241241
self.log.debug("DEBUGPYCLIENT:")
242242
self.log.debug(self.routing_id)
243243
self.log.debug(buf)
244-
await self.debugpy_socket.send_multipart((self.routing_id, buf))
244+
await self.debugpy_socket.asend_multipart((self.routing_id, buf))
245245

246246
async def _wait_for_response(self):
247247
# Since events are never pushed to the message_queue
@@ -437,7 +437,7 @@ async def start(self):
437437
(self.shell_socket.getsockopt(ROUTING_ID)),
438438
)
439439

440-
msg = await self.shell_socket.recv_multipart()
440+
msg = await self.shell_socket.arecv_multipart()
441441
ident, msg = self.session.feed_identities(msg, copy=True)
442442
try:
443443
msg = self.session.deserialize(msg, content=True, copy=True)

ipykernel/inprocess/session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
class Session(_Session):
55
async def recv(self, socket, copy=True):
6-
return await socket.recv_multipart()
6+
return await socket.arecv_multipart()
77

88
def send(
99
self,

ipykernel/iostream.py

+34-38
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Any, Callable
2121

2222
import zmq
23+
import zmq_anyio
2324
from anyio import create_task_group, run, sleep, to_thread
2425
from jupyter_client.session import extract_header
2526

@@ -55,11 +56,11 @@ def run(self):
5556
run(self._main)
5657

5758
async def _main(self):
58-
async with create_task_group() as tg:
59+
async with create_task_group() as self._task_group:
5960
for task in self._tasks:
60-
tg.start_soon(task)
61+
self._task_group.start_soon(task)
6162
await to_thread.run_sync(self.__stop.wait)
62-
tg.cancel_scope.cancel()
63+
self._task_group.cancel_scope.cancel()
6364

6465
def stop(self):
6566
"""Stop the thread.
@@ -78,7 +79,7 @@ class IOPubThread:
7879
whose IO is always run in a thread.
7980
"""
8081

81-
def __init__(self, socket, pipe=False):
82+
def __init__(self, socket: zmq_anyio.Socket, pipe=False):
8283
"""Create IOPub thread
8384
8485
Parameters
@@ -91,10 +92,7 @@ def __init__(self, socket, pipe=False):
9192
"""
9293
# ensure all of our sockets as sync zmq.Sockets
9394
# don't create async wrappers until we are within the appropriate coroutines
94-
self.socket: zmq.Socket[bytes] | None = zmq.Socket(socket)
95-
if self.socket.context is None:
96-
# bug in pyzmq, shadow socket doesn't always inherit context attribute
97-
self.socket.context = socket.context # type:ignore[unreachable]
95+
self.socket: zmq_anyio.Socket = socket
9896
self._context = socket.context
9997

10098
self.background_socket = BackgroundSocket(self)
@@ -108,14 +106,14 @@ def __init__(self, socket, pipe=False):
108106
self._event_pipe_gc_lock: threading.Lock = threading.Lock()
109107
self._event_pipe_gc_seconds: float = 10
110108
self._setup_event_pipe()
111-
tasks = [self._handle_event, self._run_event_pipe_gc]
109+
tasks = [self._handle_event, self._run_event_pipe_gc, self.socket.start]
112110
if pipe:
113111
tasks.append(self._handle_pipe_msgs)
114112
self.thread = _IOPubThread(tasks)
115113

116114
def _setup_event_pipe(self):
117115
"""Create the PULL socket listening for events that should fire in this thread."""
118-
self._pipe_in0 = self._context.socket(zmq.PULL, socket_class=zmq.Socket)
116+
self._pipe_in0 = self._context.socket(zmq.PULL)
119117
self._pipe_in0.linger = 0
120118

121119
_uuid = b2a_hex(os.urandom(16)).decode("ascii")
@@ -150,7 +148,7 @@ def _event_pipe(self):
150148
except AttributeError:
151149
# new thread, new event pipe
152150
# create sync base socket
153-
event_pipe = self._context.socket(zmq.PUSH, socket_class=zmq.Socket)
151+
event_pipe = self._context.socket(zmq.PUSH)
154152
event_pipe.linger = 0
155153
event_pipe.connect(self._event_interface)
156154
self._local.event_pipe = event_pipe
@@ -169,30 +167,28 @@ async def _handle_event(self):
169167
Whenever *an* event arrives on the event stream,
170168
*all* waiting events are processed in order.
171169
"""
172-
# create async wrapper within coroutine
173-
pipe_in = zmq.asyncio.Socket(self._pipe_in0)
174-
try:
175-
while True:
176-
await pipe_in.recv()
177-
# freeze event count so new writes don't extend the queue
178-
# while we are processing
179-
n_events = len(self._events)
180-
for _ in range(n_events):
181-
event_f = self._events.popleft()
182-
event_f()
183-
except Exception:
184-
if self.thread.__stop.is_set():
185-
return
186-
raise
170+
pipe_in = zmq_anyio.Socket(self._pipe_in0)
171+
async with pipe_in:
172+
try:
173+
while True:
174+
await pipe_in.arecv()
175+
# freeze event count so new writes don't extend the queue
176+
# while we are processing
177+
n_events = len(self._events)
178+
for _ in range(n_events):
179+
event_f = self._events.popleft()
180+
event_f()
181+
except Exception:
182+
if self.thread.__stop.is_set():
183+
return
184+
raise
187185

188186
def _setup_pipe_in(self):
189187
"""setup listening pipe for IOPub from forked subprocesses"""
190-
ctx = self._context
191-
192188
# use UUID to authenticate pipe messages
193189
self._pipe_uuid = os.urandom(16)
194190

195-
self._pipe_in1 = ctx.socket(zmq.PULL, socket_class=zmq.Socket)
191+
self._pipe_in1 = zmq_anyio.Socket(self._context.socket(zmq.PULL))
196192
self._pipe_in1.linger = 0
197193

198194
try:
@@ -210,18 +206,18 @@ def _setup_pipe_in(self):
210206
async def _handle_pipe_msgs(self):
211207
"""handle pipe messages from a subprocess"""
212208
# create async wrapper within coroutine
213-
self._async_pipe_in1 = zmq.asyncio.Socket(self._pipe_in1)
214-
try:
215-
while True:
216-
await self._handle_pipe_msg()
217-
except Exception:
218-
if self.thread.__stop.is_set():
219-
return
220-
raise
209+
async with self._pipe_in1:
210+
try:
211+
while True:
212+
await self._handle_pipe_msg()
213+
except Exception:
214+
if self.thread.__stop.is_set():
215+
return
216+
raise
221217

222218
async def _handle_pipe_msg(self, msg=None):
223219
"""handle a pipe message from a subprocess"""
224-
msg = msg or await self._async_pipe_in1.recv_multipart()
220+
msg = msg or await self._pipe_in1.arecv_multipart()
225221
if not self._pipe_flag or not self._is_main_process():
226222
return
227223
if msg[0] != self._pipe_uuid:

ipykernel/ipkernel.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dataclasses import dataclass
1313

1414
import comm
15-
import zmq.asyncio
15+
import zmq_anyio
1616
from anyio import TASK_STATUS_IGNORED, create_task_group, to_thread
1717
from anyio.abc import TaskStatus
1818
from IPython.core import release
@@ -76,7 +76,7 @@ class IPythonKernel(KernelBase):
7676
help="Set this flag to False to deactivate the use of experimental IPython completion APIs.",
7777
).tag(config=True)
7878

79-
debugpy_socket = Instance(zmq.asyncio.Socket, allow_none=True)
79+
debugpy_socket = Instance(zmq_anyio.Socket, allow_none=True)
8080

8181
user_module = Any()
8282

@@ -212,7 +212,7 @@ def __init__(self, **kwargs):
212212
}
213213

214214
async def process_debugpy(self):
215-
async with create_task_group() as tg:
215+
async with self.debug_shell_socket, self.debugpy_socket, create_task_group() as tg:
216216
tg.start_soon(self.receive_debugpy_messages)
217217
tg.start_soon(self.poll_stopped_queue)
218218
await to_thread.run_sync(self.debugpy_stop.wait)
@@ -235,7 +235,7 @@ async def receive_debugpy_message(self, msg=None):
235235

236236
if msg is None:
237237
assert self.debugpy_socket is not None
238-
msg = await self.debugpy_socket.recv_multipart()
238+
msg = await self.debugpy_socket.arecv_multipart()
239239
# The first frame is the socket id, we can drop it
240240
frame = msg[1].decode("utf-8")
241241
self.log.debug("Debugpy received: %s", frame)

ipykernel/kernelapp.py

+11-48
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pathlib import Path
1919

2020
import zmq
21-
import zmq.asyncio
21+
import zmq_anyio
2222
from anyio import create_task_group, run
2323
from IPython.core.application import ( # type:ignore[attr-defined]
2424
BaseIPythonApplication,
@@ -325,15 +325,15 @@ def init_sockets(self):
325325
"""Create a context, a session, and the kernel sockets."""
326326
self.log.info("Starting the kernel at pid: %i", os.getpid())
327327
assert self.context is None, "init_sockets cannot be called twice!"
328-
self.context = context = zmq.asyncio.Context()
328+
self.context = context = zmq.Context()
329329
atexit.register(self.close)
330330

331-
self.shell_socket = context.socket(zmq.ROUTER)
331+
self.shell_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER))
332332
self.shell_socket.linger = 1000
333333
self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
334334
self.log.debug("shell ROUTER Channel on port: %i" % self.shell_port)
335335

336-
self.stdin_socket = zmq.Context(context).socket(zmq.ROUTER)
336+
self.stdin_socket = context.socket(zmq.ROUTER)
337337
self.stdin_socket.linger = 1000
338338
self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
339339
self.log.debug("stdin ROUTER Channel on port: %i" % self.stdin_port)
@@ -349,18 +349,19 @@ def init_sockets(self):
349349

350350
def init_control(self, context):
351351
"""Initialize the control channel."""
352-
self.control_socket = context.socket(zmq.ROUTER)
352+
self.control_socket = zmq_anyio.Socket(context.socket(zmq.ROUTER))
353353
self.control_socket.linger = 1000
354354
self.control_port = self._bind_socket(self.control_socket, self.control_port)
355355
self.log.debug("control ROUTER Channel on port: %i" % self.control_port)
356356

357-
self.debugpy_socket = context.socket(zmq.STREAM)
357+
self.debugpy_socket = zmq_anyio.Socket(context, zmq.STREAM)
358358
self.debugpy_socket.linger = 1000
359359

360-
self.debug_shell_socket = context.socket(zmq.DEALER)
360+
self.debug_shell_socket = zmq_anyio.Socket(context.socket(zmq.DEALER))
361361
self.debug_shell_socket.linger = 1000
362-
if self.shell_socket.getsockopt(zmq.LAST_ENDPOINT):
363-
self.debug_shell_socket.connect(self.shell_socket.getsockopt(zmq.LAST_ENDPOINT))
362+
last_endpoint = self.shell_socket.getsockopt(zmq.LAST_ENDPOINT)
363+
if last_endpoint:
364+
self.debug_shell_socket.connect(last_endpoint)
364365

365366
if hasattr(zmq, "ROUTER_HANDOVER"):
366367
# set router-handover to workaround zeromq reconnect problems
@@ -373,7 +374,7 @@ def init_control(self, context):
373374

374375
def init_iopub(self, context):
375376
"""Initialize the iopub channel."""
376-
self.iopub_socket = context.socket(zmq.PUB)
377+
self.iopub_socket = zmq_anyio.Socket(context.socket(zmq.PUB))
377378
self.iopub_socket.linger = 1000
378379
self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
379380
self.log.debug("iopub PUB Channel on port: %i" % self.iopub_port)
@@ -634,43 +635,6 @@ def configure_tornado_logger(self):
634635
handler.setFormatter(formatter)
635636
logger.addHandler(handler)
636637

637-
def _init_asyncio_patch(self):
638-
"""set default asyncio policy to be compatible with tornado
639-
640-
Tornado 6 (at least) is not compatible with the default
641-
asyncio implementation on Windows
642-
643-
Pick the older SelectorEventLoopPolicy on Windows
644-
if the known-incompatible default policy is in use.
645-
646-
Support for Proactor via a background thread is available in tornado 6.1,
647-
but it is still preferable to run the Selector in the main thread
648-
instead of the background.
649-
650-
do this as early as possible to make it a low priority and overridable
651-
652-
ref: https://github.com/tornadoweb/tornado/issues/2608
653-
654-
FIXME: if/when tornado supports the defaults in asyncio without threads,
655-
remove and bump tornado requirement for py38.
656-
Most likely, this will mean a new Python version
657-
where asyncio.ProactorEventLoop supports add_reader and friends.
658-
659-
"""
660-
if sys.platform.startswith("win"):
661-
import asyncio
662-
663-
try:
664-
from asyncio import WindowsProactorEventLoopPolicy, WindowsSelectorEventLoopPolicy
665-
except ImportError:
666-
pass
667-
# not affected
668-
else:
669-
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
670-
# WindowsProactorEventLoopPolicy is not compatible with tornado 6
671-
# fallback to the pre-3.8 default of Selector
672-
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
673-
674638
def init_pdb(self):
675639
"""Replace pdb with IPython's version that is interruptible.
676640
@@ -690,7 +654,6 @@ def init_pdb(self):
690654
@catch_config_error
691655
def initialize(self, argv=None):
692656
"""Initialize the application."""
693-
self._init_asyncio_patch()
694657
super().initialize(argv)
695658
if self.subapp is not None:
696659
return

0 commit comments

Comments
 (0)