Skip to content

Commit 98fed43

Browse files
committed
changed tests to reflect the changes made to the manager
1 parent d97edba commit 98fed43

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

StormDBManager.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def __init__(self, db_path):
2121
"""
2222
self._logger = logging.getLogger(self.__class__.__name__)
2323

24-
# Open or create the database
25-
self._database = create_database(db_path)
24+
self.db_path = db_path
2625

2726
# The transactor is required when you have methods decorated with the @transact decorator
2827
# This field name must NOT be changed.
@@ -31,9 +30,16 @@ def __init__(self, db_path):
3130
# Create a DeferredLock that should be used by callers to schedule their call.
3231
self.db_lock = DeferredLock()
3332

34-
self._version = 1
33+
def initialize(self):
34+
"""
35+
Opens/creates the database and initializes the version.
36+
"""
37+
# Open or create the database
38+
self._database = create_database(self.db_path)
39+
self._version = 0
3540
self._retrieve_version()
3641

42+
3743
def _retrieve_version(self):
3844
"""
3945
Attempts to retrieve the current datbase version from the MyInfo table.
@@ -49,25 +55,18 @@ def on_error(failure):
4955
self._logger.exception(u"Failed to load database version: %s", failure.getTraceback())
5056

5157
# Schedule the query and add a callback and errback to the deferred.
52-
return self.fetch_one(u"SELECT value FROM MyInfo WHERE entry == 'version'").addCallbacks(on_result, on_error)
58+
return self.schedule_query(self.fetch_one, u"SELECT value FROM MyInfo WHERE entry == 'version'").addCallbacks(on_result, on_error)
5359

54-
def schedule_query(*args, **kwargs):
60+
def schedule_query(self, callable, *args, **kwargs):
5561
"""
5662
Utility function to schedule a query to be executed using the db_lock.
57-
:param args: The arguments of which the first is self and the second the function to be run.
58-
Any additional arguments will be passed as the function arguments.
59-
:param kwargs: Keyword arguments that are passed to the function
63+
:param callable: The database function that is to be executed.
64+
:param args: Any additional arguments that will be passed as the callable's arguments.
65+
:param kwargs: Keyword arguments that are passed to the callable function.
6066
:return: A deferred that fires with the result of the query.
6167
"""
62-
if len(args) < 2:
63-
if not args:
64-
raise TypeError("run() takes at least 2 arguments, none given.")
65-
raise TypeError("%s.run() takes at least 2 arguments, 1 given" % (
66-
args[0].__class__.__name__,))
67-
self, f = args[:2]
68-
args = args[2:]
6968

70-
return self.db_lock.run(f, *args, **kwargs)
69+
return self.db_lock.run(callable, *args, **kwargs)
7170

7271
@transact
7372
def execute_query(self, query, arguments=None):
@@ -109,18 +108,18 @@ def fetch_all(self, query, arguments=None):
109108
return connection.execute(query, arguments).get_all()
110109

111110
@transact
112-
def insert(self, table_name, **argv):
111+
def insert(self, table_name, **kwargs):
113112
"""
114113
Inserts data provided as keyword arguments into the table provided as an argument.
115114
:param table_name: The name of the table the data has to be inserted into.
116115
:param argv: A dictionary where the key represents the column and the value the value to be inserted.
117116
:return: A deferred that fires when the data has been inserted.
118117
"""
119118
connection = Connection(self._database)
120-
self._insert(connection, table_name, **argv)
119+
self._insert(connection, table_name, **kwargs)
121120
connection.close()
122121

123-
def _insert(self, connection, table_name, **argv):
122+
def _insert(self, connection, table_name, **kwargs):
124123
"""
125124
Utility function to insert data which is not decorated by the @transact to prevent
126125
a loop calling this function to create many threads.
@@ -130,14 +129,14 @@ def _insert(self, connection, table_name, **argv):
130129
:param argv: A dictionary where the key represents the column and the value the value to be inserted.
131130
:return: A deferred that fires when the data has been inserted.
132131
"""
133-
if len(argv) == 0: return
134-
if len(argv) == 1:
135-
sql = u'INSERT INTO %s (%s) VALUES (?);' % (table_name, argv.keys()[0])
132+
if len(kwargs) == 0: raise ValueError("No keyword arguments supplied.")
133+
if len(kwargs) == 1:
134+
sql = u'INSERT INTO %s (%s) VALUES (?);' % (table_name, kwargs.keys()[0])
136135
else:
137-
questions = '?,' * len(argv)
138-
sql = u'INSERT INTO %s %s VALUES (%s);' % (table_name, tuple(argv.keys()), questions[:-1])
136+
questions = ','.join(('?',)*len(kwargs))
137+
sql = u'INSERT INTO %s %s VALUES (%s);' % (table_name, tuple(kwargs.keys()), questions)
139138

