diff --git a/database.py b/database.py index 37779910..425bb292 100644 --- a/database.py +++ b/database.py @@ -11,8 +11,7 @@ from abc import ABCMeta, abstractmethod from sqlite3 import Connection -from .util import attach_runtime_statistics - +from .util import attach_runtime_statistics, check_io_originates_from_init if "--explain-query-plan" in getattr(sys, "argv", []): _explain_query_plan_logger = logging.getLogger("explain-query-plan") @@ -120,6 +119,8 @@ def _initial_statements(self): assert self._cursor is not None, "Database.close() has been called or Database.open() has not been called" assert self._connection is not None, "Database.close() has been called or Database.open() has not been called" + check_io_originates_from_init() + # collect current database configuration page_size = int(next(self._cursor.execute(u"PRAGMA page_size"))[0]) journal_mode = unicode(next(self._cursor.execute(u"PRAGMA journal_mode"))[0]).upper() @@ -173,6 +174,8 @@ def _prepare_version(self): assert self._cursor is not None, "Database.close() has been called or Database.open() has not been called" assert self._connection is not None, "Database.close() has been called or Database.open() has not been called" + check_io_originates_from_init() + # check is the database contains an 'option' table try: count, = next(self.execute(u"SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = 'option'")) @@ -273,6 +276,7 @@ def execute(self, statement, bindings=(), get_lastrowid=False): @returns: unknown @raise sqlite.Error: unknown """ + check_io_originates_from_init() if __debug__: assert self._cursor is not None, "Database.close() has been called or Database.open() has not been called" assert self._connection is not None, "Database.close() has been called or Database.open() has not been called" @@ -297,6 +301,7 @@ def execute(self, statement, bindings=(), get_lastrowid=False): @attach_runtime_statistics(u"{0.__class__.__name__}.{function_name} {1} [{0.file_path}]") def executescript(self, statements): + check_io_originates_from_init() assert self._cursor is not None, "Database.close() has been called or Database.open() has not been called" assert self._connection is not None, "Database.close() has been called or Database.open() has not been called" assert self._debug_thread_ident != 0, "please call database.open() first" @@ -335,6 +340,7 @@ def executemany(self, statement, sequenceofbindings): @returns: unknown @raise sqlite.Error: unknown """ + check_io_originates_from_init() assert self._cursor is not None, "Database.close() has been called or Database.open() has not been called" assert self._connection is not None, "Database.close() has been called or Database.open() has not been called" assert self._debug_thread_ident != 0, "please call database.open() first" @@ -368,6 +374,7 @@ def executemany(self, statement, sequenceofbindings): @attach_runtime_statistics(u"{0.__class__.__name__}.{function_name} [{0.file_path}]") def commit(self, exiting=False): + check_io_originates_from_init() assert self._cursor is not None, "Database.close() has been called or Database.open() has not been called" assert self._connection is not None, "Database.close() has been called or Database.open() has not been called" assert self._debug_thread_ident != 0, "please call database.open() first" diff --git a/message.py b/message.py index 1c9091c4..c5d5eb7f 100644 --- a/message.py +++ b/message.py @@ -318,7 +318,7 @@ class Message(MetaObject): class Implementation(Packet): - def __init__(self, meta, authentication, resolution, distribution, destination, payload, conversion=None, candidate=None, source=u"unknown", packet="", packet_id=0, sign=True): + def __init__(self, meta, authentication, resolution, distribution, destination, payload, conversion=None, candidate=None, source=u"unknown", packet="", packet_id=0): from .conversion import Conversion assert isinstance(meta, Message), "META has invalid type '%s'" % type(meta) assert isinstance(authentication, meta.authentication.Implementation), "AUTHENTICATION has invalid type '%s'" % type(authentication) @@ -358,16 +358,24 @@ def __init__(self, meta, authentication, resolution, distribution, destination, else: self._conversion = meta.community.get_conversion_for_message(self) - if not packet: - self._packet = self._conversion.encode_message(self, sign=sign) + def initialize_packet(self, sign): + """ + Must be called if packet was None in the constructor. + Args: + sign: The verify sign for the packet. + + """ + self._packet = self._conversion.encode_message(self, sign=sign) + + if __debug__: # attempt to decode the message when running in debug + try: + self._conversion.decode_message(LoopbackCandidate(), self._packet, verify=sign, + allow_empty_signature=True) + except DropPacket: + from binascii import hexlify + self._logger.error("Could not decode message created by me, hex '%s'", hexlify(self._packet)) + raise - if __debug__: # attempt to decode the message when running in debug - try: - self._conversion.decode_message(LoopbackCandidate(), self._packet, verify=sign, allow_empty_signature=True) - except DropPacket: - from binascii import hexlify - self._logger.error("Could not decode message created by me, hex '%s'", hexlify(self._packet)) - raise @property def conversion(self): @@ -522,8 +530,13 @@ def impl(self, authentication=(), resolution=(), distribution=(), destination=() distribution_impl = self._distribution.Implementation(self._distribution, *distribution) destination_impl = self._destination.Implementation(self._destination, *destination) payload_impl = self._payload.Implementation(self._payload, *payload) - return self.Implementation(self, authentication_impl, resolution_impl, distribution_impl, destination_impl, payload_impl, *args, **kargs) + impl = self.Implementation(self, authentication_impl, resolution_impl, distribution_impl, destination_impl, payload_impl, *args, **kargs) + packet = kargs.get("packet", "") + if not packet: + sign = kargs["sign"] if "sign" in kargs else True + impl.initialize_packet(sign) + return impl except (TypeError, DropPacket): self._logger.error("message name: %s", self._name) self._logger.error("authentication: %s.Implementation", self._authentication.__class__.__name__) diff --git a/tests/debugcommunity/node.py b/tests/debugcommunity/node.py index 03702af2..ef616f92 100644 --- a/tests/debugcommunity/node.py +++ b/tests/debugcommunity/node.py @@ -30,15 +30,21 @@ class DebugNode(object): node.init_my_member() """ - def __init__(self, testclass, dispersy, communityclass=DebugCommunity, c_master_member=None, curve=u"low"): + def __init__(self, testclass, dispersy): super(DebugNode, self).__init__() self._logger = logging.getLogger(self.__class__.__name__) self._testclass = testclass self._dispersy = dispersy - self._my_member = self._dispersy.get_new_member(curve) - self._my_pub_member = Member(self._dispersy, self._my_member._ec.pub(), self._my_member.database_id) + self._central_node = None + self._tunnel = False + self._connection_type = u"unknown" + + @inlineCallbacks + def initialize(self, communityclass=DebugCommunity, c_master_member=None, curve=u"low"): + self._my_member = yield self._dispersy.get_new_member(curve) + self._my_pub_member = Member(self._dispersy, self._my_member._ec.pub(), self._my_member.database_id) if c_master_member == None: self._community = communityclass.create_community(self._dispersy, self._my_member) else: @@ -46,8 +52,6 @@ def __init__(self, testclass, dispersy, communityclass=DebugCommunity, c_master_ self._community = communityclass.init_community(self._dispersy, mm, self._my_member) self._central_node = c_master_member - self._tunnel = False - self._connection_type = u"unknown" @property def community(self): diff --git a/tests/dispersytestclass.py b/tests/dispersytestclass.py index ce3106dc..66d494d6 100644 --- a/tests/dispersytestclass.py +++ b/tests/dispersytestclass.py @@ -91,7 +91,7 @@ def _create_nodes(amount, store_identity, tunnel, communityclass, autoload_disco self.dispersy_objects.append(dispersy) - node = self._create_node(dispersy, communityclass, self._mm) + node = yield self._create_node(dispersy, communityclass, self._mm) yield node.init_my_member(tunnel=tunnel, store_identity=store_identity) nodes.append(node) @@ -101,5 +101,8 @@ def _create_nodes(amount, store_identity, tunnel, communityclass, autoload_disco return blockingCallFromThread(reactor, _create_nodes, amount, store_identity, tunnel, community_class, autoload_discovery, memory_database) + @inlineCallbacks def _create_node(self, dispersy, community_class, c_master_member): - return DebugNode(self, dispersy, community_class, c_master_member) + node = DebugNode(self, dispersy) + yield node.initialize(community_class, c_master_member) + returnValue(node) diff --git a/util.py b/util.py index 08ab2f8d..e04223a9 100644 --- a/util.py +++ b/util.py @@ -335,5 +335,7 @@ def address_in_subnet(address, subnet): subnet_main >>= 32-netmask return address == subnet_main - - +def check_io_originates_from_init(): + for line in traceback.format_stack(): + if "__init__" in line: + raise IOError("IO Originates from on __init__!") \ No newline at end of file