Skip to content

Commit 46e6624

Browse files
committed
fix(simulations): add input data types
1 parent c9b35dc commit 46e6624

File tree

5 files changed

+92
-59
lines changed

5 files changed

+92
-59
lines changed

openfisca_core/simulations/_build_default_simulation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
"""This module contains the _BuildDefaultSimulation class."""
22

3-
from typing_extensions import Self, TypeAlias
3+
from typing_extensions import Self
44

55
import numpy
66

77
from .simulation import Simulation
8-
from .types import CoreEntity, GroupPopulation, TaxBenefitSystem
9-
10-
Populations: TypeAlias = dict[str, GroupPopulation[CoreEntity]]
8+
from .types import Populations, TaxBenefitSystem
119

1210

1311
class _BuildDefaultSimulation:

openfisca_core/simulations/_build_from_variables.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33
from __future__ import annotations
44

55
from collections.abc import Sized
6-
from typing_extensions import Self, TypeAlias
6+
from typing_extensions import Self
77

88
from openfisca_core import errors
99

1010
from ._build_default_simulation import _BuildDefaultSimulation
1111
from ._guards import is_variable_dated
1212
from .simulation import Simulation
13-
from .types import CoreEntity, GroupPopulation, TaxBenefitSystem, Variables
14-
15-
Populations: TypeAlias = dict[str, GroupPopulation[CoreEntity]]
13+
from .types import Populations, TaxBenefitSystem, Variables
1614

1715

1816
class _BuildFromVariables:

openfisca_core/simulations/simulation.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010
from openfisca_core import commons, errors, indexed_enums, periods, tracers
1111
from openfisca_core import warnings as core_warnings
1212

13-
from .types import GroupPopulation, TaxBenefitSystem, Variable
13+
from .types import (
14+
EntityPlural,
15+
GroupEntity,
16+
GroupPopulation,
17+
Populations,
18+
TaxBenefitSystem,
19+
Variable,
20+
)
1421

1522

1623
class Simulation:
@@ -19,13 +26,13 @@ class Simulation:
1926
"""
2027

2128
tax_benefit_system: TaxBenefitSystem
22-
populations: dict[str, GroupPopulation]
29+
populations: Populations
2330
invalidated_caches: Set[Cache]
2431

2532
def __init__(
2633
self,
2734
tax_benefit_system: TaxBenefitSystem,
28-
populations: dict[str, GroupPopulation],
35+
populations: Populations,
2936
):
3037
"""
3138
This constructor is reserved for internal use; see :any:`SimulationBuilder`,
@@ -555,10 +562,14 @@ def get_population(self, plural: Optional[str] = None) -> Optional[GroupPopulati
555562

556563
def get_entity(
557564
self,
558-
plural: Optional[str] = None,
559-
) -> Optional[GroupPopulation]:
560-
population = self.get_population(plural)
561-
return population and population.entity
565+
plural: EntityPlural | None = None,
566+
) -> GroupEntity | None:
567+
population: GroupPopulation | None = self.get_population(plural)
568+
569+
if population is None:
570+
return None
571+
572+
return population.entity
562573

563574
def describe_entities(self):
564575
return {

openfisca_core/simulations/simulation_builder.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from __future__ import annotations
22

33
from collections.abc import Iterable, Sequence
4-
from numpy.typing import NDArray as Array
5-
from typing import Dict, List
64

75
import copy
86

97
import dpath.util
108
import numpy
119

12-
from openfisca_core import entities, errors, periods, populations, variables
10+
from openfisca_core import errors, periods
1311

1412
from . import helpers
1513
from ._build_default_simulation import _BuildDefaultSimulation
@@ -22,23 +20,31 @@
2220
)
2321
from .simulation import Simulation
2422
from .types import (
23+
Array,
2524
Axis,
25+
EntityCounts,
26+
EntityIds,
27+
EntityRoles,
2628
FullySpecifiedEntities,
2729
GroupEntities,
2830
GroupEntity,
2931
ImplicitGroupEntities,
32+
InputBuffer,
33+
Memberships,
3034
Params,
3135
ParamsWithoutAxes,
36+
Populations,
3237
Role,
3338
SingleEntity,
3439
SinglePopulation,
3540
TaxBenefitSystem,
41+
VariableEntity,
3642
Variables,
3743
)
3844

3945

4046
class SimulationBuilder:
41-
def __init__(self):
47+
def __init__(self) -> None:
4248
self.default_period = (
4349
None # Simulation period used for variables when no period is defined
4450
)
@@ -47,26 +53,24 @@ def __init__(self):
4753
)
4854

