Skip to content

Commit 2de2be4

Browse files
committed
don't log_init by default, black format, fix type check
Signed-off-by: Zen <[email protected]>
1 parent 2df9e0a commit 2de2be4

File tree

1 file changed

+73
-67
lines changed

1 file changed

+73
-67
lines changed

src/ugrd/initramfs_dict.py

Lines changed: 73 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
__author__ = "desultory"
22
__version__ = "2.2.1"
33

4-
from tomllib import load, TOMLDecodeError
4+
from collections import UserDict
55
from pathlib import Path
66
from queue import Queue
7-
from collections import UserDict
7+
from tomllib import TOMLDecodeError, load
88

99
from zenlib.logging import loggify
10-
from zenlib.util import handle_plural, pretty_print, NoDupFlatList
10+
from zenlib.util import NoDupFlatList, handle_plural, pretty_print
1111

1212

1313
@loggify
@@ -24,51 +24,57 @@ class InitramfsConfigDict(UserDict):
2424
2525
If parameters which are not registerd are set, they are added to the processing queue and processed when the type is known.
2626
"""
27-
builtin_parameters = {'modules': NoDupFlatList, # A list of the names of modules which have been loaded, mostly used for dependency checking
28-
'imports': dict, # A dict of functions to be imported into the initramfs, under their respective hooks
29-
'validated': bool, # A flag to indicate if the config has been validated, mostly used for log levels
30-
'custom_parameters': dict, # Custom parameters loaded from imports
31-
'custom_processing': dict, # Custom processing functions which will be run to validate and process parameters
32-
'_processing': dict} # A dict of queues containing parameters which have been set before the type was known
27+
28+
builtin_parameters = {
29+
"modules": NoDupFlatList, # A list of the names of modules which have been loaded, mostly used for dependency checking
30+
"imports": dict, # A dict of functions to be imported into the initramfs, under their respective hooks
31+
"validated": bool, # A flag to indicate if the config has been validated, mostly used for log levels
32+
"custom_parameters": dict, # Custom parameters loaded from imports
33+
"custom_processing": dict, # Custom processing functions which will be run to validate and process parameters
34+
"_processing": dict,
35+
} # A dict of queues containing parameters which have been set before the type was known
3336

3437
def __init__(self, NO_BASE=False, *args, **kwargs):
3538
super().__init__(*args, **kwargs)
3639
# Define the default parameters
3740
for parameter, default_type in self.builtin_parameters.items():
3841
if default_type == NoDupFlatList:
39-
self.data[parameter] = default_type(no_warn=True, log_bump=5, logger=self.logger, _log_init=False)
42+
self.data[parameter] = default_type(no_warn=True, log_bump=5, logger=self.logger)
4043
else:
4144
self.data[parameter] = default_type()
4245
if not NO_BASE:
43-
self['modules'] = 'ugrd.base.base'
46+
self["modules"] = "ugrd.base.base"
4447
else:
45-
self['modules'] = 'ugrd.base.core'
48+
self["modules"] = "ugrd.base.core"
4649

4750
def import_args(self, args: dict) -> None:
48-
""" Imports data from an argument dict. """
51+
"""Imports data from an argument dict."""
4952
for arg, value in args.items():
5053
self.logger.info("Importing argument '%s' with value: %s" % (arg, value))
51-
if arg == 'modules': # allow loading modules by name from the command line
52-
for module in value.split(','):
54+
if arg == "modules": # allow loading modules by name from the command line
55+
for module in value.split(","):
5356
self[arg] = module
5457
else:
5558
self[arg] = value
5659

5760
def __setitem__(self, key: str, value) -> None:
58-
if self['validated']:
61+
if self["validated"]:
5962
return self.logger.error("[%s] Config is validatied, refusing to set value: %s" % (key, value))
6063
# If the type is registered, use the appropriate update function
61-
if any(key in d for d in (self.builtin_parameters, self['custom_parameters'])):
64+
if any(key in d for d in (self.builtin_parameters, self["custom_parameters"])):
6265
return self.handle_parameter(key, value)
6366
else:
64-
self.logger.debug("[%s] Unable to determine expected type, valid builtin types: %s" % (key, self.builtin_parameters.keys()))
65-
self.logger.debug("[%s] Custom types: %s" % (key, self['custom_parameters'].keys()))
67+
self.logger.debug(
68+
"[%s] Unable to determine expected type, valid builtin types: %s"
69+
% (key, self.builtin_parameters.keys())
70+
)
71+
self.logger.debug("[%s] Custom types: %s" % (key, self["custom_parameters"].keys()))
6672
# for anything but the logger, add to the processing queue
67-
if key != 'logger':
73+
if key != "logger":
6874
self.logger.debug("Adding unknown internal parameter to processing queue: %s" % key)
69-
if key not in self['_processing']:
70-
self['_processing'][key] = Queue()
71-
self['_processing'][key].put(value)
75+
if key not in self["_processing"]:
76+
self["_processing"][key] = Queue()
77+
self["_processing"][key].put(value)
7278

7379
def handle_parameter(self, key: str, value) -> None:
7480
"""
@@ -78,7 +84,7 @@ def handle_parameter(self, key: str, value) -> None:
7884
Uses custom processing functions if they are defined, otherwise uses the standard setters.
7985
"""
8086
# Get the expected type, first searching builtin_parameters, then custom_parameters
81-
for d in (self.builtin_parameters, self['custom_parameters']):
87+
for d in (self.builtin_parameters, self["custom_parameters"]):
8288
expected_type = d.get(key)
8389
if expected_type:
8490
if expected_type.__name__ == "InitramfsGenerator":
@@ -94,17 +100,17 @@ def handle_parameter(self, key: str, value) -> None:
94100

95101
# Don't use masked processing functions for custom values, fall back to standard setters
96102
def check_mask(import_name: str) -> bool:
97-
""" Checks if the funnction is masked. """
98-
return import_name in self.get('masks', [])
103+
"""Checks if the funnction is masked."""
104+
return import_name in self.get("masks", [])
99105

100-
if func := self['custom_processing'].get(f"_process_{key}"):
106+
if func := self["custom_processing"].get(f"_process_{key}"):
101107
if check_mask(func.__name__):
102108
self.logger.debug("Skipping masked function: %s" % func.__name__)
103109
else:
104110
self.logger.log(5, "[%s] Using custom setitem: %s" % (key, func.__name__))
105111
return func(self, value)
106112

107-
if func := self['custom_processing'].get(f"_process_{key}_multi"):
113+
if func := self["custom_processing"].get(f"_process_{key}_multi"):
108114
if check_mask(func.__name__):
109115
self.logger.debug("Skipping masked function: %s" % func.__name__)
110116
else:
@@ -115,7 +121,7 @@ def check_mask(import_name: str) -> bool:
115121
self.logger.log(5, "Using list setitem for: %s" % key)
116122
return self[key].append(value)
117123

118-
if expected_type == dict: # Create new keys, update existing
124+
if expected_type is dict: # Create new keys, update existing
119125
if key not in self:
120126
self.logger.log(5, "Setting dict '%s' to: %s" % (key, value))
121127
return super().__setitem__(key, value)
@@ -134,12 +140,12 @@ def _process_custom_parameters(self, parameter_name: str, parameter_type: type)
134140
"""
135141
from pycpio import PyCPIO
136142

