Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/ugrd/fs/mounts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__author__ = "desultory"
__version__ = "7.1.2"
__version__ = "7.1.3"

from pathlib import Path
from re import search
Expand Down Expand Up @@ -130,13 +130,16 @@ def _get_mount_dev_fs_type(self, device: str, raise_exception=True) -> str:
self.logger.debug("No mount found for device: %s" % device)


def _get_mount_source_type(self, mount: dict, with_val=False) -> str:
"""Gets the source from the mount config."""
def _get_mount_source(self, mount: dict) -> str:
"""Gets the source from the mount config.
Uses the order of SOURCE_TYPES to determine the source type.
uuid, partuuid, label, path.
Returns the source type and value if found, otherwise raises a ValueError.
"""
for source_type in SOURCE_TYPES:
if source_type in mount:
if with_val:
return source_type, mount[source_type]
return source_type
return source_type, mount[source_type]
raise ValueError("No source type found in mount: %s" % mount)


Expand Down Expand Up @@ -249,8 +252,8 @@ def _get_mount_str(self, mount: dict, pad=False, pad_size=44) -> str:
"""returns the mount source string based on the config,
the output string should work with fstab and mount commands.
pad: pads the output string with spaces, defined by pad_size (44)."""
mount_type, mount_name = _get_mount_source_type(self, mount, with_val=True)
out_str = mount_name if mount_type == "path" else f"{mount_type.upper()}={mount_name}"
mount_type, mount_val = _get_mount_source(self, mount)
out_str = mount_val if mount_type == "path" else f"{mount_type.upper()}={mount_val}"

if pad:
if len(out_str) > pad_size:
Expand Down Expand Up @@ -899,7 +902,7 @@ def _validate_host_mount(self, mount, destination_path=None) -> bool:
if mount.get("base_mount"):
return self.logger.debug("Skipping host mount validation for base mount: %s" % mount)

mount_type, mount_val = _get_mount_source_type(self, mount, with_val=True)
mount_type, mount_val = _get_mount_source(self, mount)
# If a destination path is passed, like for /, use that instead of the mount's destination
destination_path = str(mount["destination"]) if destination_path is None else destination_path

Expand Down Expand Up @@ -994,7 +997,14 @@ def mount_root(self) -> str:

def export_mount_info(self) -> None:
"""Exports mount info based on the config to /run/MOUNTS_ROOT_{option}"""
self["exports"]["MOUNTS_ROOT_SOURCE"] = _get_mount_str(self, self["mounts"]["root"])
try:
self["exports"]["MOUNTS_ROOT_SOURCE"] = _get_mount_str(self, self["mounts"]["root"])
except ValueError as e:
self.logger.critical(f"Failed to get source info for the root mount: {e}")
if not self["hostonly"]:
self.logger.info("Root mount infomrmation can be defined under the '[mounts.root]' section.")
raise ValidationError("Root mount source information is not set, when hostonly mode is disabled, it must be manually defined.")
raise ValidationError("Root mount source information is not set even though hostonly mode is enabled. Please report a bug.")
self["exports"]["MOUNTS_ROOT_TYPE"] = self["mounts"]["root"].get("type", "auto")
self["exports"]["MOUNTS_ROOT_OPTIONS"] = ",".join(self["mounts"]["root"]["options"])
self["exports"]["MOUNTS_ROOT_TARGET"] = self["mounts"]["root"]["destination"]
Expand Down