Skip to content

Commit 52b0dec

Browse files
authored
Merge pull request #447 from dylex/main
Add `raw_socket_proxy` to directly proxy websockets to TCP/unix sockets
2 parents 887c3d1 + b068325 commit 52b0dec

File tree

6 files changed

+201
-41
lines changed

6 files changed

+201
-41
lines changed

docs/source/server-process.md

+12
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ Defaults to _True_.
169169

170170
(server-process:callable-arguments)=
171171

172+
### `raw_socket_proxy`
173+
174+
_True_ to proxy only websocket connections into raw stream connections.
175+
_False_ (default) if the proxied server speaks full HTTP.
176+
177+
If _True_, the proxied server is treated a raw TCP (or unix socket) server that
178+
does not use HTTP.
179+
In this mode, only websockets are handled, and messages are sent to the backend
180+
server as raw stream data. This is similar to running a
181+
[websockify](https://github.com/novnc/websockify) wrapper.
182+
All other HTTP requests return 405.
183+
172184
#### Callable arguments
173185

174186
Any time you specify a callable in the config, it can ask for any arguments it needs

jupyter_server_proxy/config.py

+41-41
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from traitlets.config import Configurable
1717

1818
from .handlers import AddSlashHandler, NamedLocalProxyHandler, SuperviseAndProxyHandler
19+
from .rawsocket import RawSocketHandler, SuperviseAndRawSocketHandler
1920

2021
try:
2122
# Traitlets >= 4.3.3
@@ -43,54 +44,56 @@
4344
"request_headers_override",
4445
"rewrite_response",
4546
"update_last_activity",
47+
"raw_socket_proxy",
4648
],
4749
)
4850

4951

50-
def _make_namedproxy_handler(sp: ServerProcess):
51-
class _Proxy(NamedLocalProxyHandler):
52-
def __init__(self, *args, **kwargs):
53-
super().__init__(*args, **kwargs)
54-
self.name = sp.name
55-
self.proxy_base = sp.name
56-
self.absolute_url = sp.absolute_url
57-
self.port = sp.port
58-
self.unix_socket = sp.unix_socket
59-
self.mappath = sp.mappath
60-
self.rewrite_response = sp.rewrite_response
61-
self.update_last_activity = sp.update_last_activity
62-
63-
def get_request_headers_override(self):
64-
return self._realize_rendered_template(sp.request_headers_override)
65-
66-
return _Proxy
67-
68-
69-
def _make_supervisedproxy_handler(sp: ServerProcess):
52+
def _make_proxy_handler(sp: ServerProcess):
7053
"""
71-
Create a SuperviseAndProxyHandler subclass with given parameters
54+
Create an appropriate handler with given parameters
7255
"""
56+
if sp.command:
57+
cls = SuperviseAndRawSocketHandler if sp.raw_socket_proxy else SuperviseAndProxyHandler
58+
args = dict(state={})
59+
elif not (sp.port or isinstance(sp.unix_socket, str)):
60+
warn(
61+
f"Server proxy {sp.name} does not have a command, port "
62+
f"number or unix_socket path. At least one of these is "
63+
f"required."
64+
)
65+
return
66+
else:
67+
cls = RawSocketHandler if sp.raw_socket_proxy else NamedLocalProxyHandler
68+
args = {}
7369

7470
# FIXME: Set 'name' properly
75-
class _Proxy(SuperviseAndProxyHandler):
71+
class _Proxy(cls):
72+
kwargs = args
73+
7674
def __init__(self, *args, **kwargs):
7775
super().__init__(*args, **kwargs)
7876
self.name = sp.name
7977
self.command = sp.command
8078
self.proxy_base = sp.name
8179
self.absolute_url = sp.absolute_url
82-
self.requested_port = sp.port
83-
self.requested_unix_socket = sp.unix_socket
80+
if sp.command:
81+
self.requested_port = sp.port
82+
self.requested_unix_socket = sp.unix_socket
83+
else:
84+
self.port = sp.port
85+
self.unix_socket = sp.unix_socket
8486
self.mappath = sp.mappath
8587
self.rewrite_response = sp.rewrite_response
8688
self.update_last_activity = sp.update_last_activity
8789

88-
def get_env(self):
89-
return self._realize_rendered_template(sp.environment)
90-
9190
def get_request_headers_override(self):
9291
return self._realize_rendered_template(sp.request_headers_override)
9392

