Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing highlevel open tcp stream #2725

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ module = [
"trio._deprecate",
"trio._dtls",
"trio._file_io",
"trio._highlevel_open_tcp_stream.py",
"trio._ki",
"trio._socket",
"trio._sync",
Expand Down
67 changes: 50 additions & 17 deletions trio/_highlevel_open_tcp_stream.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

import sys
from collections.abc import Generator
from contextlib import contextmanager
from socket import AddressFamily, SocketKind
from typing import TYPE_CHECKING

import trio
from trio._core._multierror import MultiError
from trio.socket import SOCK_STREAM, getaddrinfo, socket
from trio.socket import SOCK_STREAM, Address, _SocketType, getaddrinfo, socket

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
Expand Down Expand Up @@ -109,8 +114,8 @@


@contextmanager
def close_all():
sockets_to_close = set()
def close_all() -> Generator[set[_SocketType], None, None]:
sockets_to_close: set[_SocketType] = set()
try:
yield sockets_to_close
finally:
Expand All @@ -126,7 +131,17 @@ def close_all():
raise MultiError(errs)


def reorder_for_rfc_6555_section_5_4(targets):
def reorder_for_rfc_6555_section_5_4(
targets: list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
CoolCat467 marked this conversation as resolved.
Show resolved Hide resolved
]
]
) -> None:
# RFC 6555 section 5.4 says that if getaddrinfo returns multiple address
# families (e.g. IPv4 and IPv6), then you should make sure that your first
# and second attempts use different families:
Expand All @@ -144,7 +159,7 @@ def reorder_for_rfc_6555_section_5_4(targets):
break


def format_host_port(host, port):
def format_host_port(host: str | bytes, port: int) -> str:
host = host.decode("ascii") if isinstance(host, bytes) else host
if ":" in host:
return f"[{host}]:{port}"
Expand Down Expand Up @@ -173,8 +188,12 @@ def format_host_port(host, port):
# AF_INET6: "..."}
# this might be simpler after
async def open_tcp_stream(
host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None
):
host: str | bytes,
port: int,
*,
happy_eyeballs_delay: float | None = DEFAULT_DELAY,
local_address: str | None = None,
CoolCat467 marked this conversation as resolved.
Show resolved Hide resolved
) -> trio.abc.Stream:
"""Connect to the given host and port over TCP.

If the given ``host`` has multiple IP addresses associated with it, then
Expand Down Expand Up @@ -212,9 +231,9 @@ async def open_tcp_stream(

port (int): The port to connect to.

happy_eyeballs_delay (float): How many seconds to wait for each
happy_eyeballs_delay (float or None): How many seconds to wait for each
connection attempt to succeed or fail before getting impatient and
starting another one in parallel. Set to `math.inf` if you want
starting another one in parallel. Set to `None` if you want
to limit to only one connection attempt at a time (like
:func:`socket.create_connection`). Default: 0.25 (250 ms).

Expand Down Expand Up @@ -247,9 +266,8 @@ async def open_tcp_stream(
# To keep our public API surface smaller, rule out some cases that
# getaddrinfo will accept in some circumstances, but that act weird or
# have non-portable behavior or are just plain not useful.
# No type check on host though b/c we want to allow bytes-likes.
if host is None:
raise ValueError("host cannot be None")
if not isinstance(host, (str, bytes)):
raise ValueError(f"host must be str or bytes, not {host!r}")
if not isinstance(port, int):
raise TypeError(f"port must be int, not {port!r}")

Expand All @@ -274,7 +292,7 @@ async def open_tcp_stream(

# Keeps track of the socket that we're going to complete with,
# need to make sure this isn't automatically closed
winning_socket = None
winning_socket: _SocketType | None = None

# Try connecting to the specified address. Possible outcomes:
# - success: record connected socket in winning_socket and cancel
Expand All @@ -283,7 +301,11 @@ async def open_tcp_stream(
# the next connection attempt to start early
# code needs to ensure sockets can be closed appropriately in the
# face of crash or cancellation
async def attempt_connect(socket_args, sockaddr, attempt_failed):
async def attempt_connect(
socket_args: tuple[AddressFamily, SocketKind, int],
sockaddr: Address,
attempt_failed: trio.Event,
) -> None:
nonlocal winning_socket

try:
Expand Down Expand Up @@ -334,7 +356,7 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed):
except OSError:
raise OSError(
f"local_address={local_address!r} is incompatible "
f"with remote address {sockaddr}"
f"with remote address {sockaddr!r}"
)

await sock.connect(sockaddr)
Expand All @@ -355,12 +377,23 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed):
# nursery spawns a task for each connection attempt, will be
# cancelled by the task that gets a successful connection
async with trio.open_nursery() as nursery:
for *sa, _, addr in targets:
for address_family, socket_type, proto, _, addr in targets:
# create an event to indicate connection failure,
# allowing the next target to be tried early
attempt_failed = trio.Event()

nursery.start_soon(attempt_connect, sa, addr, attempt_failed)
# workaround to check types until typing of nursery.start_soon improved
if TYPE_CHECKING:
await attempt_connect(
(address_family, socket_type, proto), addr, attempt_failed
)

nursery.start_soon(
attempt_connect,
(address_family, socket_type, proto),
addr,
attempt_failed,
)

# give this attempt at most this time before moving on
with trio.move_on_after(happy_eyeballs_delay):
Expand Down
7 changes: 3 additions & 4 deletions trio/_tests/verify_types.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
"warningCount": 0
},
"typeCompleteness": {
"completenessScore": 0.9137380191693291,
"completenessScore": 0.9154704944178629,
"exportedSymbolCounts": {
"withAmbiguousType": 0,
"withKnownType": 572,
"withUnknownType": 54
"withKnownType": 574,
"withUnknownType": 53
},
"ignoreUnknownTypesFromImports": true,
"missingClassDocStringCount": 1,
Expand Down Expand Up @@ -109,7 +109,6 @@
"trio.open_ssl_over_tcp_listeners",
"trio.open_ssl_over_tcp_stream",
"trio.open_tcp_listeners",
"trio.open_tcp_stream",
"trio.open_unix_socket",
"trio.run",
"trio.run_process",
Expand Down
1 change: 1 addition & 0 deletions trio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

# import the overwrites
from ._socket import (
Address as Address,
SocketType as SocketType,
_SocketType as _SocketType,
from_stdlib_socket as from_stdlib_socket,
Expand Down