4955
# JSON input - Memory of known input values. Indexed by variable or axis name.
50-
self.input_buffer: Dict[
51-
variables.Variable.name, Dict[str(periods.period), numpy.array]
52-
] = {}
53-
self.populations: Dict[entities.Entity.key, populations.Population] = {}
56+
self.input_buffer: InputBuffer = {}
57+
self.populations: Populations = {}
5458
# JSON input - Number of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_ids``, including axes.
55-
self.entity_counts: Dict[entities.Entity.plural, int] = {}
59+
self.entity_counts: EntityCounts = {}
5660
# JSON input - List of items of each entity type. Indexed by entities plural names. Should be consistent with ``entity_counts``.
57-
self.entity_ids: Dict[entities.Entity.plural, List[int]] = {}
61+
self.entity_ids: EntityIds = {}
5862

5963
# Links entities with persons. For each person index in persons ids list, set entity index in entity ids id. E.g.: self.memberships[entity.plural][person_index] = entity_ids.index(instance_id)
60-
self.memberships: Dict[entities.Entity.plural, List[int]] = {}
61-
self.roles: Dict[entities.Entity.plural, List[int]] = {}
64+
self.memberships: Memberships = {}
65+
self.roles: EntityRoles = {}
6266

63-
self.variable_entities: Dict[variables.Variable.name, entities.Entity] = {}
67+
self.variable_entities: VariableEntity = {}
6468

6569
self.axes = [[]]
66-
self.axes_entity_counts: Dict[entities.Entity.plural, int] = {}
67-
self.axes_entity_ids: Dict[entities.Entity.plural, List[int]] = {}
68-
self.axes_memberships: Dict[entities.Entity.plural, List[int]] = {}
69-
self.axes_roles: Dict[entities.Entity.plural, List[int]] = {}
70+
self.axes_entity_counts: EntityCounts = {}
71+
self.axes_entity_ids: EntityIds = {}
72+
self.axes_memberships: Memberships = {}
73+
self.axes_roles: EntityRoles = {}
7074

