Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/zenlib/namespace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from os import environ
if not environ.get("CI"):
from .nsexec import nsexec
from .namespace import get_id_map
else:
nsexec, get_id_map = None, None

__all__ = [
"nsexec",
"get_id_map",
]
31 changes: 31 additions & 0 deletions src/zenlib/namespace/namespace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from os import CLONE_NEWNS, CLONE_NEWUSER, getlogin, unshare
from subprocess import CalledProcessError, run


def unshare_namespace():
unshare(CLONE_NEWNS | CLONE_NEWUSER)


def get_id_map(username=None, id_type="uid"):
username = username or getlogin()
if id_type not in ("uid", "gid"):
raise ValueError("id_type must be 'uid' or 'gid'")

with open(f"/etc/sub{id_type}") as f:
for line in f:
if line.startswith(f"{username}:"):
start, count = line.strip().split(":")[1:]
return int(start), int(count)
raise ValueError(f"User {username} not found in /etc/sub{id_type}")


def new_id_map(id_type, pid, id, nsid, count=1, *args, failures=0):
if id_type not in ("uid", "gid"):
raise ValueError("id_type must be 'uid' or 'gid")
cmd_args = [f"new{id_type}map", str(pid), str(id), str(nsid), str(count), *map(str, args)]
try:
return run(cmd_args, check=True)
except CalledProcessError as e:
if failures > 5:
raise e
new_id_map(id_type, pid, id, nsid, count, *args, failures=failures + 1)
44 changes: 44 additions & 0 deletions src/zenlib/namespace/namespace_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from multiprocessing import Event, Pipe, Process, Queue
from os import chroot, getgid, getlogin, getuid, setgid, setuid

from .namespace import get_id_map, new_id_map, unshare_namespace


class NamespaceProcess(Process):
"""Like process, but runs in a new namespace.
Puts the target return value in a queue, and any exceptions in a pipe.
"""

def __init__(self, target=None, args=None, kwargs=None, **ekwargs):
self.target_root = kwargs.pop("target_root", "/")
namespace_user = kwargs.pop("namespace_user", getlogin())
self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid")
self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid")
self.orig_uid = getuid()
self.orig_gid = getgid()
self.uidmapped = Event()
self.completed = Event()
self.exception_recv, self.exception_send = Pipe()
self.function_queue = Queue()
super().__init__(target=target, args=args, kwargs=kwargs, **ekwargs)

def map_ids(self):
new_id_map("uid", self.pid, 0, self.orig_uid, 1, 1, self.subuid_start, self.subuid_count)
new_id_map("gid", self.pid, 0, self.orig_gid, 1, 1, self.subgid_start, self.subgid_count)

def map_unshare_uids(self):
self.start()
self.map_ids()
self.uidmapped.set()

def run(self):
unshare_namespace()
self.uidmapped.wait()
setuid(0)
setgid(0)
chroot(self.target_root)
try:
self.function_queue.put(self._target(*self._args, **self._kwargs))
except Exception as e:
self.exception_send.send(e)
self.completed.set()
21 changes: 21 additions & 0 deletions src/zenlib/namespace/nsexec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from subprocess import CalledProcessError
from .namespace_process import NamespaceProcess


def nsexec(target, *args, **kwargs):
p = NamespaceProcess(target=target, args=args, kwargs=kwargs)
try:
p.map_unshare_uids()
except CalledProcessError as e:
print(f"Error: {e}")
p.terminate()
raise e

p.completed.wait()
if p.exception_recv.poll():
p.terminate()
raise p.exception_recv.recv()

ret = p.function_queue.get()
p.terminate()
return ret
9 changes: 0 additions & 9 deletions src/zenlib/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
from os import environ

from .colorize import colorize
from .dict_check import contains, unset
from .handle_plural import handle_plural
from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger, process_args
from .merge_class import merge_class
if not environ.get("CI"):
from .namespace import nsexec, nschroot, get_id_map
else:
nsexec, nschroot, get_id_map = None, None, None
from .pretty_print import pretty_print
from .replace_file_line import replace_file_line
from ..types import NoDupFlatList

__all__ = [
"nsexec",
"nschroot",
"get_id_map",
"handle_plural",
"colorize",
"NoDupFlatList",
Expand Down
110 changes: 0 additions & 110 deletions src/zenlib/util/namespace.py

This file was deleted.

24 changes: 23 additions & 1 deletion tests/test_namespace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from os import environ
from unittest import TestCase, main, skipIf

from zenlib.util import nsexec
from zenlib.namespace import nsexec


class TestPassedException(Exception):
Expand All @@ -16,12 +16,22 @@ def test_add_func(a, b):
return a + b


def test_add_kwargs(a, b, add1=None, add2=None):
return add1 + add2


def test_uid_gid():
import os

return os.getuid(), os.getgid()


def test_cwd():
from pathlib import Path

return [p for p in Path("/").rglob("")]


class TestNamespace(TestCase):
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
def test_user_namespace_exceptions(self):
Expand All @@ -32,10 +42,22 @@ def test_user_namespace_exceptions(self):
def test_user_namespace_func(self):
self.assertEqual(nsexec(test_add_func, 1, 2), 3)

@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
def test_user_namespace_kwargs(self):
self.assertEqual(nsexec(test_add_kwargs, 1, 2, add1=3, add2=4), 7)

@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
def test_user_namespace_uid_gid(self):
self.assertEqual(nsexec(test_uid_gid), (0, 0))

@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
def test_user_namespace_chroot(self):
from pathlib import Path
from tempfile import TemporaryDirectory

with TemporaryDirectory() as test_dir: # It should be an empty root tree
self.assertEqual(nsexec(test_cwd, target_root=test_dir), [Path("/")])


if __name__ == "__main__":
main()
Loading