Skip to content

Commit 32105f2

Browse files
committed
feat(tests): fix corner cases
1 parent d0cec25 commit 32105f2

File tree

13 files changed

+752
-94
lines changed

13 files changed

+752
-94
lines changed

openfisca_core/indexed_enums/enum_array.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ class EnumArray(numpy.ndarray):
2121
# Subclassing ndarray is a little tricky.
2222
# To read more about the two following methods, see:
2323
# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array.
24+
25+
@property
26+
def dtype(self) -> t.ArrayEnum:
27+
return super().dtype
28+
2429
def __new__(
2530
cls,
2631
input_array: t.Array[t.ArrayEnum],
@@ -84,7 +89,7 @@ def decode(self) -> numpy.object_:
8489
list(self.possible_values),
8590
)
8691

87-
def decode_to_str(self) -> numpy.str_:
92+
def decode_to_str(self) -> t.Array[t.ArrayStr]:
8893
"""
8994
Return the array of string identifiers corresponding to self.
9095

openfisca_core/indexed_enums/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from openfisca_core.types import Array, ArrayBytes, ArrayEnum
1+
from openfisca_core.types import Array, ArrayBytes, ArrayEnum, ArrayStr
22

3-
__all__ = ["Array", "ArrayBytes", "ArrayEnum"]
3+
__all__ = ["Array", "ArrayBytes", "ArrayEnum", "ArrayStr"]

openfisca_core/tools/test_runner.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ class Options(TypedDict, total=False):
3535

3636
@dataclasses.dataclass(frozen=True)
3737
class ErrorMargin:
38-
__root__: Dict[Union[str, Literal["default"]], Optional[float]]
38+
__root__: Dict[Union[str, Literal["default"]], float]
3939

40-
def __getitem__(self, key: str) -> Optional[float]:
40+
def __getitem__(self, key: str) -> float:
4141
if key in self.__root__:
4242
return self.__root__[key]
4343

@@ -66,7 +66,7 @@ def build_test(params: Dict[str, Any]) -> Test:
6666
value = params.get(key)
6767

6868
if value is None:
69-
value = {"default": None}
69+
value = {"default": 0}
7070

7171
elif isinstance(value, (float, int, str)):
7272
value = {"default": float(value)}
@@ -326,9 +326,9 @@ def check_variable(
326326
return assert_near(
327327
actual_value,
328328
expected_value,
329-
self.test.absolute_error_margin[variable_name],
330-
f"{variable_name}@{period}: ",
331-
self.test.relative_error_margin[variable_name],
329+
message=f"{variable_name}@{period}: ",
330+
absolute_error_margin=self.test.absolute_error_margin[variable_name],
331+
relative_error_margin=self.test.relative_error_margin[variable_name],
332332
)
333333

334334
def should_ignore_variable(self, variable_name: str):

openfisca_core/variables/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"is_period_size_independent": False,
2828
},
2929
str: {
30-
"dtype": t.ArrayBytes,
30+
"dtype": t.ArrayStr,
3131
"default": "",
3232
"json_type": "string",
3333
"formatted_value_type": "String",

openfisca_core/variables/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
ArrayEnum,
66
ArrayFloat,
77
ArrayInt,
8+
ArrayObject,
9+
ArrayStr,
810
)
911

1012
__any__ = [
@@ -14,4 +16,6 @@
1416
ArrayEnum,
1517
ArrayFloat,
1618
ArrayInt,
19+
ArrayObject,
20+
ArrayStr,
1721
]

openfisca_core/variables/variable.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,10 @@ def check_set_value(self, value):
469469
return value
470470

471471
def default_array(self, array_size):
472-
array = numpy.empty(array_size, dtype=self.dtype)
472+
if numpy.issubdtype(self.dtype, numpy.datetime64):
473+
array = numpy.empty(array_size, dtype="datetime64[D]")
474+
else:
475+
array = numpy.empty(array_size, dtype=self.dtype)
473476
if self.value_type == Enum:
474477
array.fill(self.default_value.index)
475478
return EnumArray(array, self.possible_values)

openfisca_tasks/lint.mk

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ check-style: $(shell git ls-files "*.py" "*.pyi")
2020
lint-doc: \
2121
lint-doc-commons \
2222
lint-doc-entities \
23-
lint-doc-types \
2423
;
2524

2625
## Run linters to check for syntax and style errors in the doc.
@@ -32,14 +31,22 @@ lint-doc-%:
3231
@## able to integrate documentation improvements progresively.
3332
@##
3433
@$(call print_help,$(subst $*,%,$@:))
35-
@flake8 --select=D101,D102,D103,DAR openfisca_core/$* openfisca_test
36-
@pylint openfisca_core/$* openfisca_test
34+
@flake8 \
35+
--select=D101,D102,D103,DAR \
36+
openfisca_core/$* \
37+
openfisca_core/types.py \
38+
openfisca_test \
39+
stubs
40+
@pylint openfisca_core/$* \
41+
openfisca_core/$* \
42+
openfisca_core/types.py \
43+
openfisca_test \
44+
stubs
3745
@$(call print_pass,$@:)
3846

3947
## Run static type checkers for type errors.
4048
check-types:
4149
@$(call print_help,$@:)
42-
@command -v pyright && pyright
4350
@mypy \
4451
openfisca_core/commons \
4552
openfisca_core/entities \

openfisca_test/_assertions.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from __future__ import annotations
2+
3+
from typing import NoReturn
4+
5+
import numpy
6+
7+
from . import types as t
8+
from ._parsers import parse
9+
10+
11+
def assert_near(
12+
actual: object,
13+
expected: object,
14+
/,
15+
message: str = "",
16+
*,
17+
absolute_error_margin: float = 0,
18+
relative_error_margin: float = 0,
19+
) -> None | NoReturn:
20+
"""Assert that two values are near each other.
21+
22+
Args:
23+
actual: Value returned by the test.
24+
expected: Value that the test should return to pass.
25+
message: Error message to be displayed if the test fails.
26+
absolute_error_margin: Absolute error margin authorized.
27+
relative_error_margin: Relative error margin authorized.
28+
29+
Returns:
30+
None
31+
32+
Raises:
33+
ValueError: If the error margin is negative.
34+
AssertionError: If the two values are not near each other.
35+
NotImplementedError: If the data type is not supported.
36+
37+
Note:
38+
This function cannot be used to assert near periods.
39+
40+
Examples:
41+
>>> actual = numpy.array([1.0, 2.0, 3.0])
42+
>>> expected = numpy.array([1.0, 2.0, 2.9])
43+
>>> assert_near(actual, expected, absolute_error_margin=0.2)
44+
45+
>>> expected = numpy.array([1.0, 2.0, 2.95])
46+
>>> assert_near(actual, expected, absolute_error_margin=0.1)
47+
48+
>>> assert_near(actual, expected, relative_error_margin=0.05)
49+
50+
>>> assert_near(True, [True])
51+
52+
"""
53+
54+
# Validate absolute_error_margin.
55+
if absolute_error_margin < 0:
56+
raise ValueError("The absolute error margin must be positive.")
57+
58+
# Validate relative_error_margin.
59+
if relative_error_margin < 0:
60+
raise ValueError("The relative error margin must be positive.")
61+
62+
# Parse the actual value.
63+
actual = parse(actual)
64+
65+
# Parse the expected value.
66+
expected = parse(expected)
67+
68+
# Get the common data type.
69+
try:
70+
common_dtype = numpy.promote_types(actual.dtype, expected.dtype)
71+
72+
except TypeError:
73+
raise AssertionError(
74+
f"Incompatible types: {actual.dtype} and {expected.dtype}."
75+
)
76+
77+
if numpy.issubdtype(common_dtype, numpy.datetime64):
78+
actual = actual.astype(t.ArrayDate)
79+
expected = expected.astype(t.ArrayDate)
80+
assert (actual == expected).all(), f"{message}{actual} differs from {expected}."
81+
return None
82+
83+
if numpy.issubdtype(common_dtype, numpy.bool_):
84+
actual = actual.astype(t.ArrayBool)
85+
expected = expected.astype(t.ArrayBool)
86+
assert (actual == expected).all(), f"{message}{actual} differs from {expected}."
87+
return None
88+
89+
if numpy.issubdtype(common_dtype, numpy.bytes_):
90+
actual = actual.astype(t.ArrayBytes)
91+
expected = expected.astype(t.ArrayBytes)
92+
assert (actual == expected).all(), f"{message}{actual} differs from {expected}."
93+
return None
94+
95+
if numpy.issubdtype(common_dtype, numpy.str_):
96+
actual = actual.astype(t.ArrayStr)
97+
expected = expected.astype(t.ArrayStr)
98+
assert (actual == expected).all(), f"{message}{actual} differs from {expected}."
99+
return None
100+
101+
if numpy.issubdtype(common_dtype, numpy.int32) or numpy.issubdtype(
102+
common_dtype, numpy.int64
103+
):
104+
actual = actual.astype(t.ArrayInt)
105+
expected = expected.astype(t.ArrayInt)
106+
diff = abs(expected - actual)
107+
assert (
108+
diff == 0
109+
).all(), f"{message}{actual} differs from {expected} by {diff}."
110+
return None
111+
112+
if numpy.issubdtype(common_dtype, numpy.float32) or numpy.issubdtype(
113+
common_dtype, numpy.float64
114+
):
115+
actual = actual.astype(t.ArrayFloat)
116+
expected = expected.astype(t.ArrayFloat)
117+
diff = abs(expected - actual)
118+
if absolute_error_margin > 0:
119+
assert (diff <= absolute_error_margin).all(), (
120+
f"{message}{actual} differs from {expected} with an absolute margin "
121+
f"{diff} > {absolute_error_margin}"
122+
)
123+
return None
124+
if relative_error_margin > 0:
125+
assert (diff <= abs(relative_error_margin * expected)).all(), (
126+
f"{message}{actual} differs from {expected} with a relative margin "
127+
f"{diff} > {abs(relative_error_margin * expected)}"
128+
)
129+
return None
130+
assert (
131+
actual == expected
132+
).all(), f"{message}{actual} differs from {expected} by {diff}."
133+
return None
134+
135+
raise NotImplementedError
136+
137+
138+
__all__ = ["assert_near"]

0 commit comments

Comments
 (0)