Skip to content

Commit

Permalink
changed tests to reflect the changes made to the manager
Browse files Browse the repository at this point in the history
  • Loading branch information
lfdversluis committed May 27, 2016
1 parent d97edba commit 98fed43
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 36 deletions.
55 changes: 27 additions & 28 deletions StormDBManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def __init__(self, db_path):
"""
self._logger = logging.getLogger(self.__class__.__name__)

# Open or create the database
self._database = create_database(db_path)
self.db_path = db_path

# The transactor is required when you have methods decorated with the @transact decorator
# This field name must NOT be changed.
Expand All @@ -31,9 +30,16 @@ def __init__(self, db_path):
# Create a DeferredLock that should be used by callers to schedule their call.
self.db_lock = DeferredLock()

self._version = 1
def initialize(self):
"""
Opens/creates the database and initializes the version.
"""
# Open or create the database
self._database = create_database(self.db_path)
self._version = 0
self._retrieve_version()


def _retrieve_version(self):
"""
Attempts to retrieve the current datbase version from the MyInfo table.
Expand All @@ -49,25 +55,18 @@ def on_error(failure):
self._logger.exception(u"Failed to load database version: %s", failure.getTraceback())

# Schedule the query and add a callback and errback to the deferred.
return self.fetch_one(u"SELECT value FROM MyInfo WHERE entry == 'version'").addCallbacks(on_result, on_error)
return self.schedule_query(self.fetch_one, u"SELECT value FROM MyInfo WHERE entry == 'version'").addCallbacks(on_result, on_error)

def schedule_query(*args, **kwargs):
def schedule_query(self, callable, *args, **kwargs):
"""
Utility function to schedule a query to be executed using the db_lock.
:param args: The arguments of which the first is self and the second the function to be run.
Any additional arguments will be passed as the function arguments.
:param kwargs: Keyword arguments that are passed to the function
:param callable: The database function that is to be executed.
:param args: Any additional arguments that will be passed as the callable's arguments.
:param kwargs: Keyword arguments that are passed to the callable function.
:return: A deferred that fires with the result of the query.
"""
if len(args) < 2:
if not args:
raise TypeError("run() takes at least 2 arguments, none given.")
raise TypeError("%s.run() takes at least 2 arguments, 1 given" % (
args[0].__class__.__name__,))
self, f = args[:2]
args = args[2:]

return self.db_lock.run(f, *args, **kwargs)
return self.db_lock.run(callable, *args, **kwargs)

@transact
def execute_query(self, query, arguments=None):
Expand Down Expand Up @@ -109,18 +108,18 @@ def fetch_all(self, query, arguments=None):
return connection.execute(query, arguments).get_all()

@transact
def insert(self, table_name, **argv):
def insert(self, table_name, **kwargs):
"""
Inserts data provided as keyword arguments into the table provided as an argument.
:param table_name: The name of the table the data has to be inserted into.
:param argv: A dictionary where the key represents the column and the value the value to be inserted.
:return: A deferred that fires when the data has been inserted.
"""
connection = Connection(self._database)
self._insert(connection, table_name, **argv)
self._insert(connection, table_name, **kwargs)
connection.close()

def _insert(self, connection, table_name, **argv):
def _insert(self, connection, table_name, **kwargs):
"""
Utility function to insert data which is not decorated by the @transact to prevent
a loop calling this function to create many threads.
Expand All @@ -130,14 +129,14 @@ def _insert(self, connection, table_name, **argv):
:param argv: A dictionary where the key represents the column and the value the value to be inserted.
:return: A deferred that fires when the data has been inserted.
"""
if len(argv) == 0: return
if len(argv) == 1:
sql = u'INSERT INTO %s (%s) VALUES (?);' % (table_name, argv.keys()[0])
if len(kwargs) == 0: raise ValueError("No keyword arguments supplied.")
if len(kwargs) == 1:
sql = u'INSERT INTO %s (%s) VALUES (?);' % (table_name, kwargs.keys()[0])
else:
questions = '?,' * len(argv)
sql = u'INSERT INTO %s %s VALUES (%s);' % (table_name, tuple(argv.keys()), questions[:-1])
questions = ','.join(('?',)*len(kwargs))
sql = u'INSERT INTO %s %s VALUES (%s);' % (table_name, tuple(kwargs.keys()), questions)