7175
def build_from_dict(
7276
self,

openfisca_core/simulations/types.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
from __future__ import annotations
44

55
from collections.abc import Callable, Iterable, Sequence
6-
from typing import Protocol, TypeVar, TypedDict, Union
6+
from typing import NewType, Protocol, TypeVar, TypedDict, Union
77
from typing_extensions import NotRequired, Required, TypeAlias
88

99
import datetime
10-
from abc import abstractmethod
1110

1211
from numpy import bool_ as Bool
1312
from numpy import datetime64 as Date
@@ -19,32 +18,36 @@
1918
from openfisca_core import types as t
2019

2120
# Generic type variables.
22-
D = TypeVar("D")
23-
E = TypeVar("E", covariant=True)
2421
G = TypeVar("G", covariant=True)
2522
T = TypeVar("T", Bool, Date, Enum, Float, Int, String, covariant=True)
2623
U = TypeVar("U", bool, datetime.date, float, str)
2724
V = TypeVar("V", covariant=True)
2825

26+
# New types.
27+
PeriodStr = NewType("PeriodStr", str)
28+
EntityKey = NewType("EntityKey", str)
29+
EntityPlural = NewType("EntityPlural", str)
30+
VariableName = NewType("VariableName", str)
31+
32+
# Type aliases.
33+
2934
#: Type alias for numpy arrays values.
3035
Item: TypeAlias = Union[Bool, Date, Enum, Float, Int, String]
3136

37+
#: Type Alias for a numpy Array.
38+
Array: TypeAlias = t.Array
3239

3340
# Entities
3441

3542

36-
#: Type alias for a simulation dictionary defining the roles.
37-
Roles: TypeAlias = dict[str, Union[str, Iterable[str]]]
38-
39-
4043
class CoreEntity(t.CoreEntity, Protocol):
41-
key: str
42-
plural: str | None
44+
key: EntityKey
45+
plural: EntityPlural | None
4346

4447
def get_variable(
4548
self,
46-
__variable_name: str,
47-
__check_existence: bool = ...,
49+
__variable_name: VariableName,
50+
check_existence: bool = ...,
4851
) -> Variable[T] | None:
4952
...
5053

@@ -55,7 +58,6 @@ class SingleEntity(t.SingleEntity, Protocol):
5558

5659
class GroupEntity(t.GroupEntity, Protocol):
5760
@property
58-
@abstractmethod
5961
def flattened_roles(self) -> Iterable[Role[G]]:
6062
...
6163

@@ -69,11 +71,10 @@ class Role(t.Role, Protocol[G]):
6971

7072
class Holder(t.Holder, Protocol[V]):
7173
@property
72-
@abstractmethod
7374
def variable(self) -> Variable[T]:
7475
...
7576

76-
def get_array(self, __period: str) -> t.Array[T] | None:
77+
def get_array(self, __period: PeriodStr) -> t.Array[T] | None:
7778
...
7879

7980
def set_input(
@@ -94,18 +95,19 @@ class Period(t.Period, Protocol):
9495
# Populations
9596

9697

97-
class CorePopulation(t.CorePopulation, Protocol[D]):
98-
entity: D
98+
class CorePopulation(t.CorePopulation, Protocol):
99+
entity: CoreEntity
99100

100-
def get_holder(self, __variable_name: str) -> Holder[V]:
101+
def get_holder(self, __variable_name: VariableName) -> Holder[V]:
101102
...
102103

103104

104-
class SinglePopulation(t.SinglePopulation, Protocol[E]):
105-
...
105+
class SinglePopulation(t.SinglePopulation, Protocol):
106+
entity: SingleEntity
106107

107108

108-
class GroupPopulation(t.GroupPopulation, Protocol[E]):
109+
class GroupPopulation(t.GroupPopulation, Protocol):
110+
entity: GroupEntity
109111
members_entity_id: t.Array[String]
110112

111113
def nb_persons(self, __role: Role[G] | None = ...) -> int:
@@ -114,6 +116,29 @@ def nb_persons(self, __role: Role[G] | None = ...) -> int:
114116

115117
# Simulations
116118

119+
#: Dictionary with axes parameters per variable.
120+
InputBuffer: TypeAlias = dict[VariableName, dict[PeriodStr, Array]]
121+
122+
#: Dictionary with entity/population key/pais.
123+
Populations: TypeAlias = dict[EntityKey, GroupPopulation]
124+
125+
#: Dictionary with single entity count per group entity.
126+
EntityCounts: TypeAlias = dict[EntityPlural, int]
127+
128+
#: Dictionary with a list of single entities per group entity.
129+
EntityIds: TypeAlias = dict[EntityPlural, Iterable[int]]
130+
131+
#: Dictionary with a list of members per group entity.
132+
Memberships: TypeAlias = dict[EntityPlural, Iterable[int]]
133+
134+
#: Dictionary with a list of roles per group entity.
135+
EntityRoles: TypeAlias = dict[EntityPlural, Iterable[int]]
136+
137+
#: Dictionary with a map between variables and entities.
138+
VariableEntity: TypeAlias = dict[VariableName, CoreEntity]
139+
140+
#: Type alias for a simulation dictionary defining the roles.
141+
Roles: TypeAlias = dict[str, Union[str, Iterable[str]]]
117142

118143
#: Type alias for a simulation dictionary with undated variables.
119144
UndatedVariable: TypeAlias = dict[str, object]
@@ -169,21 +194,18 @@ class Simulation(t.Simulation, Protocol):
169194

170195
class TaxBenefitSystem(t.TaxBenefitSystem, Protocol):
171196
@property
172-
@abstractmethod
173197
def person_entity(self) -> SingleEntity:
174198
...
175199

176200
@person_entity.setter
177-
@abstractmethod
178201
def person_entity(self, person_entity: SingleEntity) -> None:
179202
...
180203

181204
@property
182-
@abstractmethod
183205
def variables(self) -> dict[str, V]:
184206
...
185207

186-
def entities_by_singular(self) -> dict[str, E]:
208+
def entities_by_singular(self) -> dict[str, CoreEntity]:
187209
...
188210

189211
def entities_plural(self) -> Iterable[str]:
@@ -198,7 +220,7 @@ def get_variable(
198220

199221
def instantiate_entities(
200222
self,
201-
) -> dict[str, GroupPopulation[E]]:
223+
) -> Populations:
202224
...
203225

204226

@@ -209,7 +231,7 @@ class Variable(t.Variable, Protocol[T]):
209231
calculate_output: Callable[[Simulation, str, str], t.Array[T]] | None
210232
definition_period: str
211233
end: str
212-
name: str
234+
name: VariableName
213235

214236
def default_array(self, __array_size: int) -> t.Array[T]:
215237
...

0 commit comments

Comments
 (0)