140-
connection.execute(sql, argv.values(), noresult=True)
139+
connection.execute(sql, kwargs.values(), noresult=True)
141140

142141
@transact
143142
def insert_many(self, table_name, arg_list):
@@ -155,7 +154,7 @@ def insert_many(self, table_name, arg_list):
155154

156155
connection.close()
157156

158-
def delete(self, table_name, **argv):
157+
def delete(self, table_name, **kwargs):
159158
"""
160159
Utility function to delete from the database.
161160
:param table_name: the table name to delete from
@@ -167,7 +166,7 @@ def delete(self, table_name, **argv):
167166
"""
168167
sql = u'DELETE FROM %s WHERE ' % table_name
169168
arg = []
170-
for k, v in argv.iteritems():
169+
for k, v in kwargs.iteritems():
171170
if isinstance(v, tuple):
172171
sql += u'%s %s ? AND ' % (k, v[0])
173172
arg.append(v[1])
@@ -177,7 +176,7 @@ def delete(self, table_name, **argv):
177176
sql = sql[:-5] # Remove the last AND
178177
return self.execute_query(sql, arg)
179178

180-
def num_rows(self, table_name):
179+
def count(self, table_name):
181180
"""
182181
Utility function to get the number of rows of a table.
183182
:param table_name: The table name

tests/test_storm_db_manager.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def setUp(self):
1919
# in-memory database do not point towards the same database.
2020
# http://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa
2121
self.storm_db = StormDBManager("sqlite:%s" % self.SQLITE_TEST_DB)
22+
self.storm_db.initialize()
2223

2324
def tearDown(self):
2425
super(TestStormDBManager, self).tearDown()
@@ -183,7 +184,7 @@ def assert_result(result):
183184
self.assertEquals(result[0], 2, "Result was not 2")
184185

185186
def get_size(_):
186-
return self.storm_db.schedule_query(self.storm_db.num_rows, "car")
187+
return self.storm_db.schedule_query(self.storm_db.count, "car")
187188

188189
def insert_into_db(_):
189190
list = []
@@ -199,18 +200,18 @@ def insert_into_db(_):
199200
return result_deferred
200201

201202
@deferred(timeout=5)
202-
def test_version(self):
203+
def test_version_no_table(self):
203204
"""
204-
This test tests whether the version is 1 if an sql error occurs.
205+
This test tests whether the version is 0 if an sql error occurs.
205206
In this case the table MyInfo does not exist.
206207
"""
207208

208209
def assert_result(_):
209210
self.assertIsInstance(self.storm_db._version, int, "_version field is not an int!")
210-
self.assertEqual(self.storm_db._version, 1, "Version was not 1 while it should be!")
211+
self.assertEqual(self.storm_db._version, 0, "Version was not 0 but: %r" % self.storm_db._version)
211212

212213
def get_size(_):
213-
return self.storm_db.schedule_query(self.storm_db.num_rows, "car")
214+
return self.storm_db.schedule_query(self.storm_db.count, "car")
214215

215216
result_deferred = self.create_car_database() # Create the car table
216217
result_deferred.addCallback(get_size) # Get the version
@@ -219,17 +220,17 @@ def get_size(_):
219220
return result_deferred
220221

221222
@deferred(timeout=5)
222-
def test_version(self):
223+
def test_version_myinfo_table(self):
223224
"""
224-
This test tests whether the version is 1 if the MyInfo table exists.
225+
This test tests whether the version is 2 if the MyInfo table exists.
225226
"""
226227

227228
def assert_result(_):
228229
self.assertIsInstance(self.storm_db._version, int, "_version field is not an int!")
229230
self.assertEqual(self.storm_db._version, 2, "Version was not 2 but: %r" % self.storm_db._version)
230231

231232
def get_version(_):
232-
return self.storm_db.schedule_query(self.storm_db._retrieve_version)
233+
return self.storm_db._retrieve_version()
233234

234235
def insert_version(_):
235236
return self.storm_db.schedule_query(self.storm_db.insert, "MyInfo", entry="version", value="2")

0 commit comments

Comments
 (0)