Skip to content

Commit fe33b97

Browse files
authored
Merge pull request #9 from desultory/dev
add namespace module
2 parents 2e57a46 + 9e7e37d commit fe33b97

File tree

4 files changed

+159
-1
lines changed

4 files changed

+159
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "zenlib"
7-
version = "3.0.2"
7+
version = "3.1.0"
88
authors = [
99
{ name="Desultory", email="[email protected]" },
1010
]

src/zenlib/util/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
1+
from os import environ
2+
13
from .colorize import colorize
24
from .dict_check import contains, unset
35
from .handle_plural import handle_plural
46
from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger, process_args
57
from .merge_class import merge_class
8+
if not environ.get("CI"):
9+
from .namespace import nsexec, nschroot, get_id_map
10+
else:
11+
nsexec, nschroot, get_id_map = None, None, None
612
from .pretty_print import pretty_print
713
from .replace_file_line import replace_file_line
814
from ..types import NoDupFlatList
915

1016
__all__ = [
17+
"nsexec",
18+
"nschroot",
19+
"get_id_map",
1120
"handle_plural",
1221
"colorize",
1322
"NoDupFlatList",

src/zenlib/util/namespace.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from multiprocessing import Event, Pipe, Process, Queue
2+
from os import CLONE_NEWNS, CLONE_NEWUSER, chroot, getlogin, setgid, setuid, getuid, getgid, unshare
3+
from subprocess import CalledProcessError, run
4+
5+
6+
def unshare_namespace():
7+
unshare(CLONE_NEWNS | CLONE_NEWUSER)
8+
9+
10+
def get_id_map(username=None, id_type="uid"):
11+
username = username or getlogin()
12+
if id_type not in ("uid", "gid"):
13+
raise ValueError("id_type must be 'uid' or 'gid'")
14+
15+
with open(f"/etc/sub{id_type}") as f:
16+
for line in f:
17+
if line.startswith(f"{username}:"):
18+
start, count = line.strip().split(":")[1:]
19+
return int(start), int(count)
20+
raise ValueError(f"User {username} not found in /etc/sub{id_type}")
21+
22+
23+
def new_id_map(id_type, pid, id, nsid, count=1, *args, failures=0):
24+
if id_type not in ("uid", "gid"):
25+
raise ValueError("id_type must be 'uid' or 'gid")
26+
cmd_args = [f"new{id_type}map", str(pid), str(id), str(nsid), str(count), *map(str, args)]
27+
try:
28+
return run(cmd_args, check=True)
29+
except CalledProcessError as e:
30+
if failures > 5:
31+
raise e
32+
new_id_map(id_type, pid, id, nsid, count, *args, failures=failures + 1)
33+
34+
35+
class NamespaceProcess(Process):
36+
"""Like process, but runs in a new namespace.
37+
Puts the target return value in a queue, and any exceptions in a pipe.
38+
"""
39+
40+
def __init__(self, target=None, args=None, kwargs=None, **ekwargs):
41+
self.target_root = kwargs.pop("target_root", "/")
42+
namespace_user = kwargs.pop("namespace_user", getlogin())
43+
self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid")
44+
self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid")
45+
self.orig_uid = getuid()
46+
self.orig_gid = getgid()
47+
self.uidmapped = Event()
48+
self.completed = Event()
49+
self.exception_recv, self.exception_send = Pipe()
50+
self.function_queue = Queue()
51+
super().__init__(target=target, args=args, kwargs=kwargs, **ekwargs)
52+
53+
def map_ids(self):
54+
new_id_map("uid", self.pid, 0, self.orig_uid, 1, 1, self.subuid_start, self.subuid_count)
55+
new_id_map("gid", self.pid, 0, self.orig_gid, 1, 1, self.subgid_start, self.subgid_count)
56+
57+
def map_unshare_uids(self):
58+
self.start()
59+
self.map_ids()
60+
self.uidmapped.set()
61+
62+
def run(self):
63+
unshare_namespace()
64+
self.uidmapped.wait()
65+
setuid(0)
66+
setgid(0)
67+
chroot(self.target_root)
68+
try:
69+
self.function_queue.put(self._target(*self._args, **self._kwargs))
70+
except Exception as e:
71+
self.exception_send.send(e)
72+
self.completed.set()
73+
74+
75+
def nschroot(target, *args, **kwargs):
76+
p = NamespaceProcess(target=target, args=args, kwargs=kwargs)
77+
try:
78+
p.map_unshare_uids()
79+
except CalledProcessError as e:
80+
print(f"Error: {e}")
81+
p.terminate()
82+
raise e
83+
84+
p.completed.wait()
85+
if p.exception_recv.poll():
86+
p.terminate()
87+
raise p.exception_recv.recv()
88+
89+
ret = p.function_queue.get()
90+
p.terminate()
91+
return ret
92+
93+
94+
def nsexec(target, *args, **kwargs):
95+
p = NamespaceProcess(target=target, args=args, kwargs=kwargs)
96+
try:
97+
p.map_unshare_uids()
98+
except CalledProcessError as e:
99+
print(f"Error: {e}")
100+
p.terminate()
101+
raise e
102+
103+
p.completed.wait()
104+
if p.exception_recv.poll():
105+
p.terminate()
106+
raise p.exception_recv.recv()
107+
108+
ret = p.function_queue.get()
109+
p.terminate()
110+
return ret

tests/test_namespace.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from os import environ
2+
from unittest import TestCase, main, skipIf
3+
4+
5+
6+
from zenlib.util import nsexec
7+
8+
9+
class TestPassedException(Exception):
10+
pass
11+
12+
13+
def test_exception():
14+
raise TestPassedException("This is a test exception")
15+
16+
def test_add_func(a, b):
17+
return a + b
18+
19+
def test_uid_gid():
20+
import os
21+
return os.getuid(), os.getgid()
22+
23+
class TestNamespace(TestCase):
24+
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
25+
def test_user_namespace_exceptions(self):
26+
with self.assertRaises(TestPassedException):
27+
nsexec(test_exception)
28+
29+
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
30+
def test_user_namespace_func(self):
31+
self.assertEqual(nsexec(test_add_func, 1, 2), 3)
32+
33+
@skipIf(environ.get("CI") == "true", "Skipping test_namespace.py in CI")
34+
def test_user_namespace_uid_gid(self):
35+
self.assertEqual(nsexec(test_uid_gid), (0, 0))
36+
37+
38+
if __name__ == "__main__":
39+
main()

0 commit comments

Comments
 (0)