Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 62 additions & 15 deletions salt/transport/zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def zmq_device(self):
"""
self.__setup_signals()
context = zmq.Context(self.opts["worker_threads"])

# Prepare the zeromq sockets
self.uri = "tcp://{interface}:{ret_port}".format(**self.opts)
self.clients = context.socket(zmq.ROUTER)
Expand All @@ -431,7 +432,7 @@ def zmq_device(self):
self.clients.setsockopt(zmq.IPV4ONLY, 0)
self.clients.setsockopt(zmq.BACKLOG, self.opts.get("zmq_backlog", 1000))
self._start_zmq_monitor()
self.workers = context.socket(zmq.DEALER)
self.workers = context.socket(zmq.ROUTER)
self.workers.setsockopt(zmq.LINGER, -1)

if self.opts["mworker_queue_niceness"] and not salt.utils.platform.is_windows():
Expand All @@ -458,11 +459,20 @@ def zmq_device(self):
if self.opts.get("ipc_mode", "") != "tcp":
os.chmod(os.path.join(self.opts["sock_dir"], "workers.ipc"), 0o600)

self.poller = zmq.Poller()
self.poller.register(self.workers, zmq.POLLIN)
self.workers_available = []

self.profile_last = 0
self.profile_interval = (
self.opts["worker_threads"] / self.opts.get("profile_level", 4) or 1
)

while True:
if self.clients.closed or self.workers.closed:
break
try:
zmq.device(zmq.QUEUE, self.clients, self.workers)
self.dispatch()
except zmq.ZMQError as exc:
if exc.errno == errno.EINTR:
continue
Expand All @@ -471,6 +481,53 @@ def zmq_device(self):
break
context.term()

def dispatch(self):
sockets_ready = dict(self.poller.poll())

if self.workers in sockets_ready:
message = self.workers.recv_multipart()
worker, _, client = message[:3]

if not self.workers_available:
self.poller.register(self.clients, zmq.POLLIN)
self.workers_available.append(worker)

if (
abs(len(self.workers_available) - self.profile_last)
>= self.profile_interval
):
self.profile_last = len(self.workers_available)
log.profile(
"Workers available: %d / %d (+/- %d)",
self.profile_last,
self.opts["worker_threads"],
self.profile_interval,
)

if client != b"":
self.clients.send_multipart(message[2:])

if self.clients in sockets_ready:
message = self.clients.recv_multipart()

worker = self.workers_available.pop(0)
if not self.workers_available:
self.poller.unregister(self.clients)

if (
abs(len(self.workers_available) - self.profile_last)
>= self.profile_interval
):
self.profile_last = len(self.workers_available)
log.profile(
"Workers available: %d / %d (+/- %d)",
self.profile_last,
self.opts["worker_threads"],
self.profile_interval,
)

self.workers.send_multipart([worker, b""] + message)

def close(self):
"""
Cleanly shutdown the router socket
Expand All @@ -490,8 +547,6 @@ def close(self):
self.clients.close()
if hasattr(self, "workers") and self.workers.closed is False:
self.workers.close()
if hasattr(self, "stream"):
self.stream.close()
if hasattr(self, "_socket") and self._socket.closed is False:
self._socket.close()
if hasattr(self, "context") and self.context.closed is False:
Expand Down Expand Up @@ -533,9 +588,8 @@ def post_fork(self, message_handler, io_loop):
they are picked up off the wire
:param IOLoop io_loop: An instance of a Tornado IOLoop, to handle event scheduling
"""
# context = zmq.Context(1)
self.context = zmq.asyncio.Context(1)
self._socket = self.context.socket(zmq.REP)
self._socket = self.context.socket(zmq.REQ)
# Linger -1 means we'll never discard messages.
self._socket.setsockopt(zmq.LINGER, -1)
self._start_zmq_monitor()
Expand Down Expand Up @@ -580,16 +634,13 @@ async def request_handler(self):
)
continue

async def handle_message(self, stream, payload):
async def handle_message(self, stream, message):
try:
payload = self.decode_payload(payload)
payload = salt.payload.loads(message[-1])
except salt.exceptions.SaltDeserializationError:
return {"msg": "bad load"}
return await self.message_handler(payload)

def encode_payload(self, payload):
return salt.payload.dumps(payload)

def __setup_signals(self):
signal.signal(signal.SIGINT, self._handle_signals)
signal.signal(signal.SIGTERM, self._handle_signals)
Expand All @@ -605,10 +656,6 @@ def _handle_signals(self, signum, sigframe):
self.close()
sys.exit(salt.defaults.exitcodes.EX_OK)

def decode_payload(self, payload):
payload = salt.payload.loads(payload)
return payload


def _set_tcp_keepalive(zmq_socket, opts):
"""
Expand Down
Loading