Skip to content

Commit 9d80743

Browse files
Resolves #81 (#84)
* chore(issue-81): Create a base class for saving and loading #81 * refactor: Adjusted to reviewer's suggestions . renamed black_list to block_list . renamed pickle_block_list to attr_block_list . removed SavableAttr class . left only get_full_persistence_path of the get_full_x_persistence_path methods * refactor: Adjusted to some of reviewer's suggestions #84 * refactor: More alterations #84 * fix: Replaced attr with key * refactor: Adjusted to some of reviewer's suggestions * refactor: Changed class and file name to Persistence #84 * feat: Create class PersistencePickle .Persistence class is now abstract .Furthermore, pickle related behaviour was moved to PersistencePickle * feat: Generalize some of Persistence's methods #84 * refactor: Remove unused import #84 * refactor: Rename _save to _simple_save #84 * test(issue-81): Add unit tests for Persistence #84 * test(issue-81): Add unit tests for PersistencePickle #84 * test(issue-81): Mock simple_save and load #84 * refactor(issue-81): Adjusted to reviewer's suggestions #84 * refactor(issue-81): Adjusted to lint's suggestions #84 * refactor(issue-81): Tests now in "given when then" format #84
1 parent de65870 commit 9d80743

File tree

5 files changed

+361
-0
lines changed

5 files changed

+361
-0
lines changed

tests/units/base/__init__.py

Whitespace-only changes.

tests/units/base/test_persistence.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
import unittest
3+
from abc import ABCMeta
4+
5+
from urnai.base.persistence import Persistence
6+
7+
8+
class FakePersistence(Persistence):
9+
Persistence.__abstractmethods__ = set()
10+
11+
def __init__(self, threaded_saving=False):
12+
super().__init__(threaded_saving)
13+
14+
class TestPersistence(unittest.TestCase):
15+
16+
def test_abstract_methods(self):
17+
18+
# GIVEN
19+
fake_persistence = FakePersistence()
20+
21+
# WHEN
22+
_simple_save_return = fake_persistence._simple_save(".")
23+
load_return = fake_persistence.load(".")
24+
_get_dict_return = fake_persistence._get_dict()
25+
_get_attributes_return = fake_persistence._get_attributes()
26+
27+
# THEN
28+
assert isinstance(Persistence, ABCMeta)
29+
assert _simple_save_return is None
30+
assert load_return is None
31+
assert _get_dict_return is None
32+
assert _get_attributes_return is None
33+
34+
def test_get_default_save_stamp(self):
35+
36+
# GIVEN
37+
fake_persistence = FakePersistence()
38+
39+
# WHEN
40+
_get_def_save_stamp_return = fake_persistence._get_default_save_stamp()
41+
42+
# THEN
43+
self.assertEqual(_get_def_save_stamp_return,
44+
fake_persistence.__class__.__name__ + '_')
45+
46+
def test_get_full_persistance_path(self):
47+
48+
# GIVEN
49+
fake_persistence = FakePersistence()
50+
persist_path = "test"
51+
52+
# WHEN
53+
get_full_pers_path_return = fake_persistence.get_full_persistance_path(
54+
persist_path)
55+
56+
# THEN
57+
self.assertEqual(get_full_pers_path_return,
58+
persist_path + os.path.sep +
59+
fake_persistence._get_default_save_stamp())
60+
61+
def test_save(self):
62+
63+
# GIVEN
64+
fake_persistence = FakePersistence()
65+
persist_path = "test"
66+
67+
# WHEN
68+
save_return = fake_persistence.save(persist_path)
69+
fake_persistence_threaded = FakePersistence(threaded_saving=True)
70+
fake_persistence_threaded.save(persist_path)
71+
72+
# THEN
73+
assert save_return is None
74+
for process in fake_persistence_threaded.processes:
75+
assert process.is_alive() is True
76+
77+
def test_threaded_save(self):
78+
79+
# GIVEN
80+
fake_persistence = FakePersistence(threaded_saving=True)
81+
persist_path = "test"
82+
83+
# WHEN
84+
fake_persistence.save(persist_path)
85+
86+
# THEN
87+
for process in fake_persistence.processes:
88+
assert process.is_alive() is True
89+
90+
def test_restore_attributes(self):
91+
92+
# GIVEN
93+
fake_persistence = FakePersistence()
94+
dict_to_restore = {"TestAttribute1": 314, "TestAttribute2": "string"}
95+
96+
# WHEN
97+
fake_persistence.attr_block_list = ["TestAttribute2"]
98+
fake_persistence._restore_attributes(dict_to_restore)
99+
100+
# THEN
101+
assert hasattr(fake_persistence, "TestAttribute1") is True
102+
assert hasattr(fake_persistence, "TestAttribute2") is False
103+
self.assertEqual(fake_persistence.TestAttribute1, 314)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import unittest
2+
from unittest.mock import patch
3+
4+
from urnai.base.persistence_pickle import PersistencePickle
5+
6+
7+
class FakePersistencePickle(PersistencePickle):
8+
def __init__(self, threaded_saving=False):
9+
super().__init__(threaded_saving)
10+
11+
class TestPersistence(unittest.TestCase):
12+
13+
@patch('urnai.base.persistence_pickle.PersistencePickle._simple_save')
14+
def test_simple_save(self, mock_simple_save):
15+
16+
# GIVEN
17+
fake_persistence_pickle = FakePersistencePickle()
18+
persist_path = "test_simple_save"
19+
20+
# WHEN
21+
mock_simple_save.return_value = "return_value"
22+
simple_save_return = fake_persistence_pickle._simple_save(persist_path)
23+
24+
# THEN
25+
self.assertEqual(simple_save_return, "return_value")
26+
27+
@patch('urnai.base.persistence_pickle.PersistencePickle.load')
28+
def test_load(self, mock_load):
29+
"""
30+
This method creates a FakePersistencePickle with certain values
31+
and saves it (state1). After that, it changes the object's
32+
attributes (state2) and loads it back to state1.
33+
"""
34+
35+
# GIVEN
36+
fake_persistence_pickle = FakePersistencePickle()
37+
persist_path = "test_load"
38+
mock_load.return_value = "return_value"
39+
40+
# WHEN
41+
load_return = fake_persistence_pickle.load(persist_path)
42+
43+
# THEN
44+
self.assertEqual(load_return, "return_value")
45+
46+
def test_get_attributes(self):
47+
48+
# GIVEN
49+
fake_persistence_pickle = FakePersistencePickle()
50+
51+
# WHEN
52+
return_list = fake_persistence_pickle._get_attributes()
53+
54+
# THEN
55+
self.assertEqual(return_list, ['threaded_saving'])
56+
57+
def test_get_dict(self):
58+
59+
# GIVEN
60+
fake_persistence_pickle = FakePersistencePickle()
61+
62+
# WHEN
63+
return_dict = fake_persistence_pickle._get_dict()
64+
65+
# THEN
66+
self.assertEqual(return_dict, {"threaded_saving": False})

