diff --git a/src/zenlib/namespace/__init__.py b/src/zenlib/namespace/__init__.py new file mode 100644 index 0000000..f6bde34 --- /dev/null +++ b/src/zenlib/namespace/__init__.py @@ -0,0 +1,11 @@ +from os import environ +if not environ.get("CI"): + from .nsexec import nsexec + from .namespace import get_id_map +else: + nsexec, get_id_map = None, None + +__all__ = [ + "nsexec", + "get_id_map", +] diff --git a/src/zenlib/namespace/namespace.py b/src/zenlib/namespace/namespace.py new file mode 100644 index 0000000..7f4b08d --- /dev/null +++ b/src/zenlib/namespace/namespace.py @@ -0,0 +1,31 @@ +from os import CLONE_NEWNS, CLONE_NEWUSER, getlogin, unshare +from subprocess import CalledProcessError, run + + +def unshare_namespace(): + unshare(CLONE_NEWNS | CLONE_NEWUSER) + + +def get_id_map(username=None, id_type="uid"): + username = username or getlogin() + if id_type not in ("uid", "gid"): + raise ValueError("id_type must be 'uid' or 'gid'") + + with open(f"/etc/sub{id_type}") as f: + for line in f: + if line.startswith(f"{username}:"): + start, count = line.strip().split(":")[1:] + return int(start), int(count) + raise ValueError(f"User {username} not found in /etc/sub{id_type}") + + +def new_id_map(id_type, pid, id, nsid, count=1, *args, failures=0): + if id_type not in ("uid", "gid"): + raise ValueError("id_type must be 'uid' or 'gid") + cmd_args = [f"new{id_type}map", str(pid), str(id), str(nsid), str(count), *map(str, args)] + try: + return run(cmd_args, check=True) + except CalledProcessError as e: + if failures > 5: + raise e + new_id_map(id_type, pid, id, nsid, count, *args, failures=failures + 1) diff --git a/src/zenlib/namespace/namespace_process.py b/src/zenlib/namespace/namespace_process.py new file mode 100644 index 0000000..3f07b99 --- /dev/null +++ b/src/zenlib/namespace/namespace_process.py @@ -0,0 +1,44 @@ +from multiprocessing import Event, Pipe, Process, Queue +from os import chroot, getgid, getlogin, getuid, setgid, setuid + +from .namespace import get_id_map, new_id_map, unshare_namespace + + +class NamespaceProcess(Process): + """Like process, but runs in a new namespace. + Puts the target return value in a queue, and any exceptions in a pipe. + """ + + def __init__(self, target=None, args=None, kwargs=None, **ekwargs): + self.target_root = kwargs.pop("target_root", "/") + namespace_user = kwargs.pop("namespace_user", getlogin()) + self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid") + self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid") + self.orig_uid = getuid() + self.orig_gid = getgid() + self.uidmapped = Event() + self.completed = Event() + self.exception_recv, self.exception_send = Pipe() + self.function_queue = Queue() + super().__init__(target=target, args=args, kwargs=kwargs, **ekwargs) + + def map_ids(self): + new_id_map("uid", self.pid, 0, self.orig_uid, 1, 1, self.subuid_start, self.subuid_count) + new_id_map("gid", self.pid, 0, self.orig_gid, 1, 1, self.subgid_start, self.subgid_count) + + def map_unshare_uids(self): + self.start() + self.map_ids() + self.uidmapped.set() + + def run(self): + unshare_namespace() + self.uidmapped.wait() + setuid(0) + setgid(0) + chroot(self.target_root) + try: + self.function_queue.put(self._target(*self._args, **self._kwargs)) + except Exception as e: + self.exception_send.send(e) + self.completed.set() diff --git a/src/zenlib/namespace/nsexec.py b/src/zenlib/namespace/nsexec.py new file mode 100644 index 0000000..508a563 --- /dev/null +++ b/src/zenlib/namespace/nsexec.py @@ -0,0 +1,21 @@ +from subprocess import CalledProcessError +from .namespace_process import NamespaceProcess + + +def nsexec(target, *args, **kwargs): + p = NamespaceProcess(target=target, args=args, kwargs=kwargs) + try: + p.map_unshare_uids() + except CalledProcessError as e: + print(f"Error: {e}") + p.terminate() + raise e + + p.completed.wait() + if p.exception_recv.poll(): + p.terminate() + raise p.exception_recv.recv() + + ret = p.function_queue.get() + p.terminate() + return ret diff --git a/src/zenlib/util/__init__.py b/src/zenlib/util/__init__.py index e867774..bbeaf64 100644 --- a/src/zenlib/util/__init__.py +++ b/src/zenlib/util/__init__.py @@ -1,22 +1,13 @@ -from os import environ - from .colorize import colorize from .dict_check import contains, unset from .handle_plural import handle_plural from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger, process_args from .merge_class import merge_class -if not environ.get("CI"): - from .namespace import nsexec, nschroot, get_id_map -else: - nsexec, nschroot, get_id_map = None, None, None from .pretty_print import pretty_print from .replace_file_line import replace_file_line from ..types import NoDupFlatList __all__ = [ - "nsexec", - "nschroot", - "get_id_map", "handle_plural", "colorize", "NoDupFlatList", diff --git a/src/zenlib/util/namespace.py b/src/zenlib/util/namespace.py deleted file mode 100644 index 36a52ca..0000000 --- a/src/zenlib/util/namespace.py +++ /dev/null @@ -1,110 +0,0 @@ -from multiprocessing import Event, Pipe, Process, Queue -from os import CLONE_NEWNS, CLONE_NEWUSER, chroot, getlogin, setgid, setuid, getuid, getgid, unshare -from subprocess import CalledProcessError, run - - -def unshare_namespace(): - unshare(CLONE_NEWNS | CLONE_NEWUSER) - - -def get_id_map(username=None, id_type="uid"): - username = username or getlogin() - if id_type not in ("uid", "gid"): - raise ValueError("id_type must be 'uid' or 'gid'") - - with open(f"/etc/sub{id_type}") as f: - for line in f: - if line.startswith(f"{username}:"): - start, count = line.strip().split(":")[1:] - return int(start), int(count) - raise ValueError(f"User {username} not found in /etc/sub{id_type}") - - -def new_id_map(id_type, pid, id, nsid, count=1, *args, failures=0): - if id_type not in ("uid", "gid"): - raise ValueError("id_type must be 'uid' or 'gid") - cmd_args = [f"new{id_type}map", str(pid), str(id), str(nsid), str(count), *map(str, args)] - try: - return run(cmd_args, check=True) - except CalledProcessError as e: - if failures > 5: - raise e - new_id_map(id_type, pid, id, nsid, count, *args, failures=failures + 1) - - -class NamespaceProcess(Process): - """Like process, but runs in a new namespace. - Puts the target return value in a queue, and any exceptions in a pipe. - """ - - def __init__(self, target=None, args=None, kwargs=None, **ekwargs): - self.target_root = kwargs.pop("target_root", "/") - namespace_user = kwargs.pop("namespace_user", getlogin()) - self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid") - self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid") - self.orig_uid = getuid() - self.orig_gid = getgid() - self.uidmapped = Event() - self.completed = Event() - self.exception_recv, self.exception_send = Pipe() - self.function_queue = Queue() - super().__init__(target=target, args=args, kwargs=kwargs, **ekwargs) - - def map_ids(self): - new_id_map("uid", self.pid, 0, self.orig_uid, 1, 1, self.subuid_start, self.subuid_count) - new_id_map("gid", self.pid, 0, self.orig_gid, 1, 1, self.subgid_start, self.subgid_count) - - def map_unshare_uids(self): - self.start() - self.map_ids() - self.uidmapped.set() - - def run(self): - unshare_namespace() - self.uidmapped.wait() - setuid(0) - setgid(0) - chroot(self.target_root) - try: - self.function_queue.put(self._target(*self._args, **self._kwargs)) - except Exception as e: - self.exception_send.send(e) - self.completed.set() - - -def nschroot(target, *args, **kwargs): - p = NamespaceProcess(target=target, args=args, kwargs=kwargs) - try: - p.map_unshare_uids() - except CalledProcessError as e: - print(f"Error: {e}") - p.terminate() - raise e - - p.completed.wait() - if p.exception_recv.poll(): - p.terminate() - raise p.exception_recv.recv() - - ret = p.function_queue.get() - p.terminate() - return ret - - -def nsexec(target, *args, **kwargs): - p = NamespaceProcess(target=target, args=args, kwargs=kwargs) - try: - p.map_unshare_uids() - except CalledProcessError as e: - print(f"Error: {e}") - p.terminate() - raise e - - p.completed.wait() - if p.exception_recv.poll(): - p.terminate() - raise p.exception_recv.recv() - - ret = p.function_queue.get() - p.terminate() - return ret diff --git a/tests/test_namespace.py b/tests/test_namespace.py index 7608a49..c12df78 100644 --- a/tests/test_namespace.py +++ b/tests/test_namespace.py @@ -1,7 +1,7 @@ from os import environ from unittest import TestCase, main, skipIf -from zenlib.util import nsexec +from zenlib.namespace import nsexec class TestPassedException(Exception): @@ -16,12 +16,22 @@ def test_add_func(a, b): return a + b +def test_add_kwargs(a, b, add1=None, add2=None): + return add1 + add2 + + def test_uid_gid(): import os return os.getuid(), os.getgid() +def test_cwd(): + from pathlib import Path + + return [p for p in Path("/").rglob("")] + + class TestNamespace(TestCase): @skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI") def test_user_namespace_exceptions(self): @@ -32,10 +42,22 @@ def test_user_namespace_exceptions(self): def test_user_namespace_func(self): self.assertEqual(nsexec(test_add_func, 1, 2), 3) + @skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI") + def test_user_namespace_kwargs(self): + self.assertEqual(nsexec(test_add_kwargs, 1, 2, add1=3, add2=4), 7) + @skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI") def test_user_namespace_uid_gid(self): self.assertEqual(nsexec(test_uid_gid), (0, 0)) + @skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI") + def test_user_namespace_chroot(self): + from pathlib import Path + from tempfile import TemporaryDirectory + + with TemporaryDirectory() as test_dir: # It should be an empty root tree + self.assertEqual(nsexec(test_cwd, target_root=test_dir), [Path("/")]) + if __name__ == "__main__": main()