Skip to content

Commit cf81dab

Browse files
Use a custom ZAP authentication thread
To resolve issues around interpreter shutdown and asyncio
1 parent 109c565 commit cf81dab

File tree

1 file changed

+82
-17
lines changed

1 file changed

+82
-17
lines changed

zprocess/security.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import base64
44
import weakref
55
import os
6+
import threading
67
import ipaddress
78
import zmq
89
import zmq.auth.thread
@@ -24,6 +25,7 @@
2425
else:
2526
from time import monotonic
2627

28+
PYZMQ_VER_MAJOR = int(zmq.__version__.split('.')[0])
2729

2830
_bundle_warning = """zprocess warning: pyzmq is using bundled libzmq, which on Windows
2931
is not built with the cryptography library libsodium. Encryption/decryption will be
@@ -165,9 +167,10 @@ def _configure_curve(self, server):
165167
are a server or not"""
166168
orig_server = self.curve_server
167169
if server:
170+
self.curve_server = True
171+
self.zap_domain = self.context.zap_domain
168172
self.curve_publickey = self.context.server_publickey
169173
self.curve_secretkey = self.context.server_secretkey
170-
self.curve_server = True
171174
else:
172175
self.curve_server = False
173176
self.curve_publickey = self.context.client_publickey
@@ -306,6 +309,72 @@ def recv(self, flags=0, copy=True, track=False):
306309
return msg
307310

308311

312+
class _ThreadAuthenticator:
313+
# We roll our own thread authenticator (just implementing what we need) since a)
314+
# zmq.auth.thread.ThreadAuthenticator uses asyncio, and we do not want to express an
315+
# opinion on the kerfuffle that is Windows asyncio selectors (see:
316+
# https://github.com/zeromq/pyzmq/issues/1423) and impose it on the rest of the
317+
# interpreter, when we're not even using async stuff ourself and b)
318+
# zmq.auth.thread.ThreadAuthenticator spawns a non-daemon thread which is difficult
319+
# to ensure gets shut down at interpreter shutdown, and otherwise holds the
320+
# interpreter open.
321+
def __init__(self, ctx, zap_domain, allowed_clients):
322+
if PYZMQ_VER_MAJOR >= 25:
323+
self.zap_socket = ctx.socket(zmq.REP, socket_class=zmq.Socket)
324+
else:
325+
self.zap_socket = ctx.socket(zmq.REP)
326+
self.zap_socket.linger = 1
327+
self.zap_socket.bind("inproc://zeromq.zap.01")
328+
329+
# Note: we hold a reference to the zap socket, since if the thread crashes, we
330+
# don't want the socket to be cleaned up and closed - that would have zmq
331+
# (stupidly and dangerously) fall back to its default authentication method,
332+
# which is to allow all.
333+
self.thread = threading.Thread(
334+
target=self.run,
335+
args=(self.zap_socket, zap_domain, allowed_clients),
336+
daemon=True,
337+
)
338+
self.started = threading.Event()
339+
self.thread.start()
340+
341+
def run(self, zap_socket, zap_domain, allowed_clients):
342+
VERSION = b'1.0'
343+
MECHANISM = b'CURVE'
344+
while True:
345+
try:
346+
msg = zap_socket.recv_multipart()
347+
except zmq.error.ContextTerminated:
348+
zap_socket.close()
349+
return
350+
version, request_id, domain, address, identity, mechanism = msg[:6]
351+
credentials = msg[6:]
352+
if version != VERSION:
353+
status_code = b"400"
354+
status_text = b"Invalid version"
355+
user_id = b""
356+
elif mechanism != MECHANISM:
357+
status_code = b"400"
358+
status_text = b"Security mechanism not supported"
359+
user_id = b""
360+
elif domain != zap_domain:
361+
status_code = b"400"
362+
status_text = b"Unknown domain"
363+
user_id = b""
364+
else:
365+
key = zmq.utils.z85.encode(credentials[0])
366+
if key in allowed_clients:
367+
status_code = b"200"
368+
status_text = b"OK"
369+
user_id = key
370+
else:
371+
status_code = b"400"
372+
status_text = b"Unknown key"
373+
user_id = b""
374+
response = [VERSION, request_id, status_code, status_text, user_id, b""]
375+
zap_socket.send_multipart(response)
376+
377+
309378
class SecureContext(zmq.Context):
310379
"""A ZeroMQ Context with SecureContext.socket() returning a
311380
SecureSocket(), which can authenticate and communicate securely with all
@@ -315,7 +384,10 @@ class SecureContext(zmq.Context):
315384

316385
_socket_class = SecureSocket
317386
_instances = weakref.WeakValueDictionary()
387+
zap_domain = b"zprocess"
388+
318389
# Dummy class attrs to distinguish from zmq options:
390+
auth = None
319391
secure = False
320392
client_publickey = None
321393
client_secretkey = None
@@ -331,22 +403,15 @@ def __init__(self, io_threads=1, shared_secret=None):
331403
self.client_publickey = zmq.curve_public(self.client_secretkey)
332404
self.server_publickey = zmq.curve_public(self.server_secretkey)
333405

334-
# There are potential reference cycles causing the authentication thread to
335-
# prevent the interpreter from shutting down. In pyzmq <25, the
336-
# authentication thread holds a reference to the context. So we must avoid
337-
# holding a reference to it. In pyzmq 25+, the authentication thread and
338-
# threadauthenticator objects both hold references to each other. So we
339-
# replace one with a weakref to break the cycle.
340-
341-
auth = zmq.auth.thread.ThreadAuthenticator(self)
342-
auth.start()
343-
344-
if auth.thread.authenticator is auth:
345-
auth.thread.authenticator = weakref.proxy(auth)
346-
347-
# Allow only clients who have the client public key:
348-
auth.allow_any = False
349-
auth.certs['*'] = {self.client_publickey: True}
406+
# Note: it is crucial we hold a reference to the authenticator, so that in
407+
# the case the zap authentication thread crashes, the zap socket does not
408+
# get cleaned up and closed - zmq will interpret that situation as us not
409+
# requiring any authentication.
410+
self.auth = _ThreadAuthenticator(
411+
self,
412+
zap_domain=self.zap_domain,
413+
allowed_clients=[self.client_publickey],
414+
)
350415
self.secure = True
351416

352417
@classmethod

0 commit comments

Comments
 (0)