2424import copy
2525import logging
2626from itertools import repeat
27- from typing import Iterable , Optional
27+ from typing import Iterable , Optional , Union , overload
2828
2929import matplotlib .pyplot as plt
3030import numpy as np
@@ -119,7 +119,7 @@ def clear(self):
119119 """Reinitialize attributes."""
120120 self ._data = dict () # {hazard_type : {id:ImpactFunc}}
121121
122- def append (self , func ):
122+ def append (self , func : ImpactFunc ):
123123 """Append a ImpactFunc. Overwrite existing if same id and haz_type.
124124
125125 Parameters
@@ -141,7 +141,9 @@ def append(self, func):
141141 self ._data [func .haz_type ] = dict ()
142142 self ._data [func .haz_type ][func .id ] = func
143143
144- def remove_func (self , haz_type = None , fun_id = None ):
144+ def remove_func (
145+ self , haz_type : Optional [str ] = None , fun_id : Optional [str | int ] = None
146+ ):
145147 """Remove impact function(s) with provided hazard type and/or id.
146148 If no input provided, all impact functions are removed.
147149
@@ -173,7 +175,29 @@ def remove_func(self, haz_type=None, fun_id=None):
173175 else :
174176 self ._data = dict ()
175177
176- def get_func (self , haz_type = None , fun_id = None ):
178+ @overload
179+ def get_func (
180+ self , haz_type : None = None , fun_id : None = None
181+ ) -> dict [str , dict [Union [int , str ], ImpactFunc ]]: ...
182+
183+ @overload
184+ def get_func (
185+ self , haz_type : None = ..., fun_id : int | str = ...
186+ ) -> list [ImpactFunc ]: ...
187+
188+ @overload
189+ def get_func (
190+ self , haz_type : str = ..., fun_id : None = None
191+ ) -> list [ImpactFunc ]: ...
192+
193+ @overload
194+ def get_func (self , haz_type : str = ..., fun_id : int | str = ...) -> ImpactFunc : ...
195+
196+ def get_func (
197+ self , haz_type : Optional [str ] = None , fun_id : Optional [int | str ] = None
198+ ) -> Union [
199+ ImpactFunc , list [ImpactFunc ], dict [str , dict [Union [int , str ], ImpactFunc ]]
200+ ]:
177201 """Get ImpactFunc(s) of input hazard type and/or id.
178202 If no input provided, all impact functions are returned.
179203
@@ -209,7 +233,7 @@ def get_func(self, haz_type=None, fun_id=None):
209233 else :
210234 return self ._data
211235
212- def get_hazard_types (self , fun_id = None ):
236+ def get_hazard_types (self , fun_id : Optional [ str | int ] = None ) -> list [ str ] :
213237 """Get impact functions hazard types contained for the id provided.
214238 Return all hazard types if no input id.
215239
@@ -231,7 +255,15 @@ def get_hazard_types(self, fun_id=None):
231255 haz_types .append (vul_haz )
232256 return haz_types
233257
234- def get_ids (self , haz_type = None ):
258+ @overload
259+ def get_ids (self , haz_type : None = None ) -> dict [str , list [str | int ]]: ...
260+
261+ @overload
262+ def get_ids (self , haz_type : str ) -> list [int | str ]: ...
263+
264+ def get_ids (
265+ self , haz_type : Optional [str ] = None
266+ ) -> dict [str , list [str | int ]] | list [int | str ]:
235267 """Get impact functions ids contained for the hazard type provided.
236268 Return all ids for each hazard type if no input hazard type.
237269
@@ -256,7 +288,9 @@ def get_ids(self, haz_type=None):
256288 except KeyError :
257289 return list ()
258290
259- def size (self , haz_type = None , fun_id = None ):
291+ def size (
292+ self , haz_type : Optional [str ] = None , fun_id : Optional [str | int ] = None
293+ ) -> int :
260294 """Get number of impact functions contained with input hazard type and
261295 /or id. If no input provided, get total number of impact functions.
262296
@@ -279,6 +313,7 @@ def size(self, haz_type=None, fun_id=None):
279313 return 1
280314 if (haz_type is not None ) or (fun_id is not None ):
281315 return len (self .get_func (haz_type , fun_id ))
316+
282317 return sum (len (vul_list ) for vul_list in self .get_ids ().values ())
283318
284319 def check (self ):
@@ -300,7 +335,7 @@ def check(self):
300335 )
301336 vul .check ()
302337
303- def extend (self , impact_funcs ):
338+ def extend (self , impact_funcs : "ImpactFuncSet" ):
304339 """Append impact functions of input ImpactFuncSet to current
305340 ImpactFuncSet. Overwrite ImpactFunc if same id and haz_type.
306341
@@ -323,7 +358,13 @@ def extend(self, impact_funcs):
323358 for _ , vul in vul_dict .items ():
324359 self .append (vul )
325360
326- def plot (self , haz_type = None , fun_id = None , axis = None , ** kwargs ):
361+ def plot (
362+ self ,
363+ haz_type : Optional [str ] = None ,
364+ fun_id : Optional [str | int ] = None ,
365+ axis = None ,
366+ ** kwargs ,
367+ ):
327368 """Plot impact functions of selected hazard (all if not provided) and
328369 selected function id (all if not provided).
329370
0 commit comments