diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index a6645a6..ba766fc 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -19,6 +19,12 @@ jobs: steps: - uses: actions/checkout@v4 + - name: Permit unprivileged user namespaces + run: | + set -x + sudo sysctl -w kernel.apparmor_restrict_unprivileged_unconfined=0 + sudo sysctl -w kernel.apparmor_restrict_unprivileged_userns=0 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: diff --git a/pyproject.toml b/pyproject.toml index 80c68b6..26a9c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "zenlib" -version = "3.1.6" +version = "3.1.7" authors = [ { name="Desultory", email="dev@pyl.onl" }, ] diff --git a/src/zenlib/namespace/__init__.py b/src/zenlib/namespace/__init__.py index 6d4f656..2bac508 100644 --- a/src/zenlib/namespace/__init__.py +++ b/src/zenlib/namespace/__init__.py @@ -1,8 +1,7 @@ -from os import environ from platform import system from sys import version_info -if environ.get("CI", "false").lower() == "true" or version_info < (3, 12) or system() != "Linux": +if version_info < (3, 12) or system() != "Linux": nsexec, get_id_map = None, None else: from .namespace import get_id_map diff --git a/src/zenlib/namespace/namespace.py b/src/zenlib/namespace/namespace.py index ac77c5e..d5f5913 100644 --- a/src/zenlib/namespace/namespace.py +++ b/src/zenlib/namespace/namespace.py @@ -1,4 +1,5 @@ -from os import CLONE_NEWNS, CLONE_NEWUSER, getlogin, unshare, getuid +from os import CLONE_NEWNS, CLONE_NEWUSER, unshare, getuid +from getpass import getuser from subprocess import CalledProcessError, run @@ -7,7 +8,7 @@ def unshare_namespace(): def get_id_map(username=None, id_type="uid"): - username = username or getlogin() + username = username or getuser() if id_type not in ("uid", "gid"): raise ValueError("id_type must be 'uid' or 'gid'") diff --git a/src/zenlib/namespace/namespace_process.py b/src/zenlib/namespace/namespace_process.py index b0d43fc..0cd1f76 100644 --- a/src/zenlib/namespace/namespace_process.py +++ b/src/zenlib/namespace/namespace_process.py @@ -1,5 +1,6 @@ from multiprocessing import Event, Pipe, Process, Queue -from os import chroot, chdir, getgid, getlogin, getuid, setgid, setuid +from os import chroot, chdir, getgid, getuid, setgid, setuid +from getpass import getuser from .namespace import get_id_map, new_id_map, unshare_namespace @@ -11,7 +12,7 @@ class NamespaceProcess(Process): def __init__(self, target=None, args=None, kwargs=None, **ekwargs): self.target_root = kwargs.pop("target_root", "/") - namespace_user = kwargs.pop("namespace_user", getlogin()) + namespace_user = kwargs.pop("namespace_user", getuser()) self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid") self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid") self.orig_uid = getuid() diff --git a/src/zenlib/util/__init__.py b/src/zenlib/util/__init__.py index bbeaf64..e76c185 100644 --- a/src/zenlib/util/__init__.py +++ b/src/zenlib/util/__init__.py @@ -1,7 +1,7 @@ from .colorize import colorize from .dict_check import contains, unset from .handle_plural import handle_plural -from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger, process_args +from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, process_args from .merge_class import merge_class from .pretty_print import pretty_print from .replace_file_line import replace_file_line @@ -13,9 +13,7 @@ "NoDupFlatList", "pretty_print", "replace_file_line", - "init_logger", "process_args", - "init_argparser", "get_args_n_logger", "get_kwargs_from_args", "get_kwargs", diff --git a/src/zenlib/util/main_funcs.py b/src/zenlib/util/main_funcs.py index d45bec5..03c0a9f 100644 --- a/src/zenlib/util/main_funcs.py +++ b/src/zenlib/util/main_funcs.py @@ -2,8 +2,12 @@ Functions to help with the main() """ -__version__ = "1.3.1" -__author__ = "desultory" +__version__ = "1.4.1" + +from argparse import ArgumentError, ArgumentParser, Namespace +from importlib.metadata import version +from logging import FileHandler, Formatter, StreamHandler, getLogger +from sys import argv def get_base_args(): @@ -18,25 +22,11 @@ def get_base_args(): ] -def init_logger(name=None): - """Initialize the logger with a name""" - from logging import getLogger - - name = name or __name__ - return getLogger(name) - - -def init_argparser(prog=None, description=None): - """Initialize an argparser with common options.""" - from argparse import ArgumentParser - - argparser = ArgumentParser(prog=prog, description=description) - return argparser - - def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True): """Get kwargs from argparser args. - Drop base doesn't add init_argparser args.""" + Drop base doesn't add args defined in get_base_args. + Empty args are not added to kwargs. + """ kwargs = base_kwargs.copy() if logger is not None: kwargs["logger"] = logger @@ -55,8 +45,6 @@ def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True): def process_args(argparser, logger=None, strict=False): """Process argparser args, optionally configuring a logger.""" - from logging import Formatter - from zenlib.logging import ColorLognameFormatter if strict: @@ -66,9 +54,8 @@ def process_args(argparser, logger=None, strict=False): if unknown: args._unknown = unknown - if 'version' in args and args.version and argparser.prog != "zenlib_test": + if "version" in args and args.version and argparser.prog != "zenlib_test": package = argparser.prog - from importlib.metadata import version print(f"{package} {version(package)}") exit(0) @@ -96,13 +83,11 @@ def process_args(argparser, logger=None, strict=False): handler.setFormatter(formatter) break else: - from logging import FileHandler, StreamHandler - handler = StreamHandler() if args.log_file is None else FileHandler(args.log_file) handler.setFormatter(formatter) logger.addHandler(handler) - if '__unknown' in args and args.__unknown: + if "__unknown" in args and args.__unknown: logger.warning(f"Unknown args: {args.__unknown}") return args @@ -128,13 +113,11 @@ def get_args_n_logger(package, description: str, arguments=[], drop_default=Fals Returns the parsed args and logger. """ all_arguments = get_base_args() + arguments - from sys import argv if "--dump_args" in argv: dump_args_for_autocomplete(all_arguments) - from argparse import Namespace, ArgumentError - argparser = init_argparser(prog=package, description=description) - logger = init_logger(package) + argparser = ArgumentParser(prog=package, description=description) + logger = getLogger(package) def add_args(args, argparser): for arg in args: @@ -157,7 +140,9 @@ def add_args(args, argparser): return args, logger -def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True, strict=False): +def get_kwargs( + package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True, strict=False +): """Like get_args_n_logger, but only returns kwargs""" args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default, strict=strict) return get_kwargs_from_args(args, logger=logger, base_kwargs=base_kwargs, drop_base=drop_base) diff --git a/tests/test_main_funcs.py b/tests/test_main_funcs.py index a567cc1..4e28c74 100644 --- a/tests/test_main_funcs.py +++ b/tests/test_main_funcs.py @@ -1,8 +1,8 @@ -from argparse import ArgumentParser, Namespace +from argparse import Namespace from logging import Logger from unittest import TestCase, expectedFailure, main -from zenlib.util import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger +from zenlib.util import get_args_n_logger, get_kwargs, get_kwargs_from_args from zenlib.util.main_funcs import dump_args_for_autocomplete, get_base_args DEFAULT_ARGS = ["debug", "trace", "log_time", "no_log_color"] @@ -15,25 +15,14 @@ def get_test_args(): class TestMainFuncs(TestCase): - def test_init_logger(self): - self.assertIsInstance(init_logger(), Logger) - - def test_init_argparser(self): - self.assertIsInstance(init_argparser(), ArgumentParser) - - def test_named_init_argparser(self): - parser = init_argparser("test", "test description") - self.assertEqual(parser.prog, "test") - self.assertEqual(parser.description, "test description") - def _check_for_test_args(self, args): self.assertIsInstance(args, Namespace) - if 'discover' not in args.args: + if "discover" not in args.args: self.fail("discover not found in args") - search_strs = ['tests', './tests'] - if '_unknown' in args: + search_strs = ["tests", "./tests"] + if "_unknown" in args: search_locs = [args.args, args._unknown] else: search_locs = [args.args] @@ -52,7 +41,9 @@ def test_get_args_n_logger(self): self._check_for_test_args(args) def test_get_args_n_logger_no_default(self): - args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False, drop_default=True) + args, logger = get_args_n_logger( + "zenlib_test", "test description", get_test_args(), strict=False, drop_default=True + ) self.assertIsInstance(logger, Logger) for arg in DEFAULT_ARGS: diff --git a/tests/test_namespace.py b/tests/test_namespace.py index 4a6bd41..84708f7 100644 --- a/tests/test_namespace.py +++ b/tests/test_namespace.py @@ -1,4 +1,3 @@ -from os import environ from platform import system from sys import version_info from unittest import TestCase, main, skipIf @@ -8,10 +7,6 @@ def check_test_compat(): """Checks if tests are compatible with the current environment""" - - if environ.get("CI", "false").lower() == "true": - return - if system() != "Linux": return @@ -46,7 +41,7 @@ def test_uid_gid(): def test_cwd(): from pathlib import Path - return [p for p in Path("/").rglob("")] + return [p.resolve() for p in Path("/").rglob("")] @skipIf(not check_test_compat(), "Skipping test_namespace.py in CI")