Skip to content

Commit

Permalink
Add mypy to pre-commit and fix complaints (#356)
Browse files Browse the repository at this point in the history
* Add mypy to pre-commit

* typing fixes

* more fixes

* more fixes

* more fixes

* more fixes

* lint
  • Loading branch information
bdraco authored Feb 11, 2025
1 parent 28a6c14 commit d29742f
Show file tree
Hide file tree
Showing 16 changed files with 104 additions and 66 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@ repos:
rev: v3.0.3
hooks:
- id: prettier
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.14.1
hooks:
- id: mypy
additional_dependencies: []
files: ^((snitun)/.+)?[^/]+\.(py)$
22 changes: 12 additions & 10 deletions snitun/client/client_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ClientPeer:

def __init__(self, snitun_host: str, snitun_port: int | None = None) -> None:
"""Initialize ClientPeer connector."""
self._multiplexer = None
self._multiplexer: Multiplexer | None = None
self._loop = asyncio.get_event_loop()
self._snitun_host = snitun_host
self._snitun_port = snitun_port or 8080
Expand All @@ -37,9 +37,9 @@ def is_connected(self) -> bool:
"""Return true, if a connection exists."""
return self._multiplexer is not None

def wait(self) -> asyncio.Task:
def wait(self) -> asyncio.Future[None]:
"""Block until connection to peer is closed."""
if not self._multiplexer:
if not self._multiplexer or not self._handler_task:
raise RuntimeError("No SniTun connection available")
# Wait until the handler task is done
# as we know the connection is closed
Expand Down Expand Up @@ -137,6 +137,7 @@ async def stop(self) -> None:

async def _stop_handler(self) -> None:
"""Stop the handler."""
assert self._handler_task, "Handler task not started"
self._handler_task.cancel()
try:
await self._handler_task
Expand All @@ -150,20 +151,21 @@ async def _stop_handler(self) -> None:
async def _handler(self) -> None:
"""Wait until connection is closed."""

async def _wait_with_timeout() -> None:
async def _wait_with_timeout(multiplexer: Multiplexer) -> None:
try:
async with asyncio_timeout.timeout(50):
await self._multiplexer.wait()
await multiplexer.wait()
except TimeoutError:
await self._multiplexer.ping()
await multiplexer.ping()

try:
while self._multiplexer.is_connected:
await _wait_with_timeout()
while self._multiplexer and self._multiplexer.is_connected:
await _wait_with_timeout(self._multiplexer)

except MultiplexerTransportError:
pass

finally:
self._multiplexer.shutdown()
self._multiplexer = None
if self._multiplexer:
self._multiplexer.shutdown()
self._multiplexer = None
16 changes: 9 additions & 7 deletions snitun/client/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from __future__ import annotations

import asyncio
from collections.abc import Coroutine
from collections.abc import Callable, Coroutine
from contextlib import suppress
import ipaddress
from ipaddress import IPv4Address
import logging
from typing import Any

Expand All @@ -24,13 +25,14 @@ def __init__(
end_host: str,
end_port: int | None = None,
whitelist: bool = False,
endpoint_connection_error_callback: Coroutine[Any, Any, None] | None = None,
endpoint_connection_error_callback: Callable[[], Coroutine[Any, Any, None]]
| None = None,
) -> None:
"""Initialize Connector."""
self._loop = asyncio.get_event_loop()
self._end_host = end_host
self._end_port = end_port or 443
self._whitelist = set()
self._whitelist: set[IPv4Address] = set()
self._whitelist_enabled = whitelist
self._endpoint_connection_error_callback = endpoint_connection_error_callback

Expand Down Expand Up @@ -99,16 +101,16 @@ async def handler(

# From proxy
if from_endpoint.done():
if from_endpoint.exception():
raise from_endpoint.exception()
if from_endpoint_exc := from_endpoint.exception():
raise from_endpoint_exc

await channel.write(from_endpoint.result())
from_endpoint = None

# From peer
if from_peer.done():
if from_peer.exception():
raise from_peer.exception()
if from_peer_exc := from_peer.exception():
raise from_peer_exc

writer.write(from_peer.result())
from_peer = None
Expand Down
6 changes: 3 additions & 3 deletions snitun/multiplexer/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ class MultiplexerChannel:

def __init__(
self,
output: asyncio.Queue,
output: asyncio.Queue[MultiplexerMessage | None],
ip_address: IPv4Address,
channel_id: MultiplexerChannelId | None = None,
throttling: float | None = None,
) -> None:
"""Initialize Multiplexer Channel."""
self._input: asyncio.Queue[MultiplexerMessage] = asyncio.Queue(8000)
self._input: asyncio.Queue[MultiplexerMessage | None] = asyncio.Queue(8000)
self._output = output
self._id = channel_id or MultiplexerChannelId(os.urandom(16))
self._ip_address = ip_address
Expand Down Expand Up @@ -92,7 +92,7 @@ async def write(self, data: bytes) -> None:
return
await asyncio.sleep(self._throttling)

async def read(self) -> MultiplexerMessage:
async def read(self) -> bytes:
"""Read data from peer."""
if self._closing and self._input.empty():
message = None
Expand Down
23 changes: 14 additions & 9 deletions snitun/multiplexer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Coroutine
from collections.abc import Callable, Coroutine
from contextlib import suppress
import ipaddress
import logging
Expand Down Expand Up @@ -66,15 +66,19 @@ def __init__(
crypto: CryptoTransport,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
new_connections: Coroutine[Any, Any, None] | None = None,
new_connections: Callable[
[Multiplexer, MultiplexerChannel],
Coroutine[Any, Any, None],
]
| None = None,
throttling: int | None = None,
) -> None:
"""Initialize Multiplexer."""
self._crypto = crypto
self._reader = reader
self._writer = writer
self._loop = asyncio.get_event_loop()
self._queue: asyncio.Queue[MultiplexerMessage] = asyncio.Queue(12000)
self._queue: asyncio.Queue[MultiplexerMessage | None] = asyncio.Queue(12000)
self._healthy = asyncio.Event()
self._processing_task = self._loop.create_task(self._runner())
self._channels: dict[MultiplexerChannelId, MultiplexerChannel] = {}
Expand All @@ -86,7 +90,7 @@ def is_connected(self) -> bool:
"""Return True is they is connected."""
return not self._processing_task.done()

def wait(self) -> asyncio.Task:
def wait(self) -> asyncio.Future[None]:
"""Block until the connection is closed.
Return a awaitable object.
Expand Down Expand Up @@ -160,16 +164,17 @@ async def _runner(self) -> None:

# From peer
if from_peer.done():
if from_peer.exception():
raise from_peer.exception()
if from_peer_exc := from_peer.exception():
raise from_peer_exc
await self._read_message(from_peer.result())
from_peer = None

# To peer
if to_peer.done():
if to_peer.exception():
raise to_peer.exception()
self._write_message(to_peer.result())
if to_peer_exc := to_peer.exception():
raise to_peer_exc
if msg := to_peer.result():
self._write_message(msg)
to_peer = None

# Flush buffer
Expand Down
2 changes: 1 addition & 1 deletion snitun/multiplexer/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def bytes(self) -> "bytes":
return self

@cached_property
def hex(self) -> str:
def hex(self) -> str: # type: ignore[override]
"""Return hex representation of the channel ID."""
return binascii.hexlify(self).decode("utf-8")

Expand Down
3 changes: 2 additions & 1 deletion snitun/server/listener_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self._peer_manager = peer_manager
self._host = host
self._port = port or 8080
self._server = None
self._server: asyncio.Server | None = None

async def start(self) -> None:
"""Start peer server."""
Expand All @@ -40,6 +40,7 @@ async def start(self) -> None:

async def stop(self) -> None:
"""Stop peer server."""
assert self._server is not None, "Server not started"
self._server.close()
await self._server.wait_closed()

Expand Down
16 changes: 9 additions & 7 deletions snitun/server/listener_sni.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self._loop = asyncio.get_event_loop()
self._host = host
self._port = port or 443
self._server = None
self._server: asyncio.Server | None = None

async def start(self) -> None:
"""Start Proxy server."""
Expand All @@ -48,6 +48,7 @@ async def start(self) -> None:

async def stop(self) -> None:
"""Stop proxy server."""
assert self._server is not None, "Server not started"
self._server.close()
await self._server.wait_closed()

Expand Down Expand Up @@ -94,9 +95,10 @@ async def handle_connection(
_LOGGER.debug("Hostname %s not connected", hostname)
return
peer = self._peer_manager.get_peer(hostname)

assert peer is not None, "Peer not found"
# Proxy data over mutliplexer to client
_LOGGER.debug("Processing for hostname %s started", hostname)
assert peer.multiplexer is not None, "Multiplexer not initialized"
await self._proxy_peer(peer.multiplexer, client_hello, reader, writer)

finally:
Expand All @@ -114,7 +116,7 @@ async def _proxy_peer(
"""Proxy data between end points."""
transport = writer.transport
try:
ip_address = ipaddress.ip_address(writer.get_extra_info("peername")[0])
ip_address = ipaddress.IPv4Address(writer.get_extra_info("peername")[0])
except (TypeError, AttributeError):
_LOGGER.error("Can't read source IP")
return
Expand Down Expand Up @@ -147,16 +149,16 @@ async def _proxy_peer(

# From proxy
if from_proxy.done():
if from_proxy.exception():
raise from_proxy.exception()
if from_proxy_exc := from_proxy.exception():
raise from_proxy_exc

await channel.write(from_proxy.result())
from_proxy = None

# From peer
if from_peer.done():
if from_peer.exception():
raise from_peer.exception()
if from_peer_exc := from_peer.exception():
raise from_peer_exc

writer.write(from_peer.result())
from_peer = None
Expand Down
5 changes: 2 additions & 3 deletions snitun/server/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import asyncio
from collections.abc import Coroutine
from datetime import UTC, datetime
import hashlib
import logging
Expand Down Expand Up @@ -34,7 +33,7 @@ def __init__(
self._valid = valid
self._throttling = throttling
self._alias = alias or []
self._multiplexer = None
self._multiplexer: Multiplexer | None = None
self._crypto = CryptoTransport(aes_key, aes_iv)

@property
Expand Down Expand Up @@ -111,7 +110,7 @@ async def init_multiplexer_challenge(
throttling=self._throttling,
)

def wait_disconnect(self) -> Coroutine:
def wait_disconnect(self) -> asyncio.Future[None]:
"""Wait until peer is disconnected.
Return a coroutine.
Expand Down
8 changes: 5 additions & 3 deletions snitun/server/peer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ def create_peer(self, fernet_data: bytes) -> Peer:

def add_peer(self, peer: Peer) -> None:
"""Register peer to internal hostname list."""
if self.peer_available(peer.hostname):
if self.peer_available(peer.hostname) and (
multiplexer := self._peers[peer.hostname].multiplexer
):
_LOGGER.warning("Found stale peer connection")
self._peers[peer.hostname].multiplexer.shutdown()
multiplexer.shutdown()

_LOGGER.debug("New peer connection: %s", peer.hostname)
self._peers[peer.hostname] = peer
Expand Down Expand Up @@ -120,7 +122,7 @@ async def close_connections(self, timeout: int = 10) -> None: # noqa: ASYNC109
"""
peers = list(self._peers.values())
for peer in peers:
if peer.is_connected:
if peer.is_connected and peer.multiplexer:
peer.multiplexer.shutdown()

if waiters := [peer.wait_disconnect() for peer in peers]:
Expand Down
Loading

0 comments on commit d29742f

Please sign in to comment.