Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap channel in an asyncio.Transport to eliminate loop back connection PoC #303

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
252 changes: 184 additions & 68 deletions snitun/client/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,187 @@
from contextlib import suppress
import ipaddress
import logging
from typing import Any
from typing import Any, Callable
from asyncio import Transport, BufferedProtocol
from aiohttp.web import RequestHandler

from ssl import SSLContext, SSLError
from ..exceptions import MultiplexerTransportClose, MultiplexerTransportError
from ..multiplexer.channel import MultiplexerChannel
from ..multiplexer.core import Multiplexer

_LOGGER = logging.getLogger(__name__)


class ChannelTransport(Transport):
"""An asyncio.Transport implementation for multiplexer channel."""

_start_tls_compatible = True

def __init__(
self, loop: asyncio.AbstractEventLoop, channel: MultiplexerChannel
) -> None:
"""Initialize ChannelTransport."""
self._channel = channel
self._loop = loop
self._protocol: asyncio.Protocol | None = None
self._pause_future: asyncio.Future[None] | None = None
super().__init__(extra={"peername": (str(channel.ip_address), 0)})

def get_protocol(self) -> asyncio.Protocol:
"""Return the protocol."""
return self._protocol

def set_protocol(self, protocol: asyncio.Protocol) -> None:
"""Set the protocol."""
if not isinstance(protocol, BufferedProtocol):
raise ValueError("Protocol must be a BufferedProtocol")
self._protocol = protocol

def is_closing(self) -> bool:
"""Return True if the transport is closing or closed."""
return self._channel.closing

def close(self) -> None:
"""Close the underlying channel."""
self._channel.close()
self._release_pause_future()

def write(self, data: bytes) -> None:
"""Write data to the channel."""
if not self._channel.closing:
self._channel.write_no_wait(data)

async def start(self) -> None:
"""Start reading from the channel.

As a future improvement, it would be a bit more efficient to
have the channel call this as a callback from channel.message_transport.
"""
while True:
if self._pause_future:
await self._pause_future

try:
from_peer = await self._channel.read()
except MultiplexerTransportClose:
raise
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure connect lost will always get called because close on the finally block where this task is awaited but it should get a test

except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
self._fatal_error(exc, "Fatal error: channel.read() call failed.")

peer_payload_len = len(from_peer)
try:
buf = self._protocol.get_buffer(-1)
if not (available_len := len(buf)):
raise RuntimeError("get_buffer() returned an empty buffer")
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
self._fatal_error(
exc, "Fatal error: protocol.get_buffer() call failed."
)
return

if available_len < peer_payload_len:
self._fatal_error(
RuntimeError(
"Available buffer is %s, need %s",
available_len,
peer_payload_len,
),
f"Fatal error: out of buffer need {peer_payload_len} bytes but only have {available_len} bytes",
)
return

try:
buf[:peer_payload_len] = from_peer
except (BlockingIOError, InterruptedError):
return
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
self._fatal_error(exc, "Fatal read error on socket transport")
return

try:
self._protocol.buffer_updated(peer_payload_len)
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
self._fatal_error(
exc, "Fatal error: protocol.buffer_updated() call failed."
)

def _force_close(self, exc: Exception) -> None:
"""Force close the transport."""
self._channel.close()
self._release_pause_future()
if self._protocol is not None:
self._loop.call_soon(self._protocol.connection_lost, exc)

def _fatal_error(self, exc: Exception, message: str) -> None:
"""Handle a fatal error."""
self._loop.call_exception_handler(
{
"message": message,
"exception": exc,
"transport": self,
"protocol": self._protocol,
}
)
self._force_close(exc)

def is_reading(self) -> bool:
"""Return True if the transport is receiving."""
return self._pause_future is not None

def pause_reading(self) -> None:
"""Pause the receiving end.

No data will be passed to the protocol's data_received()
method until resume_reading() is called.
"""
if self._pause_future is not None:
return
self._pause_future = self._loop.create_future()

def resume_reading(self) -> None:
"""Resume the receiving end.

Data received will once again be passed to the protocol's
data_received() method.
"""
self._release_pause_future()

def _release_pause_future(self) -> None:
"""Release the pause future, if it exists.

This will ensure that start can continue processing data.
"""
if self._pause_future is not None and not self._pause_future.done():
self._pause_future.set_result(None)
self._pause_future = None


class Connector:
"""Connector to end resource."""

