Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle alias in the worker #203

Merged
merged 2 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

VERSION = "0.36.0"
VERSION = "0.36.1"

setup(
name="snitun",
Expand Down
5 changes: 5 additions & 0 deletions snitun/server/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def alias(self) -> List[str]:
"""Return the alias."""
return self._alias

@property
def all_hostnames(self) -> List[str]:
"""Return a list of the base hostname and any alias."""
return [self._hostname, *self._alias]

@property
def is_connected(self) -> bool:
"""Return True if we are connected to peer."""
Expand Down
5 changes: 2 additions & 3 deletions snitun/server/peer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def remove_peer(self, peer: Peer) -> None:
if self._peers.get(peer.hostname) != peer:
return
_LOGGER.debug("Close peer connection: %s", peer.hostname)
self._peers.pop(peer.hostname)
for alias in peer.alias:
self._peers.pop(alias, None)
for hostname in peer.all_hostnames:
self._peers.pop(hostname, None)

if self._event_callback:
self._loop.call_soon(
Expand Down
18 changes: 13 additions & 5 deletions snitun/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from multiprocessing import Process, Manager, Queue
from threading import Thread
from typing import Dict, Optional, List
from typing import TYPE_CHECKING, Dict, Optional, List
from socket import socket

from .listener_peer import PeerListener
Expand All @@ -13,6 +13,9 @@

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from multiprocessing.managers import SyncManager


class ServerWorker(Process):
"""Worker for multiplexer."""
Expand All @@ -35,14 +38,15 @@ def __init__(
self._loop: Optional[asyncio.BaseEventLoop] = None

# Communication between Parent/Child
self._manager: Manager = Manager()
self._manager: SyncManager = Manager()
self._new: Queue = self._manager.Queue()
self._sync: Dict[str, None] = self._manager.dict()
self._peer_count = self._manager.Value("peer_count", 0)

@property
def peer_size(self) -> int:
"""Return amount of managed peers."""
return len(self._sync)
return self._peer_count.value

def is_responsible_peer(self, sni: str) -> bool:
"""Return True if worker is responsible for this peer domain."""
Expand All @@ -61,9 +65,13 @@ async def _async_init(self) -> None:
def _event_stream(self, peer: Peer, event: PeerManagerEvent) -> None:
"""Event stream peer connection data."""
if event == PeerManagerEvent.CONNECTED:
self._sync[peer.hostname] = None
self._peer_count.set(self._peer_count.value + 1)
for hostname in peer.all_hostnames:
self._sync[hostname] = None
else:
self._sync.pop(peer.hostname, None)
self._peer_count.set(self._peer_count.value - 1)
for hostname in peer.all_hostnames:
self._sync.pop(hostname, None)

def shutdown(self) -> None:
"""Shutdown child process."""
Expand Down
7 changes: 6 additions & 1 deletion tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def test_sni_connection(
aes_key = os.urandom(32)
aes_iv = os.urandom(16)
hostname = "localhost"
fernet_token = create_peer_config(valid.timestamp(), hostname, aes_key, aes_iv)
alias = ["localhost.custom"]
fernet_token = create_peer_config(
valid.timestamp(), hostname, aes_key, aes_iv, alias=alias
)

worker.start()
crypto = CryptoTransport(aes_key, aes_iv)
Expand All @@ -102,6 +105,8 @@ def test_sni_connection(

time.sleep(1)
assert worker.is_responsible_peer(hostname)
for entry in alias:
assert worker.is_responsible_peer(entry)

worker.handover_connection(test_server_sync[1], TLS_1_2, hostname)
assert len(test_client_sync.recv(1048)) == 32
Expand Down