93+
# these two methods are only used in supervise classes, but do no harm otherwise
94+
def get_env(self):
95+
return self._realize_rendered_template(sp.environment)
96+
9497
def get_timeout(self):
9598
return sp.timeout
9699

@@ -116,24 +119,14 @@ def make_handlers(base_url, server_processes):
116119
"""
117120
handlers = []
118121
for sp in server_processes:
119-
if sp.command:
120-
handler = _make_supervisedproxy_handler(sp)
121-
kwargs = dict(state={})
122-
else:
123-
if not (sp.port or isinstance(sp.unix_socket, str)):
124-
warn(
125-
f"Server proxy {sp.name} does not have a command, port "
126-
f"number or unix_socket path. At least one of these is "
127-
f"required."
128-
)
129-
continue
130-
handler = _make_namedproxy_handler(sp)
131-
kwargs = {}
122+
handler = _make_proxy_handler(sp)
123+
if not handler:
124+
continue
132125
handlers.append(
133126
(
134127
ujoin(base_url, sp.name, r"(.*)"),
135128
handler,
136-
kwargs,
129+
handler.kwargs
137130
)
138131
)
139132
handlers.append((ujoin(base_url, sp.name), AddSlashHandler))
@@ -169,6 +162,7 @@ def make_server_process(name, server_process_config, serverproxy_config):
169162
update_last_activity=server_process_config.get(
170163
"update_last_activity", True
171164
),
165+
raw_socket_proxy=server_process_config.get("raw_socket_proxy", False),
172166
)
173167

174168

@@ -292,6 +286,12 @@ def cats_only(response, path):
292286
293287
update_last_activity
294288
Will cause the proxy to report activity back to jupyter server.
289+
290+
raw_socket_proxy
291+
Proxy websocket requests as a raw TCP (or unix socket) stream.
292+
In this mode, only websockets are handled, and messages are sent to the backend,
293+
similar to running a websockify layer (https://github.com/novnc/websockify).
294+
All other HTTP requests return 405 (and thus this will also bypass rewrite_response).
295295
""",
296296
config=True,
297297
)