137-
self['custom_parameters'][parameter_name] = eval(parameter_type)
143+
self["custom_parameters"][parameter_name] = eval(parameter_type)
138144
self.logger.debug("Registered custom parameter '%s' with type: %s" % (parameter_name, parameter_type))
139145

140146
match parameter_type:
141147
case "NoDupFlatList":
142-
self.data[parameter_name] = NoDupFlatList(no_warn=True, log_bump=5, logger=self.logger, _log_init=False)
148+
self.data[parameter_name] = NoDupFlatList(no_warn=True, log_bump=5, logger=self.logger)
143149
case "list" | "dict":
144150
self.data[parameter_name] = eval(parameter_type)()
145151
case "bool":
@@ -153,38 +159,38 @@ def _process_custom_parameters(self, parameter_name: str, parameter_type: type)
153159
case "Path":
154160
self.data[parameter_name] = Path()
155161
case "PyCPIO":
156-
self.data[parameter_name] = PyCPIO(logger=self.logger, _log_init=False, _log_bump=10)
162+
self.data[parameter_name] = PyCPIO(logger=self.logger, _log_bump=10)
157163
case _: # For strings and things, don't init them so they are None
158164
self.logger.warning("Leaving '%s' as None" % parameter_name)
159165
self.data[parameter_name] = None
160166

