Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pep8 compliance in pgamit classes #114

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 58 additions & 33 deletions pgamit/dbConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
Date: 02/16/2017
Author: Demian D. Gomez

This class is used to connect to the database and handles inserts, updates and selects
This class is used to connect to the database
and handles inserts, updates and selects.
It also handles the error, info and warning messages
"""

Expand Down Expand Up @@ -38,7 +39,9 @@ def cast_array_to_float(recordset):
new_record = []
for field in record:
if isinstance(field, list):
new_record.append([float(value) if isinstance(value, Decimal) else value for value in field])
new_record.append(
[float(value) if isinstance(value, Decimal)
else value for value in field])
else:
if isinstance(field, Decimal):
new_record.append(float(field))
Expand All @@ -54,7 +57,8 @@ def cast_array_to_float(recordset):
for key, value in record.items():
if isinstance(value, Decimal):
record[key] = float(value)
elif isinstance(value, list) and all(isinstance(i, Decimal) for i in value):
elif (isinstance(value, list)
and all(isinstance(i, Decimal) for i in value)):
record[key] = [float(i) for i in value]

return recordset
Expand Down Expand Up @@ -91,19 +95,24 @@ def debug(s):
file_append('/tmp/db.log', "DB: %s\n" % s)


class dbErrInsert (Exception): pass
class dbErrInsert (Exception):
pass


class dbErrUpdate (Exception): pass
class dbErrUpdate (Exception):
pass


class dbErrConnect(Exception): pass
class dbErrConnect(Exception):
pass


class dbErrDelete (Exception): pass
class dbErrDelete (Exception):
pass


class DatabaseError(psycopg2.DatabaseError): pass
class DatabaseError(psycopg2.DatabaseError):
pass


class Cnn(object):
Expand All @@ -116,8 +125,8 @@ def __init__(self, configfile, use_float=False, write_cfg_file=False):
'database': DB_NAME}

self.active_transaction = False
self.options = options
self.options = options

# parse session config file
config = configparser.ConfigParser()

Expand All @@ -126,7 +135,8 @@ def __init__(self, configfile, use_float=False, write_cfg_file=False):
except FileNotFoundError:
if write_cfg_file:
create_empty_cfg()
print(' >> No gnss_data.cfg file found, an empty one has been created. Replace all the necessary '
print(' >> No gnss_data.cfg file found, an empty one '
'has been created. Replace all the necessary '
'config and try again.')
exit(1)
else:
Expand All @@ -143,9 +153,11 @@ def __init__(self, configfile, use_float=False, write_cfg_file=False):

# Define the custom type for an array of decimals
DECIMAL_ARRAY_TYPE = psycopg2.extensions.new_type(
(psycopg2.extensions.DECIMAL.values,), # This matches the type codes for DECIMAL
(psycopg2.extensions.DECIMAL.values,),
# This matches the type codes for DECIMAL
'DECIMAL_ARRAY', # Name of the type
lambda value, curs: [float(d) for d in value] if value is not None else None
lambda value, curs:
[float(d) for d in value] if value is not None else None
)

psycopg2.extensions.register_type(DEC2FLOAT)
Expand All @@ -155,11 +167,14 @@ def __init__(self, configfile, use_float=False, write_cfg_file=False):
err = None
for i in range(3):
try:
self.cnn = psycopg2.connect(host=options['hostname'], user=options['username'],
password=options['password'], dbname=options['database'])
self.cnn = psycopg2.connect(host=options['hostname'],
user=options['username'],
password=options['password'],
dbname=options['database'])

self.cnn.autocommit = True
self.cursor = self.cnn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
self.cursor = self.cnn.cursor(
cursor_factory=psycopg2.extras.RealDictCursor)

debug("Database connection established")

Expand All @@ -179,8 +194,8 @@ def query(self, command):
return query_obj(self.cursor)

def query_float(self, command, as_dict=False):
# deprecated: using psycopg2 now solves the problem of returning float numbers
# still in to maintain backwards compatibility
# deprecated: using psycopg2 now solves the problem of returning float
# numbers still in to maintain backwards compatibility

if not as_dict:
cursor = self.cnn.cursor()
Expand All @@ -195,18 +210,23 @@ def query_float(self, command, as_dict=False):

def get(self, table, filter_fields, return_fields):
"""
Selects from the given table the records that match filter_fields and returns ONE dictionary.
Method should not be used to retrieve more than one single record.
Selects from the given table the records that match filter_fields and
returns ONE dictionary. Method should not be used to retrieve more
than one single record.