jupyter_server_proxy/rawsocket.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""
2+
A simple translation layer between tornado websockets and asyncio stream
3+
connections.
4+
5+
This provides similar functionality to websockify
6+
(https://github.com/novnc/websockify) without needing an extra proxy hop
7+
or process through with all messages pass for translation.
8+
"""
9+
10+
import asyncio
11+
12+
from .handlers import NamedLocalProxyHandler, SuperviseAndProxyHandler
13+
14+
class RawSocketProtocol(asyncio.Protocol):
15+
"""
16+
A protocol handler for the proxied stream connection.
17+
Sends any received blocks directly as websocket messages.
18+
"""
19+
def __init__(self, handler):
20+
self.handler = handler
21+
22+
def data_received(self, data):
23+
"Send the buffer as a websocket message."
24+
self.handler._record_activity()
25+
# ignore async "semi-synchronous" result, waiting is only needed for control flow and errors
26+
# (see https://github.com/tornadoweb/tornado/blob/bdfc017c66817359158185561cee7878680cd841/tornado/websocket.py#L1073)
27+
self.handler.write_message(data, binary=True)
28+
29+
def connection_lost(self, exc):
30+
"Close the websocket connection."
31+
self.handler.log.info(f"Raw websocket {self.handler.name} connection lost: {exc}")
32+
self.handler.close()
33+
34+
class RawSocketHandler(NamedLocalProxyHandler):
35+
"""
36+
HTTP handler that proxies websocket connections into a backend stream.
37+
All other HTTP requests return 405.
38+
"""
39+
def _create_ws_connection(self, proto: asyncio.BaseProtocol):
40+
"Create the appropriate backend asyncio connection"
41+
loop = asyncio.get_running_loop()
42+
if self.unix_socket is not None:
43+
self.log.info(f"RawSocket {self.name} connecting to {self.unix_socket}")
44+
return loop.create_unix_connection(proto, self.unix_socket)
45+
else:
46+
self.log.info(f"RawSocket {self.name} connecting to port {self.port}")
47+
return loop.create_connection(proto, 'localhost', self.port)
48+
49+
async def proxy(self, port, path):
50+
raise web.HTTPError(405, "this raw_socket_proxy backend only supports websocket connections")
51+
52+
async def proxy_open(self, host, port, proxied_path=""):
53+
"""
54+
Open the backend connection. host and port are ignored (as they are in
55+
the parent for unix sockets) since they are always passed known values.
56+
"""
57+
transp, proto = await self._create_ws_connection(lambda: RawSocketProtocol(self))
58+
self.ws_transp = transp
59+
self.ws_proto = proto
60+
self._record_activity()
61+
self.log.info(f"RawSocket {self.name} connected")
62+
63+
def on_message(self, message):
64+
"Send websocket messages as stream writes, encoding if necessary."
65+
self._record_activity()
66+
if isinstance(message, str):
67+
message = message.encode('utf-8')
68+
self.ws_transp.write(message) # buffered non-blocking. should block (needs new enough tornado)
69+
70+
def on_ping(self, message):
71+
"No-op"
72+
self._record_activity()
73+
74+
def on_close(self):
75+
"Close the backend connection."
76+
self.log.info(f"RawSocket {self.name} connection closed")
77+
if hasattr(self, "ws_transp"):
78+
self.ws_transp.close()
79+
80+
class SuperviseAndRawSocketHandler(SuperviseAndProxyHandler, RawSocketHandler):
81+
async def _http_ready_func(self, p):
82+
# not really HTTP here, just try an empty connection
83+
try:
84+
transp, _ = await self._create_ws_connection(asyncio.Protocol)
85+
except OSError as exc:
86+
self.log.debug(f"RawSocket {self.name} connection check failed: {exc}")
87+
return False
88+
transp.close()
89+
return True

tests/resources/jupyter_server_config.py

+9
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,15 @@ def my_env():
127127
"rewrite_response": [cats_only, dog_to_cat],
128128
},
129129
"python-proxyto54321-no-command": {"port": 54321},
130+
"python-rawsocket-tcp": {
131+
"command": [sys.executable, "./tests/resources/rawsocket.py", "{port}"],
132+
"raw_socket_proxy": True
133+
},
134+
"python-rawsocket-unix": {
135+
"command": [sys.executable, "./tests/resources/rawsocket.py", "{unix_socket}"],
136+
"unix_socket": True,
137+
"raw_socket_proxy": True
138+
},
130139
}
131140

132141
c.ServerProxy.non_service_rewrite_response = hello_to_foo

tests/resources/rawsocket.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env python
2+
3+
import os
4+
import socket
5+
import sys
6+
7+
if len(sys.argv) != 2:
8+
print(f"Usage: {sys.argv[0]} TCPPORT|SOCKPATH")
9+
sys.exit(1)
10+
where = sys.argv[1]
11+
try:
12+
port = int(where)
13+
family = socket.AF_INET
14+
addr = ('localhost', port)
15+
except ValueError:
16+
family = socket.AF_UNIX
17+
addr = where
18+
19+
with socket.create_server(addr, family=family) as serv:
20+
while True:
21+
# only handle a single connection at a time
22+
sock, caddr = serv.accept()
23+
while True:
24+
s = sock.recv(1024)
25+
if not s:
26+
break
27+
sock.send(s.swapcase())
28+
sock.close()

tests/test_proxies.py

+22
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,25 @@ def test_callable_environment_formatting(
469469
PORT, TOKEN = a_server_port_and_token
470470
r = request_get(PORT, "/python-http-callable-env/test", TOKEN)
471471
assert r.code == 200
472+
473+
474+
@pytest.mark.parametrize("rawsocket_type", [
475+
"tcp",
476+
pytest.param(
477+
"unix",
478+
marks=pytest.mark.skipif(
479+
sys.platform == "win32", reason="Unix socket not supported on Windows"
480+
),
481+
),
482+
])
483+
async def test_server_proxy_rawsocket(
484+
rawsocket_type: str,
485+
a_server_port_and_token: Tuple[int, str]
486+
) -> None:
487+
PORT, TOKEN = a_server_port_and_token
488+
url = f"ws://{LOCALHOST}:{PORT}/python-rawsocket-{rawsocket_type}/?token={TOKEN}"
489+
conn = await websocket_connect(url)
490+
for msg in [b"Hello,", b"world!"]:
491+
await conn.write_message(msg)
492+
res = await conn.read_message()
493+
assert res == msg.swapcase()

0 commit comments

Comments
 (0)