Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "zenlib"
version = "3.0.2"
version = "3.1.0"
authors = [
{ name="Desultory", email="[email protected]" },
]
Expand Down
9 changes: 9 additions & 0 deletions src/zenlib/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from os import environ

from .colorize import colorize
from .dict_check import contains, unset
from .handle_plural import handle_plural
from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger, process_args
from .merge_class import merge_class
if not environ.get("CI"):
from .namespace import nsexec, nschroot, get_id_map
else:
nsexec, nschroot, get_id_map = None, None, None
from .pretty_print import pretty_print
from .replace_file_line import replace_file_line
from ..types import NoDupFlatList

__all__ = [
"nsexec",
"nschroot",
"get_id_map",
"handle_plural",
"colorize",
"NoDupFlatList",
Expand Down
110 changes: 110 additions & 0 deletions src/zenlib/util/namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from multiprocessing import Event, Pipe, Process, Queue
from os import CLONE_NEWNS, CLONE_NEWUSER, chroot, getlogin, setgid, setuid, getuid, getgid, unshare
from subprocess import CalledProcessError, run


def unshare_namespace():
unshare(CLONE_NEWNS | CLONE_NEWUSER)


def get_id_map(username=None, id_type="uid"):
username = username or getlogin()
if id_type not in ("uid", "gid"):
raise ValueError("id_type must be 'uid' or 'gid'")

with open(f"/etc/sub{id_type}") as f:
for line in f:
if line.startswith(f"{username}:"):
start, count = line.strip().split(":")[1:]
return int(start), int(count)
raise ValueError(f"User {username} not found in /etc/sub{id_type}")


def new_id_map(id_type, pid, id, nsid, count=1, *args, failures=0):
if id_type not in ("uid", "gid"):
raise ValueError("id_type must be 'uid' or 'gid")
cmd_args = [f"new{id_type}map", str(pid), str(id), str(nsid), str(count), *map(str, args)]
try:
return run(cmd_args, check=True)
except CalledProcessError as e:
if failures > 5:
raise e
new_id_map(id_type, pid, id, nsid, count, *args, failures=failures + 1)


class NamespaceProcess(Process):
"""Like process, but runs in a new namespace.
Puts the target return value in a queue, and any exceptions in a pipe.
"""

def __init__(self, target=None, args=None, kwargs=None, **ekwargs):
self.target_root = kwargs.pop("target_root", "/")
namespace_user = kwargs.pop("namespace_user", getlogin())
self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid")
self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid")
self.orig_uid = getuid()
self.orig_gid = getgid()
self.uidmapped = Event()
self.completed = Event()
self.exception_recv, self.exception_send = Pipe()
self.function_queue = Queue()
super().__init__(target=target, args=args, kwargs=kwargs, **ekwargs)

def map_ids(self):
new_id_map("uid", self.pid, 0, self.orig_uid, 1, 1, self.subuid_start, self.subuid_count)
new_id_map("gid", self.pid, 0, self.orig_gid, 1, 1, self.subgid_start, self.subgid_count)

def map_unshare_uids(self):
self.start()
self.map_ids()
self.uidmapped.set()

def run(self):
unshare_namespace()
self.uidmapped.wait()
setuid(0)
setgid(0)
chroot(self.target_root)
try:
self.function_queue.put(self._target(*self._args, **self._kwargs))
except Exception as e:
self.exception_send.send(e)
self.completed.set()


def nschroot(target, *args, **kwargs):
p = NamespaceProcess(target=target, args=args, kwargs=kwargs)
try:
p.map_unshare_uids()
except CalledProcessError as e:
print(f"Error: {e}")
p.terminate()
raise e

p.completed.wait()
if p.exception_recv.poll():
p.terminate()
raise p.exception_recv.recv()

ret = p.function_queue.get()
p.terminate()
return ret


def nsexec(target, *args, **kwargs):
p = NamespaceProcess(target=target, args=args, kwargs=kwargs)
try:
p.map_unshare_uids()
except CalledProcessError as e:
print(f"Error: {e}")
p.terminate()
raise e

p.completed.wait()
if p.exception_recv.poll():
p.terminate()
raise p.exception_recv.recv()

ret = p.function_queue.get()
p.terminate()
return ret
39 changes: 39 additions & 0 deletions tests/test_namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from os import environ
from unittest import TestCase, main, skipIf



from zenlib.util import nsexec


class TestPassedException(Exception):
pass


def test_exception():
raise TestPassedException("This is a test exception")

def test_add_func(a, b):
return a + b

def test_uid_gid():
import os
return os.getuid(), os.getgid()

class TestNamespace(TestCase):
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
def test_user_namespace_exceptions(self):
with self.assertRaises(TestPassedException):
nsexec(test_exception)

@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
def test_user_namespace_func(self):
self.assertEqual(nsexec(test_add_func, 1, 2), 3)

@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
def test_user_namespace_uid_gid(self):
self.assertEqual(nsexec(test_uid_gid), (0, 0))


if __name__ == "__main__":
main()
Loading