connection.execute(sql, argv.values(), noresult=True)
connection.execute(sql, kwargs.values(), noresult=True)

@transact
def insert_many(self, table_name, arg_list):
Expand All @@ -155,7 +154,7 @@ def insert_many(self, table_name, arg_list):

connection.close()

def delete(self, table_name, **argv):
def delete(self, table_name, **kwargs):
"""
Utility function to delete from the database.
:param table_name: the table name to delete from
Expand All @@ -167,7 +166,7 @@ def delete(self, table_name, **argv):
"""
sql = u'DELETE FROM %s WHERE ' % table_name
arg = []
for k, v in argv.iteritems():
for k, v in kwargs.iteritems():
if isinstance(v, tuple):
sql += u'%s %s ? AND ' % (k, v[0])
arg.append(v[1])
Expand All @@ -177,7 +176,7 @@ def delete(self, table_name, **argv):
sql = sql[:-5] # Remove the last AND
return self.execute_query(sql, arg)

def num_rows(self, table_name):
def count(self, table_name):
"""
Utility function to get the number of rows of a table.
:param table_name: The table name
Expand Down
17 changes: 9 additions & 8 deletions tests/test_storm_db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def setUp(self):
# in-memory database do not point towards the same database.
# http://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa
self.storm_db = StormDBManager("sqlite:%s" % self.SQLITE_TEST_DB)
self.storm_db.initialize()

def tearDown(self):
super(TestStormDBManager, self).tearDown()
Expand Down Expand Up @@ -183,7 +184,7 @@ def assert_result(result):
self.assertEquals(result[0], 2, "Result was not 2")

def get_size(_):
return self.storm_db.schedule_query(self.storm_db.num_rows, "car")
return self.storm_db.schedule_query(self.storm_db.count, "car")

def insert_into_db(_):
list = []
Expand All @@ -199,18 +200,18 @@ def insert_into_db(_):
return result_deferred

@deferred(timeout=5)
def test_version(self):
def test_version_no_table(self):
"""
This test tests whether the version is 1 if an sql error occurs.
This test tests whether the version is 0 if an sql error occurs.
In this case the table MyInfo does not exist.
"""

def assert_result(_):
self.assertIsInstance(self.storm_db._version, int, "_version field is not an int!")
self.assertEqual(self.storm_db._version, 1, "Version was not 1 while it should be!")
self.assertEqual(self.storm_db._version, 0, "Version was not 0 but: %r" % self.storm_db._version)

def get_size(_):
return self.storm_db.schedule_query(self.storm_db.num_rows, "car")
return self.storm_db.schedule_query(self.storm_db.count, "car")

result_deferred = self.create_car_database() # Create the car table
result_deferred.addCallback(get_size) # Get the version
Expand All @@ -219,17 +220,17 @@ def get_size(_):
return result_deferred

@deferred(timeout=5)
def test_version(self):
def test_version_myinfo_table(self):
"""
This test tests whether the version is 1 if the MyInfo table exists.
This test tests whether the version is 2 if the MyInfo table exists.
"""

def assert_result(_):
self.assertIsInstance(self.storm_db._version, int, "_version field is not an int!")
self.assertEqual(self.storm_db._version, 2, "Version was not 2 but: %r" % self.storm_db._version)

def get_version(_):
return self.storm_db.schedule_query(self.storm_db._retrieve_version)
return self.storm_db._retrieve_version()

def insert_version(_):
return self.storm_db.schedule_query(self.storm_db.insert, "MyInfo", entry="version", value="2")
Expand Down

0 comments on commit 98fed43

Please sign in to comment.