Skip to content

Commit

Permalink
Ensure a channel can be deleted multiple times (#363)
Browse files Browse the repository at this point in the history
* Ensure a channel can be deleted multiple times

If the channel get deleted because a failure happened
on the remote, and than a similar failure happens on
the local as a result, the channel delete could happen
twice which would try to send the close message again
but the output queue is already destroyed after it was
sent the first time. This should not be a failure as
we want to be able to delete the channel any time there
is a failure.

* preen

* preen

* preen

* add missing coverage

* Update snitun/multiplexer/core.py

* Update snitun/multiplexer/core.py
  • Loading branch information
bdraco authored Feb 22, 2025
1 parent c6e901d commit f85cc0b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
27 changes: 19 additions & 8 deletions snitun/multiplexer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,10 @@ async def _process_message(self, message: MultiplexerMessage) -> None:
# Close
elif flow_type == CHANNEL_FLOW_CLOSE:
# check if message exists
if message.id not in self._channels:
_LOGGER.debug("Receive close from unknown channel")
return
channel = self._channels.pop(message.id)
self._queue.delete_channel(channel.id)
channel.close()
if channel_ := self._delete_channel_and_queue(message.id):
channel_.close()
else:
_LOGGER.debug("Receive close from unknown channel: %s", message.id)

# Ping
elif flow_type == CHANNEL_FLOW_PING:
Expand Down Expand Up @@ -384,6 +382,12 @@ async def create_channel(

async def delete_channel(self, channel: MultiplexerChannel) -> None:
"""Delete channel from transport."""
if channel.id not in self._channels:
# Make sure the queue is cleaned up if the channel
# is already deleted
self._queue.delete_channel(channel.id)
return

message = channel.init_close()

try:
Expand All @@ -392,5 +396,12 @@ async def delete_channel(self, channel: MultiplexerChannel) -> None:
except TimeoutError:
raise MultiplexerTransportError from None
finally:
self._channels.pop(channel.id, None)
self._queue.delete_channel(channel.id)
self._delete_channel_and_queue(channel.id)

def _delete_channel_and_queue(
self,
channel_id: MultiplexerChannelId,
) -> MultiplexerChannel | None:
"""Delete channel and queue from multiplexer if it exists."""
self._queue.delete_channel(channel_id)
return self._channels.pop(channel_id, None)
57 changes: 57 additions & 0 deletions tests/multiplexer/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from snitun.exceptions import MultiplexerTransportClose, MultiplexerTransportError
from snitun.multiplexer import channel as channel_module, core as core_module
from snitun.multiplexer.channel import MultiplexerChannel
from snitun.multiplexer.core import Multiplexer
from snitun.multiplexer.crypto import CryptoTransport
from snitun.multiplexer.message import (
Expand Down Expand Up @@ -258,6 +259,62 @@ async def test_multiplexer_close_channel(
assert not multiplexer_server._channels


async def test_multiplexer_delete_unknown_channel(
multiplexer_client: Multiplexer,
multiplexer_server: Multiplexer,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test deleting an unknown channel."""
assert not multiplexer_client._channels
assert not multiplexer_server._channels

non_existant_channel = MultiplexerChannel(
multiplexer_server._queue,
ipaddress.IPv4Address("127.0.0.1"),
)
await multiplexer_server._queue.put(
non_existant_channel.id,
non_existant_channel.init_close(),
)
await asyncio.sleep(0.1)

assert not multiplexer_client._channels
assert not multiplexer_server._channels

assert (
f"Receive close from unknown channel: {non_existant_channel.id}" in caplog.text
)


async def test_multiplexer_delete_channel_called_multiple_times(
multiplexer_client: Multiplexer,
multiplexer_server: Multiplexer,
) -> None:
"""Test that channels can be deleted twice."""
assert not multiplexer_client._channels
assert not multiplexer_server._channels

channel = await multiplexer_client.create_channel(IP_ADDR, lambda _: None)
await asyncio.sleep(0.1)

assert multiplexer_client._channels
assert multiplexer_server._channels

assert multiplexer_client._channels[channel.id]
assert multiplexer_server._channels[channel.id]
assert multiplexer_client._channels[channel.id].ip_address == IP_ADDR
assert multiplexer_server._channels[channel.id].ip_address == IP_ADDR

await multiplexer_client.delete_channel(channel)
assert not multiplexer_client._channels

await multiplexer_client.delete_channel(channel)
assert not multiplexer_client._channels
await asyncio.sleep(0.1)

assert not multiplexer_server._channels


async def test_multiplexer_close_channel_full(multiplexer_client: Multiplexer) -> None:
"""Test that channels are nice removed but peer error is available."""
assert not multiplexer_client._channels
Expand Down

0 comments on commit f85cc0b

Please sign in to comment.