Skip to content

Commit e0a8f66

Browse files
author
Mauko Quiroga
committed
Add an explicit Cache API
1 parent 49e6083 commit e0a8f66

File tree

8 files changed

+319
-407
lines changed

8 files changed

+319
-407
lines changed

openfisca_core/data_storage.py

Lines changed: 22 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import abc
22
import os
33
import shutil
4-
from typing import Any, Dict, KeysView, Optional
4+
from typing import Any, Dict, KeysView, Optional, Union
55

66
import numpy
77

@@ -51,11 +51,11 @@ def delete(self, period: Optional[Period] = None) -> None:
5151
...
5252

5353
@abc.abstractmethod
54-
def known_periods(self) -> KeysView[Period]:
54+
def memory_usage(self) -> Dict[str, int]:
5555
...
5656

5757
@abc.abstractmethod
58-
def memory_usage(self) -> Dict[str, int]:
58+
def known_periods(self) -> KeysView[Period]:
5959
...
6060

6161

@@ -74,7 +74,7 @@ def cast_period(self, period: Optional[Period], eternal: bool) -> Period:
7474

7575

7676
class MemoryStorage(StorageLike):
77-
"""Responsible for storing and retrieving values in memory."""
77+
"""Low-level class responsible for storing and retrieving values in memory."""
7878

7979
def get(self, state: Dict[Period, Any], key: Period) -> Any:
8080
return state.get(key)
@@ -90,7 +90,6 @@ def delete_all(self, state: StateType) -> dict:
9090
state.clear()
9191
return state
9292

