11import builtins
22import keyword
33import os
4+ from typing import Any
45
56import yaml
67from black import FileMode , format_str
910predefined = ["description" , "default" , "data_type" , "required" , "alias" , "units" ]
1011
1112
12- def _find_alias (all_data : dict , head : list | None = None ):
13+ def _find_alias (all_data : dict , head : list | None = None ) -> dict [ str , str ] :
1314 """
1415 Find all aliases in the data structure.
1516
@@ -31,13 +32,13 @@ def _find_alias(all_data: dict, head: list | None = None):
3132 return results
3233
3334
34- def _replace_alias (all_data : dict ):
35+ def _replace_alias (all_data : dict ) -> dict :
3536 for key , value in _find_alias (all_data ).items ():
3637 _set (all_data , key , _get (all_data , value ))
3738 return all_data
3839
3940
40- def _get (obj : dict , path : str , sep : str = "/" ):
41+ def _get (obj : dict , path : str , sep : str = "/" ) -> Any :
4142 """
4243 Get a value from a nested dictionary.
4344
@@ -69,13 +70,13 @@ def _set(obj: dict, path: str, value):
6970 obj [last ] = value
7071
7172
72- def _get_safe_parameter_name (name : str ):
73+ def _get_safe_parameter_name (name : str ) -> str :
7374 if keyword .iskeyword (name ) or name in dir (builtins ):
7475 name = name + "_"
7576 return name
7677
7778
78- def _get_docstring_line (data : dict , key : str ):
79+ def _get_docstring_line (data : dict , key : str ) -> str :
7980 """
8081 Get a single line for the docstring.
8182
@@ -108,7 +109,12 @@ def _get_docstring_line(data: dict, key: str):
108109 return line
109110
110111
111- def _get_docstring (all_data , description = None , indent = indent , predefined = predefined ):
112+ def _get_docstring (
113+ all_data : dict ,
114+ description : str | None = None ,
115+ indent : str = indent ,
116+ predefined : list [str ] = predefined ,
117+ ) -> list [str ]:
112118 txt = [indent + '"""' ]
113119 if description is not None :
114120 txt .append (f"{ indent } { description } \n " )
@@ -123,7 +129,7 @@ def _get_docstring(all_data, description=None, indent=indent, predefined=predefi
123129 return txt
124130
125131
126- def _get_input_arg (key , entry , indent = indent ):
132+ def _get_input_arg (key : str , entry : dict , indent : str = indent ) -> str :
127133 t = entry .get ("data_type" , "dict" )
128134 units = "" .join (entry .get ("units" , "" ).split ())
129135 if not entry .get ("required" , False ) and units != "" :
@@ -136,7 +142,7 @@ def _get_input_arg(key, entry, indent=indent):
136142 return t
137143
138144
139- def _rename_keys (data ) :
145+ def _rename_keys (data : dict ) -> dict :
140146 d_1 = {_get_safe_parameter_name (key ): value for key , value in data .items ()}
141147 d_2 = {
142148 key : d
@@ -148,11 +154,11 @@ def _rename_keys(data):
148154
149155
150156def _get_function (
151- data ,
157+ data : dict ,
152158 function_name : list [str ],
153- predefined = predefined ,
154- is_kwarg = False ,
155- ):
159+ predefined : list [ str ] = predefined ,
160+ is_kwarg : bool = False ,
161+ ) -> str :
156162 d = _rename_keys (data )
157163 func = []
158164 if is_kwarg :
@@ -187,7 +193,9 @@ def _get_function(
187193 return "\n " .join (result )
188194
189195
190- def _get_all_function_names (all_data , head = "" , predefined = predefined ):
196+ def _get_all_function_names (
197+ all_data : dict , head : str = "" , predefined : list [str ] = predefined
198+ ) -> list [str ]:
191199 key_lst = []
192200 for tag , data in all_data .items ():
193201 if tag not in predefined and data .get ("data_type" , "dict" ) == "dict" :
@@ -196,7 +204,7 @@ def _get_all_function_names(all_data, head="", predefined=predefined):
196204 return key_lst
197205
198206
199- def _get_class (all_data ) :
207+ def _get_class (all_data : dict ) -> str :
200208 fnames = _get_all_function_names (all_data )
201209 txt = ""
202210 for name in fnames :
@@ -216,7 +224,7 @@ def _get_class(all_data):
216224 return txt
217225
218226
219- def _get_file_content (yml_file_name = "input_data.yml" ):
227+ def _get_file_content (yml_file_name : str = "input_data.yml" ) -> str :
220228 file_location = os .path .join (os .path .dirname (__file__ ), yml_file_name )
221229 with open (file_location , "r" ) as f :
222230 file_content = f .read ()
@@ -246,7 +254,7 @@ def _get_file_content(yml_file_name="input_data.yml"):
246254 return file_content
247255
248256
249- def export_class (yml_file_name = "input_data.yml" , py_file_name = "input.py" ):
257+ def export_class (yml_file_name : str = "input_data.yml" , py_file_name : str = "input.py" ):
250258 file_content = _get_file_content (yml_file_name )
251259 with open (os .path .join (os .path .dirname (__file__ ), ".." , py_file_name ), "w" ) as f :
252260 f .write (file_content )
0 commit comments