Skip to content

Commit 43fddae

Browse files
authored
Merge pull request #11 from desultory/dev
improve tests, split namespace stuff into separate module
2 parents 7b74843 + 65bf042 commit 43fddae

File tree

7 files changed

+130
-120
lines changed

7 files changed

+130
-120
lines changed

src/zenlib/namespace/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from os import environ
2+
if not environ.get("CI"):
3+
from .nsexec import nsexec
4+
from .namespace import get_id_map
5+
else:
6+
nsexec, get_id_map = None, None
7+
8+
__all__ = [
9+
"nsexec",
10+
"get_id_map",
11+
]

src/zenlib/namespace/namespace.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from os import CLONE_NEWNS, CLONE_NEWUSER, getlogin, unshare
2+
from subprocess import CalledProcessError, run
3+
4+
5+
def unshare_namespace():
6+
unshare(CLONE_NEWNS | CLONE_NEWUSER)
7+
8+
9+
def get_id_map(username=None, id_type="uid"):
10+
username = username or getlogin()
11+
if id_type not in ("uid", "gid"):
12+
raise ValueError("id_type must be 'uid' or 'gid'")
13+
14+
with open(f"/etc/sub{id_type}") as f:
15+
for line in f:
16+
if line.startswith(f"{username}:"):
17+
start, count = line.strip().split(":")[1:]
18+
return int(start), int(count)
19+
raise ValueError(f"User {username} not found in /etc/sub{id_type}")
20+
21+
22+
def new_id_map(id_type, pid, id, nsid, count=1, *args, failures=0):
23+
if id_type not in ("uid", "gid"):
24+
raise ValueError("id_type must be 'uid' or 'gid")
25+
cmd_args = [f"new{id_type}map", str(pid), str(id), str(nsid), str(count), *map(str, args)]
26+
try:
27+
return run(cmd_args, check=True)
28+
except CalledProcessError as e:
29+
if failures > 5:
30+
raise e
31+
new_id_map(id_type, pid, id, nsid, count, *args, failures=failures + 1)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from multiprocessing import Event, Pipe, Process, Queue
2+
from os import chroot, getgid, getlogin, getuid, setgid, setuid
3+
4+
from .namespace import get_id_map, new_id_map, unshare_namespace
5+
6+
7+
class NamespaceProcess(Process):
8+
"""Like process, but runs in a new namespace.
9+
Puts the target return value in a queue, and any exceptions in a pipe.
10+
"""
11+
12+
def __init__(self, target=None, args=None, kwargs=None, **ekwargs):
13+
self.target_root = kwargs.pop("target_root", "/")
14+
namespace_user = kwargs.pop("namespace_user", getlogin())
15+
self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid")
16+
self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid")
17+
self.orig_uid = getuid()
18+
self.orig_gid = getgid()
19+
self.uidmapped = Event()
20+
self.completed = Event()
21+
self.exception_recv, self.exception_send = Pipe()
22+
self.function_queue = Queue()
23+
super().__init__(target=target, args=args, kwargs=kwargs, **ekwargs)
24+
25+
def map_ids(self):
26+
new_id_map("uid", self.pid, 0, self.orig_uid, 1, 1, self.subuid_start, self.subuid_count)
27+
new_id_map("gid", self.pid, 0, self.orig_gid, 1, 1, self.subgid_start, self.subgid_count)
28+
29+
def map_unshare_uids(self):
30+
self.start()
31+
self.map_ids()
32+
self.uidmapped.set()
33+
34+
def run(self):
35+
unshare_namespace()
36+
self.uidmapped.wait()
37+
setuid(0)
38+
setgid(0)
39+
chroot(self.target_root)
40+
try:
41+
self.function_queue.put(self._target(*self._args, **self._kwargs))
42+
except Exception as e:
43+
self.exception_send.send(e)
44+
self.completed.set()

src/zenlib/namespace/nsexec.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from subprocess import CalledProcessError
2+
from .namespace_process import NamespaceProcess
3+
4+
5+
def nsexec(target, *args, **kwargs):
6+
p = NamespaceProcess(target=target, args=args, kwargs=kwargs)
7+
try:
8+
p.map_unshare_uids()
9+
except CalledProcessError as e:
10+
print(f"Error: {e}")
11+
p.terminate()
12+
raise e
13+
14+
p.completed.wait()
15+
if p.exception_recv.poll():
16+
p.terminate()
17+
raise p.exception_recv.recv()
18+
19+
ret = p.function_queue.get()
20+
p.terminate()
21+
return ret

src/zenlib/util/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,13 @@
1-
from os import environ
2-
31
from .colorize import colorize
42
from .dict_check import contains, unset
53
from .handle_plural import handle_plural
64
from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger, process_args
75
from .merge_class import merge_class
8-
if not environ.get("CI"):
9-
from .namespace import nsexec, nschroot, get_id_map
10-
else:
11-
nsexec, nschroot, get_id_map = None, None, None
126
from .pretty_print import pretty_print
137
from .replace_file_line import replace_file_line
148
from ..types import NoDupFlatList
159

1610
__all__ = [
17-
"nsexec",
18-
"nschroot",
19-
"get_id_map",
2011
"handle_plural",
2112
"colorize",
2213
"NoDupFlatList",

src/zenlib/util/namespace.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

tests/test_namespace.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from os import environ
22
from unittest import TestCase, main, skipIf
33

4-
from zenlib.util import nsexec
4+
from zenlib.namespace import nsexec
55

66

77
class TestPassedException(Exception):
@@ -16,12 +16,22 @@ def test_add_func(a, b):
1616
return a + b
1717

1818

19+
def test_add_kwargs(a, b, add1=None, add2=None):
20+
return add1 + add2
21+
22+
1923
def test_uid_gid():
2024
import os
2125

2226
return os.getuid(), os.getgid()
2327

2428

29+
def test_cwd():
30+
from pathlib import Path
31+
32+
return [p for p in Path("/").rglob("")]
33+
34+
2535
class TestNamespace(TestCase):
2636
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
2737
def test_user_namespace_exceptions(self):
@@ -32,10 +42,22 @@ def test_user_namespace_exceptions(self):
3242
def test_user_namespace_func(self):
3343
self.assertEqual(nsexec(test_add_func, 1, 2), 3)
3444

45+
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
46+
def test_user_namespace_kwargs(self):
47+
self.assertEqual(nsexec(test_add_kwargs, 1, 2, add1=3, add2=4), 7)
48+
3549
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
3650
def test_user_namespace_uid_gid(self):
3751
self.assertEqual(nsexec(test_uid_gid), (0, 0))
3852

53+
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
54+
def test_user_namespace_chroot(self):
55+
from pathlib import Path
56+
from tempfile import TemporaryDirectory
57+
58+
with TemporaryDirectory() as test_dir: # It should be an empty root tree
59+
self.assertEqual(nsexec(test_cwd, target_root=test_dir), [Path("/")])
60+
3961

4062
if __name__ == "__main__":
4163
main()

0 commit comments

Comments
 (0)