Skip to content

Commit

Permalink
Fix signal handling in SubprocessManager
Browse files Browse the repository at this point in the history
  • Loading branch information
filipcacky committed Sep 17, 2024
1 parent a85ee20 commit 8fe1164
Showing 1 changed file with 102 additions and 25 deletions.
127 changes: 102 additions & 25 deletions metaflow/runner/subprocess_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,58 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple


def kill_process_and_descendants(pid, termination_timeout):
def send_signals(pid, signal):
# TODO: there's a race condition that new descendants might
# spawn b/w the invocations of 'pkill' and 'kill'.
# Needs to be fixed in future.
try:
subprocess.check_call(["pkill", "-TERM", "-P", str(pid)])
subprocess.check_call(["kill", "-TERM", str(pid)])
subprocess.call(["pkill", signal, "-P", str(pid)])
subprocess.check_call(["kill", signal, str(pid)])
except subprocess.CalledProcessError:
pass


def kill_process_and_descendants(pid, termination_timeout):
send_signals(pid, "-TERM")

time.sleep(termination_timeout)

try:
subprocess.check_call(["pkill", "-KILL", "-P", str(pid)])
subprocess.check_call(["kill", "-KILL", str(pid)])
except subprocess.CalledProcessError:
pass
send_signals(pid, "-KILL")


def kill_processes_and_descendants(pids, termination_timeout):
for pid in pids:
send_signals(pid, "-TERM")

time.sleep(termination_timeout)

for pid in pids:
send_signals(pid, "-KILL")


async def async_send_signals(pids, signal):
pkill_processes = [
await asyncio.create_subprocess_exec("pkill", signal, "-P", str(pid))
for pid in pids
]

for proc in pkill_processes:
await proc.wait()

kill_processes = [
await asyncio.create_subprocess_exec("kill", signal, str(pid)) for pid in pids
]

for proc in kill_processes:
await proc.wait()


async def async_kill_processes_and_descendants(pids, termination_timeout):
await async_send_signals(pids, "-TERM")

await asyncio.sleep(termination_timeout)

await async_send_signals(pids, "-KILL")


class LogReadTimeoutError(Exception):
Expand All @@ -42,6 +77,18 @@ class SubprocessManager(object):
def __init__(self):
self.commands: Dict[int, CommandManager] = {}

try:

async def handle_sigint():
await self._async_handle_sigint()

asyncio.get_running_loop().add_signal_handler(
signal.SIGINT, lambda: asyncio.create_task(handle_sigint())
)

except RuntimeError:
signal.signal(signal.SIGINT, self._handle_sigint)

async def __aenter__(self) -> "SubprocessManager":
return self

Expand Down Expand Up @@ -81,8 +128,14 @@ def run_command(
"""

command_obj = CommandManager(command, env, cwd)
pid = command_obj.run(show_output=show_output)
pid = command_obj.run(show_output=show_output, wait=False)

self.commands[pid] = command_obj

command_obj.process.wait()
command_obj.stdout_thread.join()
command_obj.stderr_thread.join()

return pid

async def async_run_command(
Expand Down Expand Up @@ -138,6 +191,30 @@ def cleanup(self) -> None:
for v in self.commands.values():
v.cleanup()

def kill(self, termination_timeout: float = 5):
"""
Kill all managed subprocesses and their descendants.
Parameters
----------
termination_timeout : float, default 5
The time to wait after sending a SIGTERM to a subprocess and its descendants
before sending a SIGKILL.
"""
pids = [v.process.pid for v in self.commands.values()]

kill_processes_and_descendants(
pids,
termination_timeout,
)

def _handle_sigint(self, signum, frame):
self.kill()

async def _async_handle_sigint(self):
pids = [v.process.pid for v in self.commands.values()]
await async_kill_processes_and_descendants(pids, 5)


class CommandManager(object):
"""A manager for an individual subprocess."""
Expand Down Expand Up @@ -169,11 +246,11 @@ def __init__(
self.cwd = cwd if cwd is not None else os.getcwd()

self.process = None
self.stdout_thread = None
self.stderr_thread = None
self.run_called: bool = False
self.log_files: Dict[str, str] = {}

signal.signal(signal.SIGINT, self._handle_sigint)

async def __aenter__(self) -> "CommandManager":
return self

Expand Down Expand Up @@ -221,19 +298,22 @@ async def wait(
"within %s seconds." % (self.process.pid, command_string, timeout)
)

def run(self, show_output: bool = False):
def run(self, show_output: bool = False, wait: bool = True) -> int:
"""
Run the subprocess synchronously. This can only be called once.
This also waits on the process implicitly.
Parameters
----------
show_output : bool, default False
Suppress the 'stdout' and 'stderr' to the console by default.
They can be accessed later by reading the files present in:
- self.log_files["stdout"]
- self.log_files["stderr"]
wait : bool, default True
Wait for the process to finish before returning.
If false, the process will run in the background. You can then wait on
the process (using `wait`) or kill it (using `kill`).
Log forwarding threads `stdout_thread` and `stderr_thread` should be joined.
"""

if not self.run_called:
Expand Down Expand Up @@ -265,22 +345,22 @@ def stream_to_stdout_and_file(pipe, log_file):

self.run_called = True

stdout_thread = threading.Thread(
self.stdout_thread = threading.Thread(
target=stream_to_stdout_and_file,
args=(self.process.stdout, stdout_logfile),
)
stderr_thread = threading.Thread(
self.stderr_thread = threading.Thread(
target=stream_to_stdout_and_file,
args=(self.process.stderr, stderr_logfile),
)

stdout_thread.start()
stderr_thread.start()

self.process.wait()
self.stdout_thread.start()
self.stderr_thread.start()

stdout_thread.join()
stderr_thread.join()
if wait:
self.process.wait()
self.stdout_thread.join()
self.stderr_thread.join()

return self.process.pid
except Exception as e:
Expand Down Expand Up @@ -457,9 +537,6 @@ async def kill(self, termination_timeout: float = 5):
else:
print("No process to kill.")

def _handle_sigint(self, signum, frame):
asyncio.create_task(self.kill())


async def main():
flow_file = "../try.py"
Expand Down

0 comments on commit 8fe1164

Please sign in to comment.