161167
def _process_unprocessed(self, parameter_name: str) -> None:
162-
""" Processes queued values for a parameter. """
163-
if parameter_name not in self['_processing']:
168+
"""Processes queued values for a parameter."""
169+
if parameter_name not in self["_processing"]:
164170
self.logger.log(5, "No queued values for: %s" % parameter_name)
165171
return
166172

167-
value_queue = self['_processing'].pop(parameter_name)
173+
value_queue = self["_processing"].pop(parameter_name)
168174
while not value_queue.empty():
169175
value = value_queue.get()
170-
if self['validated']: # Log at info level if the config has been validated
176+
if self["validated"]: # Log at info level if the config has been validated
171177
self.logger.info("[%s] Processing queued value: %s" % (parameter_name, value))
172178
else:
173179
self.logger.debug("[%s] Processing queued value: %s" % (parameter_name, value))
174180
self[parameter_name] = value
175181

176182
@handle_plural
177183
def _process_imports(self, import_type: str, import_value: dict) -> None:
178-
""" Processes imports in a module, importing the functions and adding them to the appropriate list. """
184+
"""Processes imports in a module, importing the functions and adding them to the appropriate list."""
179185
from importlib import import_module
180-
from importlib.util import spec_from_file_location, module_from_spec
186+
from importlib.util import module_from_spec, spec_from_file_location
181187

182188
for module_name, function_names in import_value.items():
183189
self.logger.debug("[%s]<%s> Importing module functions : %s" % (module_name, import_type, function_names))
184190
try:
185191
module = import_module(module_name)
186192
except ModuleNotFoundError as e:
187-
module_path = Path('/var/lib/ugrd/' + module_name.replace('.', '/')).with_suffix('.py')
193+
module_path = Path("/var/lib/ugrd/" + module_name.replace(".", "/")).with_suffix(".py")
188194
self.logger.debug("Attempting to sideload module from: %s" % module_path)
189195
if not module_path.exists():
190196
raise ModuleNotFoundError("Module not found: %s" % module_name) from e
@@ -196,77 +202,77 @@ def _process_imports(self, import_type: str, import_value: dict) -> None:
196202
raise ModuleNotFoundError("Unable to load module: %s" % module_name) from e
197203

198204
self.logger.log(5, "[%s] Imported module contents: %s" % (module_name, dir(module)))
199-
if '_module_name' in dir(module) and module._module_name != module_name:
205+
if "_module_name" in dir(module) and module._module_name != module_name:
200206
self.logger.warning("Module name mismatch: %s != %s" % (module._module_name, module_name))
201207

202208
function_list = [getattr(module, function_name) for function_name in function_names]
203209

204-
if import_type not in self['imports']:
210+
if import_type not in self["imports"]:
205211
self.logger.log(5, "Creating import type: %s" % import_type)
206-
self['imports'][import_type] = NoDupFlatList(log_bump=10, logger=self.logger, _log_init=False)
212+
self["imports"][import_type] = NoDupFlatList(log_bump=10, logger=self.logger)
207213

208-
if import_type == 'custom_init':
209-
if self['imports']['custom_init']:
210-
raise ValueError("Custom init function already defined: %s" % self['imports']['custom_init'])
214+
if import_type == "custom_init":
215+
if self["imports"]["custom_init"]:
216+
raise ValueError("Custom init function already defined: %s" % self["imports"]["custom_init"])
211217
else:
212-
self['imports']['custom_init'] = function_list[0]
218+
self["imports"]["custom_init"] = function_list[0]
213219
self.logger.info("Registered custom init function: %s" % function_list[0].__name__)
214220
continue
215221