urnai/base/persistence.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os
2+
from abc import ABC, abstractmethod
3+
from multiprocessing import Process
4+
5+
6+
class Persistence(ABC):
7+
"""
8+
This interface represents the concept of a class that can be saved to disk.
9+
The heir class should define a constant or attribute as a default filename
10+
to save on disk.
11+
"""
12+
13+
def __init__(self, threaded_saving=False):
14+
self.threaded_saving = threaded_saving
15+
self.attr_block_list = []
16+
self.processes = []
17+
18+
def _get_default_save_stamp(self):
19+
"""
20+
This method returns the default
21+
file name that should be used while
22+
persisting the object.
23+
"""
24+
return self.__class__.__name__ + '_'
25+
26+
def get_full_persistance_path(self, persist_path):
27+
"""This method returns the default persistance path."""
28+
return persist_path + os.path.sep + self._get_default_save_stamp()
29+
30+
def save(self, savepath):
31+
if self.threaded_saving:
32+
self._threaded_save(savepath)
33+
else:
34+
self._simple_save(savepath)
35+
36+
def _threaded_save(self, savepath):
37+
"""
38+
This method saves pickleable
39+
elements in a separate thread
40+
"""
41+
new_process = Process(target=self._simple_save, args=(savepath,))
42+
self.processes.append(new_process)
43+
new_process.start()
44+
45+
@abstractmethod
46+
def _simple_save(self, savepath):
47+
"""
48+
This method handles logic related
49+
to non-threaded saving in the child class
50+
"""
51+
...
52+
53+
@abstractmethod
54+
def load(self, savepath):
55+
...
56+
57+
@abstractmethod
58+
def _get_dict(self):
59+
"""
60+
This method returns a dict with
61+
all the attributes that will be
62+
saved
63+
"""
64+
...
65+
66+
@abstractmethod
67+
def _get_attributes(self):
68+
"""
69+
This method returns all of the
70+
attribute names that can be saved
71+
except those in blocklist
72+
"""
73+
...
74+
75+
def _restore_attributes(self, dict_to_restore):
76+
for key in dict_to_restore:
77+
if key not in self.attr_block_list:
78+
setattr(self, key, dict_to_restore[key])

