Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pr/134'
Browse files Browse the repository at this point in the history
* origin/pr/134:
  Use fully-qualified exception names in .pylintrc
  Support additional policy directories
  • Loading branch information
marmarek committed Sep 25, 2024
2 parents 7d66382 + 966e383 commit 692209c
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,4 @@ exclude-protected=_asdict,_fields,_replace,_source,_make

# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=Exception
overgeneral-exceptions=builtins.Exception,builtins.BaseException
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ install-dom0: all-dom0
install -t $(DESTDIR)/etc/qubes/policy.d/include -m 664 policy.d/include/*
install -d $(DESTDIR)/lib/systemd/system -m 755
install -t $(DESTDIR)/lib/systemd/system -m 644 systemd/qubes-qrexec-policy-daemon.service
install -m 755 -d $(DESTDIR)/usr/lib/tmpfiles.d/
install -m 0644 -t $(DESTDIR)/usr/lib/tmpfiles.d/ systemd/qrexec.conf
.PHONY: install-dom0


Expand Down
2 changes: 2 additions & 0 deletions qrexec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
RPC_PATH = "/etc/qubes-rpc"
POLICY_AGENT_SOCKET_PATH = "/var/run/qubes/policy-agent.sock"
POLICYPATH = pathlib.Path("/etc/qubes/policy.d")
RUNTIME_POLICY_PATH = pathlib.Path("/run/qubes/policy.d")
POLICYSOCKET = pathlib.Path("/var/run/qubes/policy.sock")
POLICY_EVAL_SOCKET = pathlib.Path("/etc/qubes-rpc/policy.EvalSimple")
POLICY_GUI_SOCKET = pathlib.Path("/etc/qubes-rpc/policy.EvalGUI")
INCLUDEPATH = POLICYPATH / "include"
RUNTIME_INCLUDE_PATH = RUNTIME_POLICY_PATH / "include"
POLICYSUFFIX = ".policy"
POLICYPATH_OLD = pathlib.Path("/etc/qubes-rpc/policy")

Expand Down
69 changes: 55 additions & 14 deletions qrexec/policy/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
Sequence,
)

from .. import POLICYPATH, RPCNAME_ALLOWED_CHARSET, POLICYSUFFIX
from .. import POLICYPATH, RPCNAME_ALLOWED_CHARSET, POLICYSUFFIX, RUNTIME_POLICY_PATH
from ..utils import FullSystemInfo
from .. import exc
from ..exc import (
Expand Down Expand Up @@ -1790,22 +1790,54 @@ class AbstractFileSystemLoader(AbstractDirectoryLoader, AbstractFileLoader):
"""This class is used when policy is stored as regular files in a directory.
Args:
policy_path (pathlib.Path): Load this directory. Paths given to
``!include`` etc. directives are interpreted relative to this path.
policy_path: Load these directories. Paths given to
``!include`` etc. directives in a file are interpreted relative to
the path from which the file was loaded.
"""

def __init__(self, *, policy_path=POLICYPATH, **kwds):
super().__init__(**kwds)
self.policy_path = pathlib.Path(policy_path)

policy_path: Optional[pathlib.Path]
def __init__(
self,
*,
policy_path: Union[None, pathlib.PurePath, Iterable[pathlib.PurePath]]
) -> None:
super().__init__()
if policy_path is None:
iterable_policy_paths = [RUNTIME_POLICY_PATH, POLICYPATH]
elif isinstance(policy_path, pathlib.Path):
iterable_policy_paths = [policy_path]
elif isinstance(policy_path, list):
iterable_policy_paths = policy_path
else:
raise TypeError("unexpected type of policy path in AbstractFileSystemLoader.__init__!")
try:
self.load_policy_dir(self.policy_path)
self.load_policy_dirs(iterable_policy_paths)
except OSError as err:
raise AccessDenied(
"failed to load {} file: {!s}".format(err.filename, err)
) from err

def resolve_path(self, included_path):
self.policy_path = None

def load_policy_dirs(self, paths: Iterable[pathlib.PurePath]) -> None:
already_seen = set()
final_list = []
for path in paths:
for file_path in filter_filepaths(pathlib.Path(path).iterdir()):
basename = file_path.name
if basename not in already_seen:
already_seen.add(basename)
final_list.append(file_path)
final_list.sort(key=lambda x: x.name)
for file_path in final_list:
with file_path.open() as file:
self.policy_path = file_path.parent
try:
self.load_policy_file(file, file_path)
finally:
self.policy_path = None

def resolve_path(self, included_path: pathlib.PurePosixPath) -> pathlib.Path:
assert self.policy_path is not None, "Tried to resolve a path when not loading policy"
return (self.policy_path / included_path).resolve()


Expand Down Expand Up @@ -1840,12 +1872,21 @@ class ValidateParser(FilePolicy):
"""

def __init__(
self, *args, overrides: Dict[pathlib.Path, Optional[str]], **kwds
):
self,
*,
overrides: Dict[pathlib.Path, Optional[str]],
policy_path: Union[None, pathlib.PurePath, Iterable[pathlib.PurePath]] = None,
) -> None:
self.overrides = overrides
super().__init__(*args, **kwds)
super().__init__(policy_path=policy_path)

def load_policy_dir(self, dirpath):
def load_policy_dirs(self, paths: Iterable[pathlib.PurePath]) -> None:
assert len(paths) == 1
path, = paths
self.policy_path = path
self.load_policy_dir(path)

def load_policy_dir(self, dirpath: pathlib.Path) -> None:
for path in filter_filepaths(dirpath.iterdir()):
if path not in self.overrides:
with path.open() as file:
Expand Down
30 changes: 15 additions & 15 deletions qrexec/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@
import asyncio
import os.path
import pyinotify
from qrexec import POLICYPATH, POLICYPATH_OLD
from qrexec import POLICYPATH, POLICYPATH_OLD, RUNTIME_POLICY_PATH
from . import parser


class PolicyCache:
def __init__(self, path=POLICYPATH, use_legacy=True, lazy_load=False):
self.path = path
def __init__(
self, path=(RUNTIME_POLICY_PATH, POLICYPATH), use_legacy=True, lazy_load=False
) -> None:
self.paths = list(path)
self.outdated = lazy_load
if lazy_load:
self.policy = None
else:
self.policy = parser.FilePolicy(policy_path=self.path)
self.policy = parser.FilePolicy(policy_path=self.paths)

# default policy paths are listed manually, for compatibility with R4.0
# to be removed in Qubes 5.0
Expand Down Expand Up @@ -62,22 +64,20 @@ def initialize_watcher(self):
self.watch_manager, loop, default_proc_fun=PolicyWatcher(self)
)

if str(self.path) not in self.default_policy_paths and os.path.exists(
self.path
):
self.watches.append(
self.watch_manager.add_watch(
str(self.path), mask, rec=True, auto_add=True
for path in self.paths:
str_path = str(path)
if str_path not in self.default_policy_paths and os.path.exists(str_path):
self.watches.append(
self.watch_manager.add_watch(
str_path, mask, rec=True, auto_add=True
)
)
)

for path in self.default_policy_paths:
if not os.path.exists(path):
continue
self.watches.append(
self.watch_manager.add_watch(
str(path), mask, rec=True, auto_add=True
)
self.watch_manager.add_watch(str(path), mask, rec=True, auto_add=True)
)

def cleanup(self):
Expand All @@ -92,7 +92,7 @@ def cleanup(self):

def get_policy(self):
if self.outdated:
self.policy = parser.FilePolicy(policy_path=self.path)
self.policy = parser.FilePolicy(policy_path=self.paths)
self.outdated = False

return self.policy
Expand Down
2 changes: 1 addition & 1 deletion qrexec/tests/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def policy():
yield policy

assert mock_policy.mock_calls == [
mock.call(policy_path=PosixPath("/etc/qubes/policy.d"))
mock.call(policy_path=[PosixPath("/run/qubes/policy.d"), PosixPath("/etc/qubes/policy.d")]),
]


Expand Down
107 changes: 63 additions & 44 deletions qrexec/tests/policy_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,20 @@
import pytest
import unittest
import unittest.mock
import pathlib

from ..policy import utils


class TestPolicyCache:
@pytest.fixture
def tmp_paths(self, tmp_path: pathlib.Path) -> list[pathlib.Path]:
path1 = tmp_path / "path1"
path2 = tmp_path / "path2"
path1.mkdir()
path2.mkdir()
return [path1, path2]

@pytest.fixture
def mock_parser(self, monkeypatch):
mock_parser = unittest.mock.Mock()
Expand All @@ -37,58 +46,60 @@ def mock_parser(self, monkeypatch):
return mock_parser

def test_00_policy_init(self, tmp_path, mock_parser):
cache = utils.PolicyCache(tmp_path)
mock_parser.assert_called_once_with(policy_path=tmp_path)
cache = utils.PolicyCache([tmp_path])
mock_parser.assert_called_once_with(policy_path=[tmp_path])

@pytest.mark.asyncio
async def test_10_file_created(self, tmp_path, mock_parser):
cache = utils.PolicyCache(tmp_path)
cache.initialize_watcher()
async def test_10_file_created(self, tmp_paths, mock_parser):
for i in tmp_paths:
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
assert not cache.outdated

file = tmp_path / "test"
file.write_text("test")
(i / "file").write_text("test")

await asyncio.sleep(1)
await asyncio.sleep(1)

assert cache.outdated
assert cache.outdated

@pytest.mark.asyncio
async def test_11_file_changed(self, tmp_path, mock_parser):
file = tmp_path / "test"
file.write_text("test")
async def test_11_file_changed(self, tmp_paths, mock_parser):
for i in tmp_paths:
file = i / "test"
file.write_text("test")

cache = utils.PolicyCache(tmp_path)
cache.initialize_watcher()
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
assert not cache.outdated

file.write_text("new_content")
file.write_text("new_content")

await asyncio.sleep(1)
await asyncio.sleep(1)

assert cache.outdated
assert cache.outdated

@pytest.mark.asyncio
async def test_12_file_deleted(self, tmp_path, mock_parser):
file = tmp_path / "test"
file.write_text("test")
async def test_12_file_deleted(self, tmp_paths, mock_parser):
for i in tmp_paths:
file = i / "test"
file.write_text("test")

cache = utils.PolicyCache(tmp_path)
cache.initialize_watcher()
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
assert not cache.outdated

os.remove(file)
os.remove(file)

await asyncio.sleep(1)
await asyncio.sleep(1)

assert cache.outdated
assert cache.outdated

@pytest.mark.asyncio
async def test_13_no_change(self, tmp_path, mock_parser):
cache = utils.PolicyCache(tmp_path)
async def test_13_no_change(self, tmp_paths, mock_parser):
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

assert not cache.outdated
Expand All @@ -101,10 +112,10 @@ async def test_13_no_change(self, tmp_path, mock_parser):
async def test_14_policy_move(self, tmp_path, mock_parser):
policy_path = tmp_path / "policy"
policy_path.mkdir()
cache = utils.PolicyCache(policy_path)
cache = utils.PolicyCache([policy_path])
cache.initialize_watcher()

mock_parser.assert_called_once_with(policy_path=policy_path)
mock_parser.assert_called_once_with(policy_path=[policy_path])

assert not cache.outdated

Expand Down Expand Up @@ -135,27 +146,35 @@ async def test_14_policy_move(self, tmp_path, mock_parser):

cache.get_policy()

call = unittest.mock.call(policy_path=policy_path)
call = unittest.mock.call(policy_path=[policy_path])
assert mock_parser.mock_calls == [call, call, call]

@pytest.mark.asyncio
async def test_20_policy_updates(self, tmp_path, mock_parser):
cache = utils.PolicyCache(tmp_path)
async def test_20_policy_updates(self, tmp_paths, mock_parser):
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()
count = 0

mock_parser.assert_called_once_with(policy_path=tmp_path)
for i in tmp_paths:
call = unittest.mock.call(policy_path=tmp_paths)

assert not cache.outdated
count += 2
assert mock_parser.mock_calls == [call] * (count - 1)
cache = utils.PolicyCache(tmp_paths)
cache.initialize_watcher()

file = tmp_path / "test"
file.write_text("test")
l = len(mock_parser.mock_calls)
assert mock_parser.mock_calls == [call] * l

await asyncio.sleep(1)
assert not cache.outdated

assert cache.outdated
file = i / "test"
file.write_text("test")

cache.get_policy()
await asyncio.sleep(1)

assert cache.outdated

call = unittest.mock.call(policy_path=tmp_path)
cache.get_policy()

assert mock_parser.mock_calls == [call, call]
assert mock_parser.mock_calls == [call] * (count + 1)
2 changes: 1 addition & 1 deletion qrexec/tools/qrexec_legacy_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def main(args=None):
str(POLICYPATH), '--full-output'],
output=current_state_string)
current_state = set(current_state_string.getvalue().split('\n'))
except Exception: #pylint: disable-broad-except
except Exception: # pylint: disable=broad-except
current_state = 'ERROR'

if initial_state != current_state:
Expand Down
Loading

0 comments on commit 692209c

Please sign in to comment.