11__author__ = "desultory"
22__version__ = "2.2.1"
33
4- from tomllib import load , TOMLDecodeError
4+ from collections import UserDict
55from pathlib import Path
66from queue import Queue
7- from collections import UserDict
7+ from tomllib import TOMLDecodeError , load
88
99from zenlib .logging import loggify
10- from zenlib .util import handle_plural , pretty_print , NoDupFlatList
10+ from zenlib .util import NoDupFlatList , handle_plural , pretty_print
1111
1212
1313@loggify
@@ -24,51 +24,57 @@ class InitramfsConfigDict(UserDict):
2424
2525 If parameters which are not registerd are set, they are added to the processing queue and processed when the type is known.
2626 """
27- builtin_parameters = {'modules' : NoDupFlatList , # A list of the names of modules which have been loaded, mostly used for dependency checking
28- 'imports' : dict , # A dict of functions to be imported into the initramfs, under their respective hooks
29- 'validated' : bool , # A flag to indicate if the config has been validated, mostly used for log levels
30- 'custom_parameters' : dict , # Custom parameters loaded from imports
31- 'custom_processing' : dict , # Custom processing functions which will be run to validate and process parameters
32- '_processing' : dict } # A dict of queues containing parameters which have been set before the type was known
27+
28+ builtin_parameters = {
29+ "modules" : NoDupFlatList , # A list of the names of modules which have been loaded, mostly used for dependency checking
30+ "imports" : dict , # A dict of functions to be imported into the initramfs, under their respective hooks
31+ "validated" : bool , # A flag to indicate if the config has been validated, mostly used for log levels
32+ "custom_parameters" : dict , # Custom parameters loaded from imports
33+ "custom_processing" : dict , # Custom processing functions which will be run to validate and process parameters
34+ "_processing" : dict ,
35+ } # A dict of queues containing parameters which have been set before the type was known
3336
3437 def __init__ (self , NO_BASE = False , * args , ** kwargs ):
3538 super ().__init__ (* args , ** kwargs )
3639 # Define the default parameters
3740 for parameter , default_type in self .builtin_parameters .items ():
3841 if default_type == NoDupFlatList :
39- self .data [parameter ] = default_type (no_warn = True , log_bump = 5 , logger = self .logger , _log_init = False )
42+ self .data [parameter ] = default_type (no_warn = True , log_bump = 5 , logger = self .logger )
4043 else :
4144 self .data [parameter ] = default_type ()
4245 if not NO_BASE :
43- self [' modules' ] = ' ugrd.base.base'
46+ self [" modules" ] = " ugrd.base.base"
4447 else :
45- self [' modules' ] = ' ugrd.base.core'
48+ self [" modules" ] = " ugrd.base.core"
4649
4750 def import_args (self , args : dict ) -> None :
48- """ Imports data from an argument dict. """
51+ """Imports data from an argument dict."""
4952 for arg , value in args .items ():
5053 self .logger .info ("Importing argument '%s' with value: %s" % (arg , value ))
51- if arg == ' modules' : # allow loading modules by name from the command line
52- for module in value .split (',' ):
54+ if arg == " modules" : # allow loading modules by name from the command line
55+ for module in value .split ("," ):
5356 self [arg ] = module
5457 else :
5558 self [arg ] = value
5659
5760 def __setitem__ (self , key : str , value ) -> None :
58- if self [' validated' ]:
61+ if self [" validated" ]:
5962 return self .logger .error ("[%s] Config is validatied, refusing to set value: %s" % (key , value ))
6063 # If the type is registered, use the appropriate update function
61- if any (key in d for d in (self .builtin_parameters , self [' custom_parameters' ])):
64+ if any (key in d for d in (self .builtin_parameters , self [" custom_parameters" ])):
6265 return self .handle_parameter (key , value )
6366 else :
64- self .logger .debug ("[%s] Unable to determine expected type, valid builtin types: %s" % (key , self .builtin_parameters .keys ()))
65- self .logger .debug ("[%s] Custom types: %s" % (key , self ['custom_parameters' ].keys ()))
67+ self .logger .debug (
68+ "[%s] Unable to determine expected type, valid builtin types: %s"
69+ % (key , self .builtin_parameters .keys ())
70+ )
71+ self .logger .debug ("[%s] Custom types: %s" % (key , self ["custom_parameters" ].keys ()))
6672 # for anything but the logger, add to the processing queue
67- if key != ' logger' :
73+ if key != " logger" :
6874 self .logger .debug ("Adding unknown internal parameter to processing queue: %s" % key )
69- if key not in self [' _processing' ]:
70- self [' _processing' ][key ] = Queue ()
71- self [' _processing' ][key ].put (value )
75+ if key not in self [" _processing" ]:
76+ self [" _processing" ][key ] = Queue ()
77+ self [" _processing" ][key ].put (value )
7278
7379 def handle_parameter (self , key : str , value ) -> None :
7480 """
@@ -78,7 +84,7 @@ def handle_parameter(self, key: str, value) -> None:
7884 Uses custom processing functions if they are defined, otherwise uses the standard setters.
7985 """
8086 # Get the expected type, first searching builtin_parameters, then custom_parameters
81- for d in (self .builtin_parameters , self [' custom_parameters' ]):
87+ for d in (self .builtin_parameters , self [" custom_parameters" ]):
8288 expected_type = d .get (key )
8389 if expected_type :
8490 if expected_type .__name__ == "InitramfsGenerator" :
@@ -94,17 +100,17 @@ def handle_parameter(self, key: str, value) -> None:
94100
95101 # Don't use masked processing functions for custom values, fall back to standard setters
96102 def check_mask (import_name : str ) -> bool :
97- """ Checks if the funnction is masked. """
98- return import_name in self .get (' masks' , [])
103+ """Checks if the funnction is masked."""
104+ return import_name in self .get (" masks" , [])
99105
100- if func := self [' custom_processing' ].get (f"_process_{ key } " ):
106+ if func := self [" custom_processing" ].get (f"_process_{ key } " ):
101107 if check_mask (func .__name__ ):
102108 self .logger .debug ("Skipping masked function: %s" % func .__name__ )
103109 else :
104110 self .logger .log (5 , "[%s] Using custom setitem: %s" % (key , func .__name__ ))
105111 return func (self , value )
106112
107- if func := self [' custom_processing' ].get (f"_process_{ key } _multi" ):
113+ if func := self [" custom_processing" ].get (f"_process_{ key } _multi" ):
108114 if check_mask (func .__name__ ):
109115 self .logger .debug ("Skipping masked function: %s" % func .__name__ )
110116 else :
@@ -115,7 +121,7 @@ def check_mask(import_name: str) -> bool:
115121 self .logger .log (5 , "Using list setitem for: %s" % key )
116122 return self [key ].append (value )
117123
118- if expected_type == dict : # Create new keys, update existing
124+ if expected_type is dict : # Create new keys, update existing
119125 if key not in self :
120126 self .logger .log (5 , "Setting dict '%s' to: %s" % (key , value ))
121127 return super ().__setitem__ (key , value )
@@ -134,12 +140,12 @@ def _process_custom_parameters(self, parameter_name: str, parameter_type: type)
134140 """
135141 from pycpio import PyCPIO
136142
137- self [' custom_parameters' ][parameter_name ] = eval (parameter_type )
143+ self [" custom_parameters" ][parameter_name ] = eval (parameter_type )
138144 self .logger .debug ("Registered custom parameter '%s' with type: %s" % (parameter_name , parameter_type ))
139145
140146 match parameter_type :
141147 case "NoDupFlatList" :
142- self .data [parameter_name ] = NoDupFlatList (no_warn = True , log_bump = 5 , logger = self .logger , _log_init = False )
148+ self .data [parameter_name ] = NoDupFlatList (no_warn = True , log_bump = 5 , logger = self .logger )
143149 case "list" | "dict" :
144150 self .data [parameter_name ] = eval (parameter_type )()
145151 case "bool" :
@@ -153,38 +159,38 @@ def _process_custom_parameters(self, parameter_name: str, parameter_type: type)
153159 case "Path" :
154160 self .data [parameter_name ] = Path ()
155161 case "PyCPIO" :
156- self .data [parameter_name ] = PyCPIO (logger = self .logger , _log_init = False , _log_bump = 10 )
162+ self .data [parameter_name ] = PyCPIO (logger = self .logger , _log_bump = 10 )
157163 case _: # For strings and things, don't init them so they are None
158164 self .logger .warning ("Leaving '%s' as None" % parameter_name )
159165 self .data [parameter_name ] = None
160166
161167 def _process_unprocessed (self , parameter_name : str ) -> None :
162- """ Processes queued values for a parameter. """
163- if parameter_name not in self [' _processing' ]:
168+ """Processes queued values for a parameter."""
169+ if parameter_name not in self [" _processing" ]:
164170 self .logger .log (5 , "No queued values for: %s" % parameter_name )
165171 return
166172
167- value_queue = self [' _processing' ].pop (parameter_name )
173+ value_queue = self [" _processing" ].pop (parameter_name )
168174 while not value_queue .empty ():
169175 value = value_queue .get ()
170- if self [' validated' ]: # Log at info level if the config has been validated
176+ if self [" validated" ]: # Log at info level if the config has been validated
171177 self .logger .info ("[%s] Processing queued value: %s" % (parameter_name , value ))
172178 else :
173179 self .logger .debug ("[%s] Processing queued value: %s" % (parameter_name , value ))
174180 self [parameter_name ] = value
175181
176182 @handle_plural
177183 def _process_imports (self , import_type : str , import_value : dict ) -> None :
178- """ Processes imports in a module, importing the functions and adding them to the appropriate list. """
184+ """Processes imports in a module, importing the functions and adding them to the appropriate list."""
179185 from importlib import import_module
180- from importlib .util import spec_from_file_location , module_from_spec
186+ from importlib .util import module_from_spec , spec_from_file_location
181187
182188 for module_name , function_names in import_value .items ():
183189 self .logger .debug ("[%s]<%s> Importing module functions : %s" % (module_name , import_type , function_names ))
184190 try :
185191 module = import_module (module_name )
186192 except ModuleNotFoundError as e :
187- module_path = Path (' /var/lib/ugrd/' + module_name .replace ('.' , '/' )).with_suffix (' .py' )
193+ module_path = Path (" /var/lib/ugrd/" + module_name .replace ("." , "/" )).with_suffix (" .py" )
188194 self .logger .debug ("Attempting to sideload module from: %s" % module_path )
189195 if not module_path .exists ():
190196 raise ModuleNotFoundError ("Module not found: %s" % module_name ) from e
@@ -196,77 +202,77 @@ def _process_imports(self, import_type: str, import_value: dict) -> None:
196202 raise ModuleNotFoundError ("Unable to load module: %s" % module_name ) from e
197203
198204 self .logger .log (5 , "[%s] Imported module contents: %s" % (module_name , dir (module )))
199- if ' _module_name' in dir (module ) and module ._module_name != module_name :
205+ if " _module_name" in dir (module ) and module ._module_name != module_name :
200206 self .logger .warning ("Module name mismatch: %s != %s" % (module ._module_name , module_name ))
201207
202208 function_list = [getattr (module , function_name ) for function_name in function_names ]
203209
204- if import_type not in self [' imports' ]:
210+ if import_type not in self [" imports" ]:
205211 self .logger .log (5 , "Creating import type: %s" % import_type )
206- self [' imports' ][import_type ] = NoDupFlatList (log_bump = 10 , logger = self .logger , _log_init = False )
212+ self [" imports" ][import_type ] = NoDupFlatList (log_bump = 10 , logger = self .logger )
207213
208- if import_type == ' custom_init' :
209- if self [' imports' ][ ' custom_init' ]:
210- raise ValueError ("Custom init function already defined: %s" % self [' imports' ][ ' custom_init' ])
214+ if import_type == " custom_init" :
215+ if self [" imports" ][ " custom_init" ]:
216+ raise ValueError ("Custom init function already defined: %s" % self [" imports" ][ " custom_init" ])
211217 else :
212- self [' imports' ][ ' custom_init' ] = function_list [0 ]
218+ self [" imports" ][ " custom_init" ] = function_list [0 ]
213219 self .logger .info ("Registered custom init function: %s" % function_list [0 ].__name__ )
214220 continue
215221
216- if import_type == ' funcs' :
222+ if import_type == " funcs" :
217223 for function in function_list :
218- if function .__name__ in self [' imports' ][ ' funcs' ]:
224+ if function .__name__ in self [" imports" ][ " funcs" ]:
219225 raise ValueError ("Function '%s' already registered" % function .__name__ )
220- if function .__name__ in self [' binaries' ]:
226+ if function .__name__ in self [" binaries" ]:
221227 raise ValueError ("Function collides with defined binary: %s'" % function .__name__ )
222228
223- self [' imports' ][import_type ] += function_list
229+ self [" imports" ][import_type ] += function_list
224230 self .logger .debug ("[%s] Updated import functions: %s" % (import_type , function_list ))
225231
226- if import_type == ' config_processing' :
232+ if import_type == " config_processing" :
227233 for function in function_list :
228- self [' custom_processing' ][function .__name__ ] = function
234+ self [" custom_processing" ][function .__name__ ] = function
229235 self .logger .debug ("Registered config processing function: %s" % function .__name__ )
230- self ._process_unprocessed (function .__name__ .removeprefix (' _process_' ))
236+ self ._process_unprocessed (function .__name__ .removeprefix (" _process_" ))
231237
232238 @handle_plural
233239 def _process_modules (self , module : str ) -> None :
234240 """
235241 processes a single module into the config
236242 takes list with decorator
237243 """
238- if module in self [' modules' ]:
244+ if module in self [" modules" ]:
239245 self .logger .debug ("Module '%s' already loaded" % module )
240246 return
241247
242248 self .logger .info ("Processing module: %s" % module )
243249
244- module_subpath = module .replace ('.' , '/' ) + ' .toml'
250+ module_subpath = module .replace ("." , "/" ) + " .toml"
245251
246252 module_path = Path (__file__ ).parent .parent / module_subpath
247253 if not module_path .exists ():
248- module_path = Path (' /var/lib/ugrd' ) / module_subpath
254+ module_path = Path (" /var/lib/ugrd" ) / module_subpath
249255 if not module_path .exists ():
250256 raise FileNotFoundError ("Unable to locate module: %s" % module )
251257 self .logger .debug ("Module path: %s" % module_path )
252258
253- with open (module_path , 'rb' ) as module_file :
259+ with open (module_path , "rb" ) as module_file :
254260 try :
255261 module_config = load (module_file )
256262 except TOMLDecodeError as e :
257263 raise TOMLDecodeError ("Unable to load module config: %s" % module ) from e
258264
259- if imports := module_config .get (' imports' ):
265+ if imports := module_config .get (" imports" ):
260266 self .logger .debug ("[%s] Processing imports: %s" % (module , imports ))
261- self [' imports' ] = imports
267+ self [" imports" ] = imports
262268
263- custom_parameters = module_config .get (' custom_parameters' , {})
269+ custom_parameters = module_config .get (" custom_parameters" , {})
264270 if custom_parameters :
265271 self .logger .debug ("[%s] Processing custom parameters: %s" % (module , custom_parameters ))
266- self [' custom_parameters' ] = custom_parameters
272+ self [" custom_parameters" ] = custom_parameters
267273
268274 for name , value in module_config .items (): # Process config values, in order they are defined
269- if name in [' imports' , ' custom_parameters' ]:
275+ if name in [" imports" , " custom_parameters" ]:
270276 self .logger .log (5 , "[%s] Skipping '%s'" % (module , name ))
271277 continue
272278 self .logger .debug ("[%s] (%s) Setting value: %s" % (module , name , value ))
@@ -277,13 +283,13 @@ def _process_modules(self, module: str) -> None:
277283 self ._process_unprocessed (custom_parameter )
278284
279285 # Append the module to the list of loaded modules, avoid recursion
280- self [' modules' ].append (module )
286+ self [" modules" ].append (module )
281287
282288 def validate (self ) -> None :
283- """ Validate config, checks that all values are processed, sets validated flag."""
284- if self [' _processing' ]:
285- self .logger .critical ("Unprocessed config values: %s" % ', ' .join (list (self [' _processing' ].keys ())))
286- self [' validated' ] = True
289+ """Validate config, checks that all values are processed, sets validated flag."""
290+ if self [" _processing" ]:
291+ self .logger .critical ("Unprocessed config values: %s" % ", " .join (list (self [" _processing" ].keys ())))
292+ self [" validated" ] = True
287293
288294 def __str__ (self ) -> str :
289295 return pretty_print (self .data )
0 commit comments