Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ jobs:

steps:
- uses: actions/checkout@v4
- name: Permit unprivileged user namespaces
run: |
set -x
sudo sysctl -w kernel.apparmor_restrict_unprivileged_unconfined=0
sudo sysctl -w kernel.apparmor_restrict_unprivileged_userns=0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "zenlib"
version = "3.1.6"
version = "3.1.7"
authors = [
{ name="Desultory", email="[email protected]" },
]
Expand Down
3 changes: 1 addition & 2 deletions src/zenlib/namespace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from os import environ
from platform import system
from sys import version_info

if environ.get("CI", "false").lower() == "true" or version_info < (3, 12) or system() != "Linux":
if version_info < (3, 12) or system() != "Linux":
nsexec, get_id_map = None, None
else:
from .namespace import get_id_map
Expand Down
5 changes: 3 additions & 2 deletions src/zenlib/namespace/namespace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os import CLONE_NEWNS, CLONE_NEWUSER, getlogin, unshare, getuid
from os import CLONE_NEWNS, CLONE_NEWUSER, unshare, getuid
from getpass import getuser
from subprocess import CalledProcessError, run


Expand All @@ -7,7 +8,7 @@ def unshare_namespace():


def get_id_map(username=None, id_type="uid"):
username = username or getlogin()
username = username or getuser()
if id_type not in ("uid", "gid"):
raise ValueError("id_type must be 'uid' or 'gid'")

Expand Down
5 changes: 3 additions & 2 deletions src/zenlib/namespace/namespace_process.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from multiprocessing import Event, Pipe, Process, Queue
from os import chroot, chdir, getgid, getlogin, getuid, setgid, setuid
from os import chroot, chdir, getgid, getuid, setgid, setuid
from getpass import getuser

from .namespace import get_id_map, new_id_map, unshare_namespace

Expand All @@ -11,7 +12,7 @@ class NamespaceProcess(Process):

def __init__(self, target=None, args=None, kwargs=None, **ekwargs):
self.target_root = kwargs.pop("target_root", "/")
namespace_user = kwargs.pop("namespace_user", getlogin())
namespace_user = kwargs.pop("namespace_user", getuser())
self.subuid_start, self.subuid_count = get_id_map(namespace_user, "uid")
self.subgid_start, self.subgid_count = get_id_map(namespace_user, "gid")
self.orig_uid = getuid()
Expand Down
4 changes: 1 addition & 3 deletions src/zenlib/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .colorize import colorize
from .dict_check import contains, unset
from .handle_plural import handle_plural
from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger, process_args
from .main_funcs import get_args_n_logger, get_kwargs, get_kwargs_from_args, process_args
from .merge_class import merge_class
from .pretty_print import pretty_print
from .replace_file_line import replace_file_line
Expand All @@ -13,9 +13,7 @@
"NoDupFlatList",
"pretty_print",
"replace_file_line",
"init_logger",
"process_args",
"init_argparser",
"get_args_n_logger",
"get_kwargs_from_args",
"get_kwargs",
Expand Down
47 changes: 16 additions & 31 deletions src/zenlib/util/main_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
Functions to help with the main()
"""

__version__ = "1.3.1"
__author__ = "desultory"
__version__ = "1.4.1"

from argparse import ArgumentError, ArgumentParser, Namespace
from importlib.metadata import version
from logging import FileHandler, Formatter, StreamHandler, getLogger
from sys import argv


def get_base_args():
Expand All @@ -18,25 +22,11 @@ def get_base_args():
]


def init_logger(name=None):
"""Initialize the logger with a name"""
from logging import getLogger

name = name or __name__
return getLogger(name)


def init_argparser(prog=None, description=None):
"""Initialize an argparser with common options."""
from argparse import ArgumentParser

argparser = ArgumentParser(prog=prog, description=description)
return argparser


def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True):
"""Get kwargs from argparser args.
Drop base doesn't add init_argparser args."""
Drop base doesn't add args defined in get_base_args.
Empty args are not added to kwargs.
"""
kwargs = base_kwargs.copy()
if logger is not None:
kwargs["logger"] = logger
Expand All @@ -55,8 +45,6 @@ def get_kwargs_from_args(args, logger=None, base_kwargs={}, drop_base=True):

def process_args(argparser, logger=None, strict=False):
"""Process argparser args, optionally configuring a logger."""
from logging import Formatter

from zenlib.logging import ColorLognameFormatter

if strict:
Expand All @@ -66,9 +54,8 @@ def process_args(argparser, logger=None, strict=False):
if unknown:
args._unknown = unknown

if 'version' in args and args.version and argparser.prog != "zenlib_test":
if "version" in args and args.version and argparser.prog != "zenlib_test":
package = argparser.prog
from importlib.metadata import version

print(f"{package} {version(package)}")
exit(0)
Expand Down Expand Up @@ -96,13 +83,11 @@ def process_args(argparser, logger=None, strict=False):
handler.setFormatter(formatter)
break
else:
from logging import FileHandler, StreamHandler

handler = StreamHandler() if args.log_file is None else FileHandler(args.log_file)
handler.setFormatter(formatter)
logger.addHandler(handler)

if '__unknown' in args and args.__unknown:
if "__unknown" in args and args.__unknown:
logger.warning(f"Unknown args: {args.__unknown}")

return args
Expand All @@ -128,13 +113,11 @@ def get_args_n_logger(package, description: str, arguments=[], drop_default=Fals
Returns the parsed args and logger.
"""
all_arguments = get_base_args() + arguments
from sys import argv
if "--dump_args" in argv:
dump_args_for_autocomplete(all_arguments)

