diff --git a/pyproject.toml b/pyproject.toml index 9623af4..4bad0ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "zenlib" -version = "3.0.2" +version = "3.1.0" authors = [ { name="Desultory", email="dev@pyl.onl" }, ] diff --git a/src/zenlib/util/__init__.py b/src/zenlib/util/__init__.py index bbeaf64..e867774 100644 --- a/src/zenlib/util/__init__.py +++ b/src/zenlib/util/__init__.py @@ -1,13 +1,22 @@ +from os import environ + from .colorize import colorize from .dict_check import contains, unset from .handle_plural import handle_plural from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger, process_args from .merge_class import merge_class +if not environ.get("CI"): + from .namespace import nsexec, nschroot, get_id_map +else: + nsexec, nschroot, get_id_map = None, None, None from .pretty_print import pretty_print from .replace_file_line import replace_file_line from ..types import NoDupFlatList __all__ = [ + "nsexec", + "nschroot", + "get_id_map", "handle_plural", "colorize", "NoDupFlatList", diff --git a/src/zenlib/util/namespace.py b/src/zenlib/util/namespace.py new file mode 100644 index 0000000..36a52ca --- /dev/null +++ b/src/zenlib/util/namespace.py @@ -0,0 +1,110 @@ +from multiprocessing import Event, Pipe, Process, Queue +from os import CLONE_NEWNS, CLONE_NEWUSER, chroot, getlogin, setgid, setuid, getuid, getgid, unshare +from subprocess import CalledProcessError, run + + +def unshare_namespace(): + unshare(CLONE_NEWNS | CLONE_NEWUSER) + + +def get_id_map(username=None, id_type="uid"): + username = username or getlogin() + if id_type not in ("uid", "gid"): + raise ValueError("id_type must be 'uid' or 'gid'") + + with open(f"/etc/sub{id_type}") as f: + for line in f: + if line.startswith(f"{username}:"): + start, count = line.strip().split(":")[1:] + return int(start), int(count) + raise ValueError(f"User {username} not found in /etc/sub{id_type}") + + +def new_id_map(id_type, pid, id, nsid, count=1, *args, failures=0): + if id_type not in ("uid", "gid"): + raise ValueError("id_type must be 'uid' or 'gid") + cmd_args = [f"new{id_type}map", str(pid), str(id), str(nsid), str(count), *map(str, args)] + try: + return run(cmd_args, check=True) + except CalledProcessError as e: + if failures > 5: + raise e + new_id_map(id_type, pid, id, nsid, count, *args, failures=failures + 1) + + +class NamespaceProcess(Process): + """Like process, but runs in a new namespace. + Puts the target return value in a queue, and any exceptions in a pipe. + """ + + def __init__(self, target=None, args=None, kwargs=None, **ekwargs): + self.target_root = kwargs.pop("target_root", "/") + namespace_user = kwargs.pop("namespace_user", getlogin()) + self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid") + self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid") + self.orig_uid = getuid() + self.orig_gid = getgid() + self.uidmapped = Event() + self.completed = Event() + self.exception_recv, self.exception_send = Pipe() + self.function_queue = Queue() + super().__init__(target=target, args=args, kwargs=kwargs, **ekwargs) + + def map_ids(self): + new_id_map("uid", self.pid, 0, self.orig_uid, 1, 1, self.subuid_start, self.subuid_count) + new_id_map("gid", self.pid, 0, self.orig_gid, 1, 1, self.subgid_start, self.subgid_count) + + def map_unshare_uids(self): + self.start() + self.map_ids() + self.uidmapped.set() + + def run(self): + unshare_namespace() + self.uidmapped.wait() + setuid(0) + setgid(0) + chroot(self.target_root) + try: + self.function_queue.put(self._target(*self._args, **self._kwargs)) + except Exception as e: + self.exception_send.send(e) + self.completed.set() + + +def nschroot(target, *args, **kwargs): + p = NamespaceProcess(target=target, args=args, kwargs=kwargs) + try: + p.map_unshare_uids() + except CalledProcessError as e: + print(f"Error: {e}") + p.terminate() + raise e + + p.completed.wait() + if p.exception_recv.poll(): + p.terminate() + raise p.exception_recv.recv() + + ret = p.function_queue.get() + p.terminate() + return ret + + +def nsexec(target, *args, **kwargs): + p = NamespaceProcess(target=target, args=args, kwargs=kwargs) + try: + p.map_unshare_uids() + except CalledProcessError as e: + print(f"Error: {e}") + p.terminate() + raise e + + p.completed.wait() + if p.exception_recv.poll(): + p.terminate() + raise p.exception_recv.recv() + + ret = p.function_queue.get() + p.terminate() + return ret diff --git a/tests/test_namespace.py b/tests/test_namespace.py new file mode 100644 index 0000000..d7184aa --- /dev/null +++ b/tests/test_namespace.py @@ -0,0 +1,39 @@ +from os import environ +from unittest import TestCase, main, skipIf + + + +from zenlib.util import nsexec + + +class TestPassedException(Exception): + pass + + +def test_exception(): + raise TestPassedException("This is a test exception") + +def test_add_func(a, b): + return a + b + +def test_uid_gid(): + import os + return os.getuid(), os.getgid() + +class TestNamespace(TestCase): + @skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI") + def test_user_namespace_exceptions(self): + with self.assertRaises(TestPassedException): + nsexec(test_exception) + + @skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI") + def test_user_namespace_func(self): + self.assertEqual(nsexec(test_add_func, 1, 2), 3) + + @skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI") + def test_user_namespace_uid_gid(self): + self.assertEqual(nsexec(test_uid_gid), (0, 0)) + + +if __name__ == "__main__": + main()