def __init__(
self,
end_host: str,
end_port: int | None=None,
whitelist: bool=False,
protocol_factory: Callable[[], RequestHandler],
ssl_context: SSLContext,
whitelist: bool = False,
endpoint_connection_error_callback: Coroutine[Any, Any, None] | None = None,
) -> None:
"""Initialize Connector."""
self._loop = asyncio.get_event_loop()
self._end_host = end_host
self._end_port = end_port or 443
self._whitelist = set()
self._whitelist_enabled = whitelist
self._endpoint_connection_error_callback = endpoint_connection_error_callback
self._protocol_factory = protocol_factory
self._ssl_context = ssl_context

@property
def whitelist(self) -> set:
Expand All @@ -45,68 +200,44 @@ def _whitelist_policy(self, ip_address: ipaddress.IPv4Address) -> bool:
return True

async def handler(
self, multiplexer: Multiplexer, channel: MultiplexerChannel,
self, multiplexer: Multiplexer, channel: MultiplexerChannel
) -> None:
"""Handle new connection from SNIProxy."""
from_endpoint = None
from_peer = None

_LOGGER.debug(
"Receive from %s a request for %s", channel.ip_address, self._end_host,
)
_LOGGER.debug("Receive from %s a request for %s", channel.ip_address)

# Check policy
if not self._whitelist_policy(channel.ip_address):
_LOGGER.warning("Block request from %s per policy", channel.ip_address)
await multiplexer.delete_channel(channel)
return

transport = ChannelTransport(self._loop, channel)
request_handler = self._protocol_factory()
# Performance: We could avoid the task here if
# channel.message_transport feed the protocol directly, i.e.
# called the code in the loop in this task and would only queue
# the data if the protocol is paused.
transport_reader_task = asyncio.create_task(transport.start())
# Open connection to endpoint
try:
reader, writer = await asyncio.open_connection(
host=self._end_host, port=self._end_port,
)
except OSError:
_LOGGER.error(
"Can't connect to endpoint %s:%s", self._end_host, self._end_port,
new_transport = await self._loop.start_tls(
transport, request_handler, self._ssl_context, server_side=True
)
except (OSError, SSLError):
# This can can be just about any error, but mostly likely it's a TLS error
# or the connection gets dropped in the middle of the handshake
_LOGGER.debug("Can't start TLS for %s", channel.id, exc_info=True)
transport_reader_task.cancel()
with suppress(asyncio.CancelledError, Exception):
await transport_reader_task
await multiplexer.delete_channel(channel)
if self._endpoint_connection_error_callback:
await self._endpoint_connection_error_callback()
return

try:
# Process stream from multiplexer
while not writer.transport.is_closing():
if not from_endpoint:
from_endpoint = self._loop.create_task(reader.read(4096))
if not from_peer:
from_peer = self._loop.create_task(channel.read())

# Wait until data need to be processed
await asyncio.wait(
[from_endpoint, from_peer], return_when=asyncio.FIRST_COMPLETED,
)

# From proxy
if from_endpoint.done():
if from_endpoint.exception():
raise from_endpoint.exception()

await channel.write(from_endpoint.result())
from_endpoint = None

# From peer
if from_peer.done():
if from_peer.exception():
raise from_peer.exception()

writer.write(from_peer.result())
from_peer = None

# Flush buffer
await writer.drain()
request_handler.connection_made(new_transport)
_LOGGER.info("Connected peer: %s", new_transport.get_extra_info("peername"))

try:
await transport_reader_task
except (MultiplexerTransportError, OSError, RuntimeError):
_LOGGER.debug("Transport closed by endpoint for %s", channel.id)
with suppress(MultiplexerTransportError):
Expand All @@ -116,19 +247,4 @@ async def handler(
_LOGGER.debug("Peer close connection for %s", channel.id)

finally:
# Cleanup peer reader
if from_peer:
if not from_peer.done():
from_peer.cancel()
else:
# Avoid exception was never retrieved
from_peer.exception()

# Cleanup endpoint reader
if from_endpoint and not from_endpoint.done():
from_endpoint.cancel()

# Close Transport
if not writer.transport.is_closing():
with suppress(OSError):
writer.close()
new_transport.close()
16 changes: 16 additions & 0 deletions snitun/multiplexer/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ def close(self) -> None:
with suppress(asyncio.QueueFull):
self._input.put_nowait(None)

def write_no_wait(self, data: bytes) -> None:
"""Send data to peer."""
if not data:
raise MultiplexerTransportError
if self._closing:
raise MultiplexerTransportClose

# Create message
message = MultiplexerMessage(self._id, CHANNEL_FLOW_DATA, data)

try:
self._output.put_nowait(message)
except asyncio.QueueFull:
_LOGGER.debug("Can't write to peer transport")
raise MultiplexerTransportError from None

async def write(self, data: bytes) -> None:
"""Send data to peer."""
if not data:
Expand Down
Loading
Loading