from argparse import Namespace, ArgumentError
argparser = init_argparser(prog=package, description=description)
logger = init_logger(package)
argparser = ArgumentParser(prog=package, description=description)
logger = getLogger(package)

def add_args(args, argparser):
for arg in args:
Expand All @@ -157,7 +140,9 @@ def add_args(args, argparser):
return args, logger


def get_kwargs(package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True, strict=False):
def get_kwargs(
package, description: str, arguments=[], base_kwargs={}, drop_default=False, drop_base=True, strict=False
):
"""Like get_args_n_logger, but only returns kwargs"""
args, logger = get_args_n_logger(package, description, arguments, drop_default=drop_default, strict=strict)
return get_kwargs_from_args(args, logger=logger, base_kwargs=base_kwargs, drop_base=drop_base)
25 changes: 8 additions & 17 deletions tests/test_main_funcs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from argparse import ArgumentParser, Namespace
from argparse import Namespace
from logging import Logger
from unittest import TestCase, expectedFailure, main

from zenlib.util import get_args_n_logger, get_kwargs, get_kwargs_from_args, init_argparser, init_logger
from zenlib.util import get_args_n_logger, get_kwargs, get_kwargs_from_args
from zenlib.util.main_funcs import dump_args_for_autocomplete, get_base_args

DEFAULT_ARGS = ["debug", "trace", "log_time", "no_log_color"]
Expand All @@ -15,25 +15,14 @@ def get_test_args():


class TestMainFuncs(TestCase):
def test_init_logger(self):
self.assertIsInstance(init_logger(), Logger)

def test_init_argparser(self):
self.assertIsInstance(init_argparser(), ArgumentParser)

def test_named_init_argparser(self):
parser = init_argparser("test", "test description")
self.assertEqual(parser.prog, "test")
self.assertEqual(parser.description, "test description")

def _check_for_test_args(self, args):
self.assertIsInstance(args, Namespace)

if 'discover' not in args.args:
if "discover" not in args.args:
self.fail("discover not found in args")

search_strs = ['tests', './tests']
if '_unknown' in args:
search_strs = ["tests", "./tests"]
if "_unknown" in args:
search_locs = [args.args, args._unknown]
else:
search_locs = [args.args]
Expand All @@ -52,7 +41,9 @@ def test_get_args_n_logger(self):
self._check_for_test_args(args)

def test_get_args_n_logger_no_default(self):
args, logger = get_args_n_logger("zenlib_test", "test description", get_test_args(), strict=False, drop_default=True)
args, logger = get_args_n_logger(
"zenlib_test", "test description", get_test_args(), strict=False, drop_default=True
)
self.assertIsInstance(logger, Logger)

for arg in DEFAULT_ARGS:
Expand Down
7 changes: 1 addition & 6 deletions tests/test_namespace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from os import environ
from platform import system
from sys import version_info
from unittest import TestCase, main, skipIf
Expand All @@ -8,10 +7,6 @@

def check_test_compat():
"""Checks if tests are compatible with the current environment"""

if environ.get("CI", "false").lower() == "true":
return

if system() != "Linux":
return

Expand Down Expand Up @@ -46,7 +41,7 @@ def test_uid_gid():
def test_cwd():
from pathlib import Path

return [p for p in Path("/").rglob("")]
return [p.resolve() for p in Path("/").rglob("")]


@skipIf(not check_test_compat(), "Skipping test_namespace.py in CI")
Expand Down