Skip to content

Commit 33e7bac

Browse files
authored
Merge pull request #1250 from CLIMADA-project/feature/impf-type-hints
Better type hints and overloads signatures for ImpactFuncSet
2 parents a031f4e + 8731d07 commit 33e7bac

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Code freeze date: YYYY-MM-DD
1212

1313
### Added
1414

15+
- Better type hints and overloads signatures for ImpactFuncSet [#1250](https://github.com/CLIMADA-project/climada_python/pull/1250)
16+
1517
### Changed
1618
- Updated Impact Calculation Tutorial (`doc.climada_engine_Impact.ipynb`) [#1095](https://github.com/CLIMADA-project/climada_python/pull/1095).
1719

climada/entity/impact_funcs/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def from_step_impf(
189189
haz_type: str,
190190
mdd: tuple[float, float] = (0, 1),
191191
paa: tuple[float, float] = (1, 1),
192-
impf_id: int = 1,
192+
impf_id: int | str = 1,
193193
**kwargs,
194194
):
195195
"""Step function type impact function.
@@ -207,7 +207,7 @@ def from_step_impf(
207207
(min, max) mdd values. The default is (0, 1)
208208
paa: tuple(float, float)
209209
(min, max) paa values. The default is (1, 1)
210-
impf_id : int, optional, default=1
210+
impf_id : int|str, optional, default=1
211211
impact function id
212212
kwargs :
213213
keyword arguments passed to ImpactFunc()
@@ -250,7 +250,7 @@ def from_sigmoid_impf(
250250
k: float,
251251
x0: float,
252252
haz_type: str,
253-
impf_id: int = 1,
253+
impf_id: int | str = 1,
254254
**kwargs,
255255
):
256256
r"""Sigmoid type impact function hinging on three parameter.
@@ -320,7 +320,7 @@ def from_poly_s_shape(
320320
scale: float,
321321
exponent: float,
322322
haz_type: str,
323-
impf_id: int = 1,
323+
impf_id: int | str = 1,
324324
**kwargs,
325325
):
326326
r"""S-shape polynomial impact function hinging on four parameter.

climada/entity/impact_funcs/impact_func_set.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import copy
2525
import logging
2626
from itertools import repeat
27-
from typing import Iterable, Optional
27+
from typing import Iterable, Optional, Union, overload
2828

2929
import matplotlib.pyplot as plt
3030
import numpy as np
@@ -119,7 +119,7 @@ def clear(self):
119119
"""Reinitialize attributes."""
120120
self._data = dict() # {hazard_type : {id:ImpactFunc}}
121121

122-
def append(self, func):
122+
def append(self, func: ImpactFunc):
123123
"""Append a ImpactFunc. Overwrite existing if same id and haz_type.
124124
125125
Parameters
@@ -141,7 +141,9 @@ def append(self, func):
141141
self._data[func.haz_type] = dict()
142142
self._data[func.haz_type][func.id] = func
143143

144-
def remove_func(self, haz_type=None, fun_id=None):
144+
def remove_func(
145+
self, haz_type: Optional[str] = None, fun_id: Optional[str | int] = None
146+
):
145147
"""Remove impact function(s) with provided hazard type and/or id.
146148
If no input provided, all impact functions are removed.
147149
@@ -173,7 +175,29 @@ def remove_func(self, haz_type=None, fun_id=None):
173175
else:
174176
self._data = dict()
175177

176-
def get_func(self, haz_type=None, fun_id=None):
178+
@overload
179+
def get_func(
180+
self, haz_type: None = None, fun_id: None = None
181+
) -> dict[str, dict[Union[int, str], ImpactFunc]]: ...
182+
183+
@overload
184+
def get_func(
185+
self, haz_type: None = ..., fun_id: int | str = ...
186+
) -> list[ImpactFunc]: ...
187+
188+
@overload
189+
def get_func(
190+
self, haz_type: str = ..., fun_id: None = None
191+
) -> list[ImpactFunc]: ...
192+
193+
@overload
194+
def get_func(self, haz_type: str = ..., fun_id: int | str = ...) -> ImpactFunc: ...
195+
196+
def get_func(
197+
self, haz_type: Optional[str] = None, fun_id: Optional[int | str] = None
198+
) -> Union[
199+
ImpactFunc, list[ImpactFunc], dict[str, dict[Union[int, str], ImpactFunc]]
200+
]:
177201
"""Get ImpactFunc(s) of input hazard type and/or id.
178202
If no input provided, all impact functions are returned.
179203
@@ -209,7 +233,7 @@ def get_func(self, haz_type=None, fun_id=None):
209233
else:
210234
return self._data
211235

212-
def get_hazard_types(self, fun_id=None):
236+
def get_hazard_types(self, fun_id: Optional[str | int] = None) -> list[str]:
213237
"""Get impact functions hazard types contained for the id provided.
214238
Return all hazard types if no input id.
215239
@@ -231,7 +255,15 @@ def get_hazard_types(self, fun_id=None):
231255
haz_types.append(vul_haz)
232256
return haz_types
233257

234-
def get_ids(self, haz_type=None):
258+
@overload
259+
def get_ids(self, haz_type: None = None) -> dict[str, list[str | int]]: ...
260+
261+
@overload
262+
def get_ids(self, haz_type: str) -> list[int | str]: ...
263+
264+
def get_ids(
265+
self, haz_type: Optional[str] = None
266+
) -> dict[str, list[str | int]] | list[int | str]:
235267
"""Get impact functions ids contained for the hazard type provided.
236268
Return all ids for each hazard type if no input hazard type.
237269
@@ -256,7 +288,9 @@ def get_ids(self, haz_type=None):
256288
except KeyError:
257289
return list()
258290

259-
def size(self, haz_type=None, fun_id=None):
291+
def size(
292+
self, haz_type: Optional[str] = None, fun_id: Optional[str | int] = None
293+
) -> int:
260294
"""Get number of impact functions contained with input hazard type and
261295
/or id. If no input provided, get total number of impact functions.
262296
@@ -279,6 +313,7 @@ def size(self, haz_type=None, fun_id=None):
279313
return 1
280314
if (haz_type is not None) or (fun_id is not None):
281315
return len(self.get_func(haz_type, fun_id))
316+
282317
return sum(len(vul_list) for vul_list in self.get_ids().values())
283318

284319
def check(self):
@@ -300,7 +335,7 @@ def check(self):
300335
)
301336
vul.check()
302337

303-
def extend(self, impact_funcs):
338+
def extend(self, impact_funcs: "ImpactFuncSet"):
304339
"""Append impact functions of input ImpactFuncSet to current
305340
ImpactFuncSet. Overwrite ImpactFunc if same id and haz_type.
306341
@@ -323,7 +358,13 @@ def extend(self, impact_funcs):
323358
for _, vul in vul_dict.items():
324359
self.append(vul)
325360

326-
def plot(self, haz_type=None, fun_id=None, axis=None, **kwargs):
361+
def plot(
362+
self,
363+
haz_type: Optional[str] = None,
364+
fun_id: Optional[str | int] = None,
365+
axis=None,
366+
**kwargs,
367+
):
327368
"""Plot impact functions of selected hazard (all if not provided) and
328369
selected function id (all if not provided).
329370

0 commit comments

Comments
 (0)