93-
# TODO: test
9493
def memory_usage(self, state: StateType) -> Dict[str, int]:
9594
if not state:
9695
return {
@@ -110,14 +109,17 @@ def memory_usage(self, state: StateType) -> Dict[str, int]:
110109

111110

112111
class DiskStorage(StorageLike):
113-
"""Responsible for storing and retrieving values on disk."""
112+
"""Low-level class responsible for storing and retrieving values on disk."""
114113
directory: str
115114
preserve: bool
116115

117116
def __init__(self, directory: str, preserve: bool) -> None:
118117
self.directory = directory
119118
self.preserve = preserve
120119

120+
if not os.path.isdir(self.directory):
121+
os.makedirs(self.directory, exist_ok = True)
122+
121123
def get(self, state: Dict[Period, Any], key: Period) -> Any:
122124
file = state.get(key)
123125

@@ -153,7 +155,6 @@ def delete_all(self, state: StateType) -> dict:
153155
state.clear()
154156
return state
155157

156-
# TODO: test
157158
def memory_usage(self, state: StateType) -> Dict[str, int]:
158159
if not state:
159160
return {
@@ -203,21 +204,26 @@ def __del__(self) -> None:
203204
shutil.rmtree(parent_dir)
204205

205206

206-
class InMemoryStorage(CachingLike, SupportsPeriodCasting):
207+
StorageType = Union[MemoryStorage, DiskStorage]
208+
209+
210+
class Cache(CachingLike, SupportsPeriodCasting):
207211
"""
208-
Low-level class responsible for storing and retrieving calculated vectors in memory.
212+
Explicit Cache API responsible of:
209213
210-
TODO: separate concerns between the caching API and the storing API.
214+
* keeping cache state
215+
* reading from storages
216+
* writing to storages
211217
"""
212218

213219
state: StateType
220+
storage: StorageType
214221
is_eternal: bool
215-
storage: MemoryStorage
216222

217-
def __init__(self, is_eternal: bool = False) -> None:
223+
def __init__(self, storage: StorageType, is_eternal: bool = False) -> None:
218224
self.state = {}
225+
self.storage = storage
219226
self.is_eternal = is_eternal
220-
self.storage = MemoryStorage()
221227

222228
def get(self, period: Period) -> Any:
223229
casted: Period = self.cast_period(period, self.is_eternal)
@@ -235,75 +241,20 @@ def delete(self, period: Optional[Period] = None) -> None:
235241
casted: Period = self.cast_period(period, self.is_eternal)
236242
self.state = self.storage.delete(self.state, casted)
237243

238-
# TODO: test
239-
def known_periods(self) -> KeysView[Period]:
240-
return self.state.keys()
241-
242-
# TODO: test
243244
def memory_usage(self) -> Dict[str, int]:
244245
return self.storage.memory_usage(self.state)
245246

246-
def get_known_periods(self) -> KeysView[Period]:
247-
raise ValueError("TODO: add a deprecation warning")
248-
249-
# TODO: decide what to do with this.
250-
def get_memory_usage(self) -> Dict[str, int]:
251-
raise ValueError("TODO: add a deprecation warning")
252-
253-
254-
class OnDiskStorage(CachingLike, SupportsPeriodCasting):
255-
"""
256-
Low-level class responsible for storing and retrieving calculated vectors on disk.
257-
258-
TODO: separate concerns between the caching API and the storing API.
259-
"""
260-
261-
state: StateType
262-
is_eternal: bool
263-
storage: DiskStorage
264-
265-
def __init__(
266-
self,
267-
storage_dir: str,
268-
is_eternal: bool = False,
269-
preserve_storage_dir: bool = False,
270-
) -> None:
271-
self.state = {}
272-
self.is_eternal = is_eternal
273-
self.storage = DiskStorage(storage_dir, preserve_storage_dir)
274-
275-
def get(self, period: Period) -> Any:
276-
casted: Period = self.cast_period(period, self.is_eternal)
277-
return self.storage.get(self.state, casted)
278-
279-
def put(self, value: Any, period: Period) -> None:
280-
casted: Period = self.cast_period(period, self.is_eternal)
281-
self.state = self.storage.put(self.state, casted, value)
282-
283-
def delete(self, period: Optional[Period] = None) -> None:
284-
if period is None:
285-
self.state = self.storage.delete_all(self.state)
286-
return
287-
288-
casted: Period = self.cast_period(period, self.is_eternal)
289-
self.state = self.storage.delete(self.state, casted)
290-
291-
# TODO: test
292247
def known_periods(self) -> KeysView[Period]:
293248
return self.state.keys()
294249

295-
# TODO: test
296-
def memory_usage(self) -> Dict[str, int]:
297-
return self.storage.memory_usage(self.state)
298-
250+
# TODO : test
299251
def get_known_periods(self) -> KeysView[Period]:
300252
raise ValueError("TODO: add a deprecation warning")
301253

254+
# TODO : test
302255
def get_memory_usage(self) -> Dict[str, int]:
303256
raise ValueError("TODO: add a deprecation warning")
304257

258+
# TODO : test
305259
def restore(self, state: StateType) -> StateType:
306260
raise ValueError("TODO: add a deprecation warning")
307-
308-
def __del__(self) -> None:
309-
raise ValueError("TODO: add a deprecation warning")

openfisca_core/holders.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from openfisca_core import periods
1111
from openfisca_core.commons import empty_clone
12-
from openfisca_core.data_storage import InMemoryStorage, OnDiskStorage
12+
from openfisca_core.data_storage import Cache, DiskStorage, MemoryStorage
1313
from openfisca_core.errors import PeriodMismatchError
1414
from openfisca_core.indexed_enums import Enum
1515
from openfisca_core.periods import MONTH, YEAR, ETERNITY
@@ -27,7 +27,7 @@ def __init__(self, variable, population):
2727
self.population = population
2828
self.variable = variable
2929
self.simulation = population.simulation
30-
self._memory_storage = InMemoryStorage(is_eternal = (self.variable.definition_period == ETERNITY))
30+
self._memory_storage = Cache(MemoryStorage(), is_eternal = (self.variable.definition_period == ETERNITY))
3131

3232
# By default, do not activate on-disk storage, or variable dropping
3333
self._disk_storage = None
@@ -60,13 +60,13 @@ def create_disk_storage(self, directory = None, preserve = False):
6060
if directory is None:
6161
directory = self.simulation.data_storage_dir
6262
storage_dir = os.path.join(directory, self.variable.name)
63+
6364
if not os.path.isdir(storage_dir):
6465
os.mkdir(storage_dir)
65-
return OnDiskStorage(
66-
storage_dir,
67-
is_eternal = (self.variable.definition_period == ETERNITY),
68-
preserve_storage_dir = preserve
69-
)
66+
67+
storage = DiskStorage(storage_dir, preserve)
68+
69+
return Cache(storage, is_eternal = (self.variable.definition_period == ETERNITY))
7070

7171
def delete_arrays(self, period = None):
7272
"""

openfisca_core/tools/simulation_dumper.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88
from openfisca_core.simulations import Simulation
9-
from openfisca_core.data_storage import OnDiskStorage
9+
from openfisca_core.data_storage import Cache, DiskStorage
1010
from openfisca_core.periods import ETERNITY
1111

1212

@@ -106,18 +106,12 @@ def _restore_entity(population, directory):
106106

107107
def _restore_holder(simulation, variable, directory):
108108
storage_dir = os.path.join(directory, variable)
109-
is_variable_eternal = simulation.tax_benefit_system.get_variable(variable).definition_period == ETERNITY
110-
disk_storage = OnDiskStorage(
111-
storage_dir,
112-
is_eternal = is_variable_eternal,
113-
preserve_storage_dir = True
114-
)
115-
116-
# TODO: decide whether to turn this variable public
117-
disk_storage.state = disk_storage.storage.restore(disk_storage.state)
118-
109+
is_eternal = simulation.tax_benefit_system.get_variable(variable).definition_period == ETERNITY
110+
storage = DiskStorage(storage_dir, preserve = True)
111+
cache = Cache(storage, is_eternal = is_eternal)
112+
cache.state = storage.restore(cache.state)
119113
holder = simulation.get_holder(variable)
120114

121-
for period in disk_storage.known_periods():
122-
value = disk_storage.get(period)
115+
for period in cache.known_periods():
116+
value = cache.get(period)
123117
holder.put_in_cache(value, period)

tests/core/data_storage/test_cache.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import numpy
2+
3+
from openfisca_core import periods
4+
from openfisca_core.data_storage import Cache, DiskStorage
5+
6+
import pytest
7+
8+
9+
@pytest.fixture
10+
def storage():
11+
return DiskStorage(directory = "/tmp/openfisca", preserve = False)
12+
13+
14+
@pytest.fixture
15+
def cache(storage):
16+
return Cache(storage = storage)
17+
18+
19+
@pytest.fixture
20+
def period():
21+
return periods.period("2020")
22+
23+
24+
@pytest.fixture
25+
def eternal_period():
26+
return periods.period(periods.ETERNITY)
27+
28+
29+
@pytest.fixture
30+
def array():
31+
return numpy.array([1])
32+
33+
34+
def test_get(cache, storage, period, mocker):
35+
mocker.patch.object(storage, "get")
36+
cache.get(period)
37+
38+
result = {}, period
39+
40+
storage.get.assert_called_once_with(*result)
41+
42+
43+
def test_get_when_is_eternal(cache, storage, period, eternal_period, mocker):
44+
"""When it is eternal, input periods are actually ignored."""
45+
mocker.patch.object(storage, "get")
46+
cache.is_eternal = True
47+
cache.get(period)
48+
49+
result = {}, eternal_period
50+
51+
storage.get.assert_called_once_with(*result)
52+
53+
54+
def test_put(cache, storage, period, array, mocker):
55+
mocker.patch.object(storage, "put")
56+
cache.put(array, period)
57+
58+
result = {}, period, array
59+
60+
storage.put.assert_called_once_with(*result)
61+
62+
63+
def test_put_when_is_eternal(cache, storage, period, eternal_period, array, mocker):
64+
"""When it is eternal, input periods are actually ignored."""
65+
mocker.patch.object(storage, "put")
66+
cache.is_eternal = True
67+
cache.put(array, period)
68+
69+
result = {}, eternal_period, array
70+
71+
storage.put.assert_called_once_with(*result)
72+
73+
74+
def test_delete(cache, storage, period, mocker):
75+
mocker.patch.object(storage, "delete")
76+
cache.delete(period)
77+
78+
result = {}, period
79+
80+
storage.delete.assert_called_once_with(*result)
81+
82+
83+
def test_delete_when_period_is_not_specified(cache, storage, mocker):
84+
mocker.patch.object(storage, "delete_all")
85+
cache.delete()
86+
87+
result = {}
88+
89+
storage.delete_all.assert_called_once_with(result)
90+
91+
92+
def test_delete_when_is_eternal(cache, storage, period, eternal_period, mocker):
93+
"""When it is eternal, input periods are actually ignored."""
94+
mocker.patch.object(storage, "delete")
95+
cache.is_eternal = True
96+
cache.delete(period)
97+
98+
result = {}, eternal_period
99+
100+
storage.delete.assert_called_once_with(*result)
101+
102+
103+
def test_get_memory_usage(cache, storage, mocker):
104+
mocker.patch.object(storage, "memory_usage")
105+
cache.memory_usage()
106+
107+
result = {}
108+
109+
storage.memory_usage.assert_called_once_with(result)
110+
111+
112+
def test_known_periods(cache, period, array):
113+
cache.put(array, period)
114+
115+
result = cache.known_periods()
116+
117+
assert period in result

0 commit comments

Comments
 (0)