diff --git a/setup.py b/setup.py index 1e9a0dc7..aaa79205 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -VERSION = "0.36.0" +VERSION = "0.36.1" setup( name="snitun", diff --git a/snitun/server/peer.py b/snitun/server/peer.py index 30ebc9ad..bdfba0f4 100644 --- a/snitun/server/peer.py +++ b/snitun/server/peer.py @@ -45,6 +45,11 @@ def alias(self) -> List[str]: """Return the alias.""" return self._alias + @property + def all_hostnames(self) -> List[str]: + """Return a list of the base hostname and any alias.""" + return [self._hostname, *self._alias] + @property def is_connected(self) -> bool: """Return True if we are connected to peer.""" diff --git a/snitun/server/peer_manager.py b/snitun/server/peer_manager.py index f93fd07a..538e5d69 100644 --- a/snitun/server/peer_manager.py +++ b/snitun/server/peer_manager.py @@ -89,9 +89,8 @@ def remove_peer(self, peer: Peer) -> None: if self._peers.get(peer.hostname) != peer: return _LOGGER.debug("Close peer connection: %s", peer.hostname) - self._peers.pop(peer.hostname) - for alias in peer.alias: - self._peers.pop(alias, None) + for hostname in peer.all_hostnames: + self._peers.pop(hostname, None) if self._event_callback: self._loop.call_soon( diff --git a/snitun/server/worker.py b/snitun/server/worker.py index 75911c8c..77299cf8 100644 --- a/snitun/server/worker.py +++ b/snitun/server/worker.py @@ -3,7 +3,7 @@ import logging from multiprocessing import Process, Manager, Queue from threading import Thread -from typing import Dict, Optional, List +from typing import TYPE_CHECKING, Dict, Optional, List from socket import socket from .listener_peer import PeerListener @@ -13,6 +13,9 @@ _LOGGER = logging.getLogger(__name__) +if TYPE_CHECKING: + from multiprocessing.managers import SyncManager + class ServerWorker(Process): """Worker for multiplexer.""" @@ -35,14 +38,15 @@ def __init__( self._loop: Optional[asyncio.BaseEventLoop] = None # Communication between Parent/Child - self._manager: Manager = Manager() + self._manager: SyncManager = Manager() self._new: Queue = self._manager.Queue() self._sync: Dict[str, None] = self._manager.dict() + self._peer_count = self._manager.Value("peer_count", 0) @property def peer_size(self) -> int: """Return amount of managed peers.""" - return len(self._sync) + return self._peer_count.value def is_responsible_peer(self, sni: str) -> bool: """Return True if worker is responsible for this peer domain.""" @@ -61,9 +65,13 @@ async def _async_init(self) -> None: def _event_stream(self, peer: Peer, event: PeerManagerEvent) -> None: """Event stream peer connection data.""" if event == PeerManagerEvent.CONNECTED: - self._sync[peer.hostname] = None + self._peer_count.set(self._peer_count.value + 1) + for hostname in peer.all_hostnames: + self._sync[hostname] = None else: - self._sync.pop(peer.hostname, None) + self._peer_count.set(self._peer_count.value - 1) + for hostname in peer.all_hostnames: + self._sync.pop(hostname, None) def shutdown(self) -> None: """Shutdown child process.""" diff --git a/tests/server/test_worker.py b/tests/server/test_worker.py index 022ce46e..6c6598e3 100644 --- a/tests/server/test_worker.py +++ b/tests/server/test_worker.py @@ -89,7 +89,10 @@ def test_sni_connection( aes_key = os.urandom(32) aes_iv = os.urandom(16) hostname = "localhost" - fernet_token = create_peer_config(valid.timestamp(), hostname, aes_key, aes_iv) + alias = ["localhost.custom"] + fernet_token = create_peer_config( + valid.timestamp(), hostname, aes_key, aes_iv, alias=alias + ) worker.start() crypto = CryptoTransport(aes_key, aes_iv) @@ -102,6 +105,8 @@ def test_sni_connection( time.sleep(1) assert worker.is_responsible_peer(hostname) + for entry in alias: + assert worker.is_responsible_peer(entry) worker.handover_connection(test_server_sync[1], TLS_1_2, hostname) assert len(test_client_sync.recv(1048)) == 32