urnai/base/persistence_pickle.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import os
2+
import pickle
3+
import tempfile
4+
5+
from urnai.base.persistence import Persistence
6+
7+
8+
class PersistencePickle(Persistence):
9+
"""
10+
This interface represents the concept of a class that can be saved to disk.
11+
The heir class should define a constant or attribute as a default filename
12+
to save on disk.
13+
"""
14+
15+
def __init__(self, threaded_saving=False):
16+
super().__init__(threaded_saving)
17+
18+
def _simple_save(self, persist_path):
19+
"""
20+
This method saves our instance
21+
using pickle.
22+
23+
First it checks which attributes should be
24+
saved using pickle, the ones which are not
25+
are backuped.
26+
27+
Then all unpickleable attributes are set to None
28+
and the object is pickled.
29+
30+
Finally the nulled attributes are
31+
restored.
32+
"""
33+
path = self.get_full_persistance_path(persist_path)
34+
35+
os.makedirs(os.path.dirname(path), exist_ok=True)
36+
37+
with open(path, 'wb') as pickle_out:
38+
pickle.dump(self._get_dict(), pickle_out)
39+
40+
def load(self, persist_path):
41+
"""
42+
This method loads a list instance
43+
saved by pickle.
44+
"""
45+
pickle_path = self.get_full_persistance_path(persist_path)
46+
exists_pickle = os.path.isfile(pickle_path)
47+
48+
if exists_pickle and os.path.getsize(pickle_path) > 0:
49+
with open(pickle_path, 'rb') as pickle_in:
50+
pickle_dict = pickle.load(pickle_in)
51+
self._restore_attributes(pickle_dict)
52+
53+
def _get_attributes(self):
54+
"""
55+
This method returns a list of pickeable attributes.
56+
If you wish to block one particular pickleable attribute, put it
57+
in self.attr_block_list as a string.
58+
"""
59+
if not hasattr(self, 'attr_block_list') or self.attr_block_list is None:
60+
self.attr_block_list = []
61+
62+
attr_block_list = self.attr_block_list + ['attr_block_list', 'processes']
63+
64+
full_attr_list = [attr for attr in dir(self) if not attr.startswith('__')
65+
and not callable(getattr(self, attr))
66+
and attr not in attr_block_list
67+
and 'abc' not in attr]
68+
69+
pickleable_list = []
70+
71+
for key in full_attr_list:
72+
try:
73+
with tempfile.NamedTemporaryFile() as tmp_file:
74+
pickle.dump(getattr(self, key), tmp_file)
75+
tmp_file.flush()
76+
77+
pickleable_list.append(key)
78+
79+
except pickle.PicklingError:
80+
continue
81+
82+
except TypeError as type_error:
83+
if ("can't pickle" not in str(type_error) and
84+
'cannot pickle' not in str(type_error)):
85+
raise TypeError() from type_error
86+
continue
87+
88+
except NotImplementedError as notimpl_error:
89+
if (str(notimpl_error) !=
90+
'numpy() is only available when eager execution is enabled.'):
91+
raise NotImplementedError() from notimpl_error
92+
continue
93+
94+
except AttributeError as attr_error:
95+
if ("Can't pickle" not in str(attr_error) and
96+
"object has no attribute '__getstate__'" not in
97+
str(attr_error)):
98+
raise AttributeError() from attr_error
99+
continue
100+
101+
except ValueError as value_error:
102+
if 'ctypes objects' not in str(value_error):
103+
raise ValueError() from value_error
104+
continue
105+
106+
return pickleable_list
107+
108+
def _get_dict(self):
109+
pickleable_attr_dict = {}
110+
111+
for attr in self._get_attributes():
112+
pickleable_attr_dict[attr] = getattr(self, attr)
113+
114+
return pickleable_attr_dict

0 commit comments

Comments
 (0)