Skip to content

Commit

Permalink
Implement sync backend and add SyncClient tests
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Jan 5, 2020
1 parent b28b6ac commit ad1cfd0
Show file tree
Hide file tree
Showing 7 changed files with 519 additions and 7 deletions.
2 changes: 1 addition & 1 deletion httpx/_async/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def lookup_backend(
if not is_async_mode():
from ...backends.sync import SyncBackend

return SyncBackend # type: ignore
return SyncBackend() # type: ignore

if backend == "auto":
from ...backends.auto import AutoBackend
Expand Down
5 changes: 3 additions & 2 deletions httpx/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hstspreload

from ..auth import Auth, AuthTypes, BasicAuth, FunctionAuth
from ..backends.shells import AsyncFixes
from ..config import (
DEFAULT_MAX_REDIRECTS,
DEFAULT_POOL_LIMITS,
Expand Down Expand Up @@ -404,9 +405,9 @@ async def send(

if not stream:
try:
await response.aread()
await AsyncFixes.read_response(response)
finally:
await response.aclose()
await AsyncFixes.close_response(response)

return response

Expand Down
21 changes: 21 additions & 0 deletions httpx/backends/shells.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from ..models import Response


class AsyncFixes:
@staticmethod
async def read_response(response: Response) -> bytes:
return await response.aread()

@staticmethod
async def close_response(response: Response) -> None:
await response.aclose()


class SyncFixes:
@staticmethod
def read_response(response: Response) -> bytes:
return response.read()

@staticmethod
def close_response(response: Response) -> None:
response.close()
242 changes: 238 additions & 4 deletions httpx/backends/sync.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,243 @@
from .._sync.backends.base import SyncConcurrencyBackend
import errno
import functools
import socket
import ssl
import threading
import time
import typing

from .._sync.backends.base import (
SyncBaseLock,
SyncBaseSemaphore,
SyncBaseSocketStream,
SyncConcurrencyBackend,
)
from ..config import Timeout
from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
from .sync_utils.wait import wait_for_socket as default_wait_for_socket


class SocketStream(SyncBaseSocketStream):
def __init__(
self,
sock: socket.socket,
timeout: Timeout,
wait_for_socket: typing.Callable = default_wait_for_socket,
):
self.sock = sock
self.timeout = timeout
self.wait_for_socket = wait_for_socket
self.write_buffer = b""
# Keep the socket in non-blocking mode, except during connect() and
# during the SSL handshake.
self.sock.setblocking(False)

def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout
) -> "SocketStream":
self.sock.setblocking(True)
wrapped = ssl_context.wrap_socket(self.sock, server_hostname=hostname)
wrapped.setblocking(False)
return SocketStream(wrapped, timeout=self.timeout)

def get_http_version(self) -> str:
if not isinstance(self.sock, ssl.SSLSocket):
return "HTTP/1.1"
ident = self.sock.selected_alpn_protocol()
return "HTTP/2" if ident == "h2" else "HTTP/1.1"

def _wait(
self, readable: bool, writable: bool, mode: str, timeout: typing.Optional[float]
) -> None:
assert mode in ("read", "write")
assert readable or writable
if not self.wait_for_socket(
self.sock, read=readable, write=writable, timeout=timeout
):
raise (ReadTimeout() if mode == "read" else WriteTimeout())

def read(self, n: int, timeout: Timeout) -> bytes:
read_timeout = timeout.read_timeout
start = time.time()
while True:
try:
return self.sock.recv(n)
except ssl.SSLWantReadError:
self._wait(
readable=True, writable=False, mode="read", timeout=read_timeout
)
except ssl.SSLWantWriteError:
self._wait(
readable=False, writable=True, mode="read", timeout=read_timeout
)
except (OSError, socket.error) as exc:
if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
self._wait(
readable=True, writable=False, mode="read", timeout=read_timeout
)
else:
raise

if read_timeout is not None:
read_timeout -= time.time() - start
if read_timeout < 0:
raise ReadTimeout()

def write(self, data: bytes, timeout: Timeout = None,) -> None:
if not data:
return

if timeout is None:
timeout = self.timeout
write_timeout = timeout.write_timeout
start = time.time()

while data:
made_progress = False
want_read = False
want_write = False

try:
sent = self.sock.send(data)
data = data[sent:]
except ssl.SSLWantReadError:
want_read = True
except ssl.SSLWantWriteError:
want_write = True
except (OSError, socket.error) as exc:
if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
want_write = True
else:
raise
else:
made_progress = True

if not made_progress:
self._wait(
readable=want_read,
writable=want_write,
mode="write",
timeout=write_timeout,
)

if write_timeout is not None:
write_timeout -= time.time() - start
if write_timeout < 0:
raise WriteTimeout()

def is_connection_dropped(self) -> bool:
# Counter-intuitively, what we really want to know here is whether the socket is
# *readable*, i.e. whether it would return immediately with empty bytes if we
# called `.recv()` on it, indicating that the other end has closed the socket.
# See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
return self.wait_for_socket(self.sock, read=True, timeout=0)

def close(self) -> None:
self.sock.close()


class Semaphore(SyncBaseSemaphore):
def __init__(self, max_value: int, exc_class: type) -> None:
self.max_value = max_value
self.exc_class = exc_class

@property
def semaphore(self) -> threading.BoundedSemaphore:
if not hasattr(self, "_semaphore"):
self._semaphore = threading.BoundedSemaphore(value=self.max_value)
return self._semaphore

def acquire(self, timeout: float = None) -> None:
if timeout is None:
self.semaphore.acquire()
return

acquired = self.semaphore.acquire(blocking=True, timeout=timeout)

if not acquired:
raise self.exc_class()

def release(self) -> None:
self.semaphore.release()


class SyncBackend(SyncConcurrencyBackend):
"""
A synchronous backend.
TODO
Concurrency backend that performs synchronous I/O operations
while exposing async-annotated methods.
"""

def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: Timeout,
) -> SocketStream:
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(timeout.connect_timeout)
sock.connect((hostname, port))
if ssl_context is not None:
sock = ssl_context.wrap_socket(sock, server_hostname=hostname)
except socket.timeout:
raise ConnectTimeout()
except socket.error:
raise # TODO: raise an HTTPX-specific exception
else:
return SocketStream(sock=sock, timeout=timeout)

def open_uds_stream(
self,
path: str,
hostname: typing.Optional[str],
ssl_context: typing.Optional[ssl.SSLContext],
timeout: Timeout,
) -> SocketStream:
raise NotImplementedError

def time(self) -> float:
return time.time()

def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
# This backend is a blocking one anyway, so no need to use
# a threadpool here.
return func(*args, **kwargs)

def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
if kwargs:
coroutine = functools.partial(coroutine, **kwargs)
return run_secretly_sync_async_function(coroutine, *args)

def create_semaphore(self, max_value: int, exc_class: type) -> SyncBaseSemaphore:
return Semaphore(max_value, exc_class)

def create_lock(self) -> SyncBaseLock:
return Lock()


class Lock(SyncBaseLock):
def __init__(self) -> None:
self._lock = threading.RLock()

def release(self) -> None:
self._lock.release()

def acquire(self) -> None:
self._lock.acquire()


def run_secretly_sync_async_function(
async_function: typing.Callable, *args: typing.Any
) -> typing.Any:
coro = async_function(*args)
try:
coro.send(None)
except StopIteration as exc:
return exc.value
else:
raise RuntimeError("This async function is not secretly synchronous.")
Empty file.
90 changes: 90 additions & 0 deletions httpx/backends/sync_utils/wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Adapted from:
https://github.com/python-trio/urllib3/blob/f5ff1acf157c167e549c941ee19715341cba2b58/src/urllib3/util/wait.py
"""

import select
import socket
import typing


class NoWayToWaitForSocketError(Exception):
pass


def select_wait_for_socket(
sock: socket.socket, read: bool = False, write: bool = False, timeout: float = None
) -> bool:
if not read and not write:
raise RuntimeError("must specify at least one of read=True, write=True")
rcheck = []
wcheck = []
if read:
rcheck.append(sock)
if write:
wcheck.append(sock)
# When doing a non-blocking connect, most systems signal success by
# marking the socket writable. Windows, though, signals success by marked
# it as "exceptional". We paper over the difference by checking the write
# sockets for both conditions. (The stdlib selectors module does the same
# thing.)
rready, wready, xready = select.select(rcheck, wcheck, wcheck, timeout)
return bool(rready or wready or xready)


def poll_wait_for_socket(
sock: socket.socket, read: bool = False, write: bool = False, timeout: float = None
) -> bool:
if not read and not write:
raise RuntimeError("must specify at least one of read=True, write=True")
mask = 0
if read:
mask |= select.POLLIN
if write:
mask |= select.POLLOUT
poll_obj = select.poll()
poll_obj.register(sock, mask)

# For some reason, poll() takes timeout in milliseconds
def do_poll(t: typing.Optional[float]) -> typing.Any:
if t is not None:
t *= 1000
return poll_obj.poll(t)

return bool(do_poll(timeout))


def null_wait_for_socket(
sock: socket.socket, read: bool = False, write: bool = False, timeout: float = None
) -> typing.NoReturn:
raise NoWayToWaitForSocketError("no select-equivalent available")


def _have_working_poll() -> bool:
# Apparently some systems have a select.poll that fails as soon as you try
# to use it, either due to strange configuration or broken monkeypatching
# from libraries like eventlet/greenlet.
try:
poll_obj = select.poll()
poll_obj.poll(0)
except (AttributeError, OSError):
return False
else:
return True


def wait_for_socket(
sock: socket.socket, read: bool = False, write: bool = False, timeout: float = None
) -> bool:
# We delay choosing which implementation to use until the first time we're
# called. We could do it at import time, but then we might make the wrong
# decision if someone goes wild with monkeypatching select.poll after
# we're imported.
global wait_for_socket
if _have_working_poll():
wait_for_socket = poll_wait_for_socket
elif hasattr(select, "select"):
wait_for_socket = select_wait_for_socket
else: # Platform-specific: Appengine.
wait_for_socket = null_wait_for_socket
return wait_for_socket(sock, read=read, write=write, timeout=timeout)
Loading

0 comments on commit ad1cfd0

Please sign in to comment.