Parameters:
table (str): The table to select from.
filter_fields (dict): The dictionary where the keys are the field names and the values are the filter values.
filter_fields (dict): The dictionary where the keys are the field
names and the values are the filter values.
return_fields (list of str): The fields to return.

Returns:
list: A list of dictionaries, each representing a record that matches the filter.
list: A list of dictionaries, each representing a record that
matches the filter.
"""

where_clause = ' AND '.join([f'"{key}" = %s' for key in filter_fields.keys()])
where_clause = ' AND '.join([f'"{key}" = %s'
for key in filter_fields.keys()])
fields_clause = ', '.join([f'"{field}"' for field in return_fields])
query = f'SELECT {fields_clause} FROM {table} WHERE {where_clause}'
values = list(filter_fields.values())
Expand All @@ -225,7 +245,9 @@ def get(self, table, filter_fields, return_fields):
raise e

def get_columns(self, table):
tblinfo = self.query('select column_name, data_type from information_schema.columns where table_name=\'%s\''
tblinfo = self.query(('select column_name, data_type from'
'information_schema.columns where '
'table_name=\'%s\'')
% table).dictresult()

return {field['column_name']: field['data_type'] for field in tblinfo}
Expand Down Expand Up @@ -262,13 +284,15 @@ def insert(self, table, **kw):

def update(self, table, row, **kwargs):
"""
Updates the specified table with new field values. The row(s) are updated based on the primary key(s)
indicated in the 'row' dictionary. New values are specified in kwargs. Field names must be enclosed
with double quotes to handle camel case names.
Updates the specified table with new field values. The row(s) are
updated based on the primary key(s) indicated in the 'row' dictionary.
New values are specified in kwargs. Field names must be enclosed with
double quotes to handle camel case names.

Parameters:
table (str): The table to update.
row (dict): The dictionary where the keys are the primary key fields and the values are the row's identifiers.
row (dict): The dictionary where the keys are the primary key fields
and the values are the row's identifiers.
kwargs: New field values for the row.
"""
# Build the SET clause of the query
Expand All @@ -293,7 +317,8 @@ def update(self, table, row, **kwargs):

def delete(self, table, **kw):
"""
Deletes row(s) from the specified table based on the provided keyword arguments.
Deletes row(s) from the specified table based on the provided
keyword arguments.

Parameters:
table (str): The table to delete from.
Expand Down Expand Up @@ -332,7 +357,8 @@ def insert_event_bak(self, type, module, desc):
desc = re.sub(r'BASH.*', '', desc)
desc = re.sub(r'PSQL.*', '', desc)

# warn = self.query('SELECT * FROM events WHERE "EventDescription" = \'%s\'' % (desc))
# warn = self.query('SELECT * FROM events WHERE
# "EventDescription" = \'%s\'' % (desc))

# if warn.ntuples() == 0:
self.insert('events', EventType=type, EventDescription=desc)
Expand All @@ -354,8 +380,7 @@ def __del__(self):
def _caller_str():
# get the module calling to make clear how is logging this message
frame = inspect.stack()[2]
line = frame[2]
line = frame[2]
caller = frame[3]

return '[%s:%s(%s)]\n' % (platform.node(), caller, str(line))

return '[%s:%s(%s)]\n' % (platform.node(), caller, str(line))
Loading