216-
if import_type == 'funcs':
222+
if import_type == "funcs":
217223
for function in function_list:
218-
if function.__name__ in self['imports']['funcs']:
224+
if function.__name__ in self["imports"]["funcs"]:
219225
raise ValueError("Function '%s' already registered" % function.__name__)
220-
if function.__name__ in self['binaries']:
226+
if function.__name__ in self["binaries"]:
221227
raise ValueError("Function collides with defined binary: %s'" % function.__name__)
222228

223-
self['imports'][import_type] += function_list
229+
self["imports"][import_type] += function_list
224230
self.logger.debug("[%s] Updated import functions: %s" % (import_type, function_list))
225231

226-
if import_type == 'config_processing':
232+
if import_type == "config_processing":
227233
for function in function_list:
228-
self['custom_processing'][function.__name__] = function
234+
self["custom_processing"][function.__name__] = function
229235
self.logger.debug("Registered config processing function: %s" % function.__name__)
230-
self._process_unprocessed(function.__name__.removeprefix('_process_'))
236+
self._process_unprocessed(function.__name__.removeprefix("_process_"))
231237

232238
@handle_plural
233239
def _process_modules(self, module: str) -> None:
234240
"""
235241
processes a single module into the config
236242
takes list with decorator
237243
"""
238-
if module in self['modules']:
244+
if module in self["modules"]:
239245
self.logger.debug("Module '%s' already loaded" % module)
240246
return
241247

242248
self.logger.info("Processing module: %s" % module)
243249

244-
module_subpath = module.replace('.', '/') + '.toml'
250+
module_subpath = module.replace(".", "/") + ".toml"
245251

246252
module_path = Path(__file__).parent.parent / module_subpath
247253
if not module_path.exists():
248-
module_path = Path('/var/lib/ugrd') / module_subpath
254+
module_path = Path("/var/lib/ugrd") / module_subpath
249255
if not module_path.exists():
250256
raise FileNotFoundError("Unable to locate module: %s" % module)
251257
self.logger.debug("Module path: %s" % module_path)
252258

253-
with open(module_path, 'rb') as module_file:
259+
with open(module_path, "rb") as module_file:
254260
try:
255261
module_config = load(module_file)
256262
except TOMLDecodeError as e:
257263
raise TOMLDecodeError("Unable to load module config: %s" % module) from e
258264

259-
if imports := module_config.get('imports'):
265+
if imports := module_config.get("imports"):
260266
self.logger.debug("[%s] Processing imports: %s" % (module, imports))
261-
self['imports'] = imports
267+
self["imports"] = imports
262268

263-
custom_parameters = module_config.get('custom_parameters', {})
269+
custom_parameters = module_config.get("custom_parameters", {})
264270
if custom_parameters:
265271
self.logger.debug("[%s] Processing custom parameters: %s" % (module, custom_parameters))
266-
self['custom_parameters'] = custom_parameters
272+
self["custom_parameters"] = custom_parameters
267273

268274
for name, value in module_config.items(): # Process config values, in order they are defined
269-
if name in ['imports', 'custom_parameters']:
275+
if name in ["imports", "custom_parameters"]:
270276
self.logger.log(5, "[%s] Skipping '%s'" % (module, name))
271277
continue
272278
self.logger.debug("[%s] (%s) Setting value: %s" % (module, name, value))
@@ -277,13 +283,13 @@ def _process_modules(self, module: str) -> None:
277283
self._process_unprocessed(custom_parameter)
278284

279285
# Append the module to the list of loaded modules, avoid recursion
280-
self['modules'].append(module)
286+
self["modules"].append(module)
281287

282288
def validate(self) -> None:
283-
""" Validate config, checks that all values are processed, sets validated flag."""
284-
if self['_processing']:
285-
self.logger.critical("Unprocessed config values: %s" % ', '.join(list(self['_processing'].keys())))
286-
self['validated'] = True
289+
"""Validate config, checks that all values are processed, sets validated flag."""
290+
if self["_processing"]:
291+
self.logger.critical("Unprocessed config values: %s" % ", ".join(list(self["_processing"].keys())))
292+
self["validated"] = True
287293

288294
def __str__(self) -> str:
289295
return pretty_print(self.data)

0 commit comments

Comments
 (0)