Skip to content

Commit

Permalink
Merge pull request #626 from ibm-watson-iot/kohlmann-master-sqlInjection
Browse files Browse the repository at this point in the history
Kohlmann master sql injection
  • Loading branch information
pkohlmann authored Mar 20, 2024
2 parents 9621317 + c7ea24f commit 26365c3
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 41 deletions.
45 changes: 37 additions & 8 deletions iotfunctions/dbhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# *****************************************************************************

import datetime
import logging

import ibm_db
import pandas as pd
import psycopg2.extras
import re

logger = logging.getLogger(__name__)
SQL_PATTERN = re.compile('\w*')
SQL_PATTERN_EXTENDED = re.compile('[\w-]*')

# PostgreSQL Queries
POSTGRE_SQL_INFORMATION_SCHEMA = " SELECT table_schema,table_name , column_name ,udt_name, character_maximum_length FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ; "
Expand All @@ -30,26 +34,21 @@ def quotingSchemaName(schemaName, is_postgre_sql=False):

def quotingTableName(tableName, is_postgre_sql=False):
quotedTableName = 'NULL'
quote = '\"'
twoQuotes = '\"\"'

if tableName is not None:
tableName = tableName if is_postgre_sql else tableName.upper()
# Quote string and escape all quotes in string by an additional quote
quotedTableName = quote + tableName.replace(quote, twoQuotes) + quote
quotedTableName = f'"{tableName if is_postgre_sql else tableName.upper()}"'

return quotedTableName


def quotingSqlString(sqlValue):
preparedValue = 'NULL'
quote = '\''
twoQuotes = '\'\''

if sqlValue is not None:
if isinstance(sqlValue, str):
# Quote string and escape all quotes in string by an additional quote
preparedValue = quote + sqlValue.replace(quote, twoQuotes) + quote
preparedValue = f"'{sqlValue}'"
else:
# sqlValue is no string; therefore just return it as is
preparedValue = sqlValue
Expand Down Expand Up @@ -141,3 +140,33 @@ def execute_postgre_sql_query(db_connection, sql, params=None, raise_error=True)
finally:
if cursor is not None:
cursor.close()


def check_sql_injection(input_obj):
input_type = type(input_obj)
if input_type == str:
if SQL_PATTERN.fullmatch(input_obj) is None:
raise RuntimeError(f"The following string contains forbidden characters and cannot be inserted into a sql "
f"statement for security reason. Only letters including underscore are allowed: {input_obj}")
elif input_type == int or input_type == float:
pass
elif input_type == pd.Timestamp or input_type == datetime.datetime:
pass
else:
raise RuntimeError(f"The following object has an unexpected type {input_type} and cannot be inserted into a "
f"sql statement for security reason: {input_obj}")

return input_obj


def check_sql_injection_extended(input_string):

if type(input_string) == str:
if SQL_PATTERN_EXTENDED.fullmatch(input_string) is None:
raise RuntimeError(f"The string {input_string} contains forbidden characters and cannot be inserted "
f"into a sql statement for security reason. Only letters, underscore and hyphen are allowed.")
else:
raise RuntimeError(f"A string is expected but the object {input_string} has type {type(input_string)}. "
f"It cannot be inserted into a sql statement for security reason.")

return input_string
30 changes: 13 additions & 17 deletions iotfunctions/dbtables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pyarrow.parquet

from iotfunctions import dbhelper
from iotfunctions.dbhelper import check_sql_injection

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,8 +51,9 @@ def __init__(self, tenant_id, entity_type_id, schema, db_connection, db_type):
raise Exception('Initialization of %s failed because the database type %s is unknown.' % (
self.__class__.__name__, self.db_type))

self.quoted_schema = dbhelper.quotingSchemaName(self.schema, self.is_postgre_sql)
self.quoted_cache_tablename = dbhelper.quotingTableName(self.cache_tablename, self.is_postgre_sql)
self.quoted_schema = dbhelper.quotingSchemaName(check_sql_injection(self.schema), self.is_postgre_sql)
self.quoted_cache_tablename = dbhelper.quotingTableName(check_sql_injection(self.cache_tablename), self.is_postgre_sql)
self.quoted_constraint_name = dbhelper.quotingTableName(check_sql_injection('uc_%s' % self.cache_tablename), self.is_postgre_sql)

self._handle_cache_table()

Expand All @@ -66,8 +68,7 @@ def _create_cache_table(self):
"UPDATED_TS TIMESTAMP NOT NULL DEFAULT CURRENT TIMESTAMP, " \
"CONSTRAINT %s UNIQUE(ENTITY_TYPE_ID, PARQUET_NAME) ENFORCED ) " \
"ORGANIZE BY ROW" % (self.quoted_schema, self.quoted_cache_tablename,
dbhelper.quotingTableName('uc_%s' % self.cache_tablename,
self.is_postgre_sql))
self.quoted_constraint_name)
try:
stmt = ibm_db.exec_immediate(self.db_connection, sql_statement)
ibm_db.free_result(stmt)
Expand All @@ -80,8 +81,7 @@ def _create_cache_table(self):
"parquet_file BYTEA, " \
"updated_ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, " \
"CONSTRAINT %s UNIQUE(entity_type_id, parquet_name))" % (
self.quoted_schema, self.quoted_cache_tablename,
dbhelper.quotingTableName('uc_%s' % self.cache_tablename, self.is_postgre_sql))
self.quoted_schema, self.quoted_cache_tablename, self.quoted_constraint_name)
try:
dbhelper.execute_postgre_sql_query(self.db_connection, sql_statement)
except Exception as ex:
Expand Down Expand Up @@ -161,8 +161,7 @@ def _push_cache(self, cache_filename, cache_pathname):

statement3 = "ON CONFLICT ON CONSTRAINT %s DO update set entity_type_id = EXCLUDED.entity_type_id, " \
"parquet_name = EXCLUDED.parquet_name, parquet_file = EXCLUDED.parquet_file, " \
"updated_ts = EXCLUDED.updated_ts" % dbhelper.quotingTableName(
('uc_%s' % self.cache_tablename), self.is_postgre_sql)
"updated_ts = EXCLUDED.updated_ts" % self.quoted_constraint_name

sql_statement = statement1 + " values (%s, %s, %s, current_timestamp) " + statement3

Expand Down Expand Up @@ -482,8 +481,9 @@ def __init__(self, tenant_id, entity_type_id, schema, db_connection, db_type):
raise Exception('Initialization of %s failed because the database type %s is unknown.' % (
self.__class__.__name__, self.db_type))

self.quoted_schema = dbhelper.quotingSchemaName(self.schema, self.is_postgre_sql)
self.quoted_store_tablename = dbhelper.quotingTableName(self.store_tablename, self.is_postgre_sql)
self.quoted_schema = dbhelper.quotingSchemaName(check_sql_injection(self.schema), self.is_postgre_sql)
self.quoted_store_tablename = dbhelper.quotingTableName(check_sql_injection(self.store_tablename), self.is_postgre_sql)
self.quoted_constraint_name = dbhelper.quotingTableName(check_sql_injection('uc_%s' % self.store_tablename), self.is_postgre_sql)

self._handle_store_table()

Expand All @@ -497,9 +497,7 @@ def _create_store_table(self):
"UPDATED_TS TIMESTAMP NOT NULL DEFAULT CURRENT TIMESTAMP, " \
"LAST_UPDATED_BY VARCHAR(256), " \
"CONSTRAINT %s UNIQUE(ENTITY_TYPE_ID, MODEL_NAME) ENFORCED) " \
"ORGANIZE BY ROW" % (self.quoted_schema, self.quoted_store_tablename,
dbhelper.quotingTableName('uc_%s' % self.store_tablename,
self.is_postgre_sql))
"ORGANIZE BY ROW" % (self.quoted_schema, self.quoted_store_tablename, self.quoted_constraint_name)
try:
stmt = ibm_db.exec_immediate(self.db_connection, sql_statement)
ibm_db.free_result(stmt)
Expand All @@ -513,8 +511,7 @@ def _create_store_table(self):
"updated_ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, " \
"last_updated_by VARCHAR(256), " \
"CONSTRAINT %s UNIQUE(entity_type_id, model_name))" % (
self.quoted_schema, self.quoted_store_tablename,
dbhelper.quotingTableName('uc_%s' % self.store_tablename, self.is_postgre_sql))
self.quoted_schema, self.quoted_store_tablename, self.quoted_constraint_name)
try:
dbhelper.execute_postgre_sql_query(self.db_connection, sql_statement)
except Exception as ex:
Expand Down Expand Up @@ -591,8 +588,7 @@ def store_model(self, model_name, model, user_name=None, serialize=True):

statement3 = "ON CONFLICT ON CONSTRAINT %s DO update set entity_type_id = EXCLUDED.entity_type_id, " \
"model_name = EXCLUDED.model_name, model = EXCLUDED.model, " \
"updated_ts = EXCLUDED.updated_ts, last_updated_by = EXCLUDED.last_updated_by" % dbhelper.quotingTableName(
('uc_%s' % self.store_tablename), self.is_postgre_sql)
"updated_ts = EXCLUDED.updated_ts, last_updated_by = EXCLUDED.last_updated_by" % self.quoted_constraint_name

sql_statement = statement1 + " values (%s, %s, %s, current_timestamp, %s) " + statement3

Expand Down
23 changes: 12 additions & 11 deletions iotfunctions/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pandas as pd

from iotfunctions import dbhelper, util
from iotfunctions.dbhelper import check_sql_injection, check_sql_injection_extended


class LoaderPipeline:
Expand Down Expand Up @@ -143,16 +144,16 @@ def execute(self, df, start_ts, end_ts, entities=None):
key_timestamp = 'key_timestamp_'

if self.schema is not None and len(self.schema) > 0:
schema_prefix = f"{dbhelper.quotingSchemaName(self.schema, self.dms.is_postgre_sql)}."
schema_prefix = f"{dbhelper.quotingSchemaName(check_sql_injection(self.schema), self.dms.is_postgre_sql)}."
else:
schema_prefix = ""

sql = 'SELECT %s, %s AS "%s", %s AS "%s" FROM %s%s' % (
', '.join([dbhelper.quotingColumnName(col, self.dms.is_postgre_sql) for col in self.columns]),
dbhelper.quotingColumnName(self.id_col, self.dms.is_postgre_sql), key_id,
dbhelper.quotingColumnName(self.timestamp_col, self.dms.is_postgre_sql), key_timestamp,
', '.join([dbhelper.quotingColumnName(check_sql_injection(col), self.dms.is_postgre_sql) for col in self.columns]),
dbhelper.quotingColumnName(check_sql_injection(self.id_col), self.dms.is_postgre_sql), key_id,
dbhelper.quotingColumnName(check_sql_injection(self.timestamp_col), self.dms.is_postgre_sql), key_timestamp,
schema_prefix,
dbhelper.quotingTableName(self.table, self.dms.is_postgre_sql))
dbhelper.quotingTableName(check_sql_injection(self.table), self.dms.is_postgre_sql))
condition_applied = False
if self.where_clause is not None:
sql += ' WHERE %s' % self.where_clause
Expand All @@ -162,18 +163,18 @@ def execute(self, df, start_ts, end_ts, entities=None):
sql += ' WHERE '
else:
sql += ' AND '
sql += "%s <= %s AND %s < %s" % (dbhelper.quotingSqlString(str(start_ts)),
dbhelper.quotingColumnName(self.timestamp_col, self.dms.is_postgre_sql),
dbhelper.quotingColumnName(self.timestamp_col, self.dms.is_postgre_sql),
dbhelper.quotingSqlString(str(end_ts)))
sql += "%s <= %s AND %s < %s" % (dbhelper.quotingSqlString(str(check_sql_injection(start_ts))),
dbhelper.quotingColumnName(check_sql_injection(self.timestamp_col), self.dms.is_postgre_sql),
dbhelper.quotingColumnName(check_sql_injection(self.timestamp_col), self.dms.is_postgre_sql),
dbhelper.quotingSqlString(str(check_sql_injection(end_ts))))
condition_applied = True
if entities is not None:
if not condition_applied:
sql += ' WHERE '
else:
sql += ' AND '
sql += "%s IN (%s)" % (dbhelper.quotingColumnName(self.id_col, self.dms.is_postgre_sql),
', '.join([dbhelper.quotingSqlString(ent) for ent in entities]))
sql += "%s IN (%s)" % (dbhelper.quotingColumnName(check_sql_injection(self.id_col), self.dms.is_postgre_sql),
', '.join([dbhelper.quotingSqlString(check_sql_injection_extended(ent)) for ent in entities]))

self.parse_dates.add(key_timestamp)
requested_col_names = self.names + [key_id, key_timestamp]
Expand Down
11 changes: 6 additions & 5 deletions iotfunctions/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy import (MetaData, Table)

from . import dbhelper
from .dbhelper import check_sql_injection
from .exceptions import StageException, DataWriterException
from .util import MessageHub, asList
from . import metadata as md
Expand Down Expand Up @@ -271,7 +272,7 @@ def create_upsert_statement(self, tableName, grain):
joinExtension = ''
sourceExtension = ''
for dimension in dimensions:
quoted_dimension = dbhelper.quotingColumnName(dimension)
quoted_dimension = dbhelper.quotingColumnName(check_sql_injection(dimension))
colExtension += ', ' + quoted_dimension
parmExtension += ', ?'
joinExtension += ' AND TARGET.' + quoted_dimension + ' = SOURCE.' + quoted_dimension
Expand All @@ -284,7 +285,7 @@ def create_upsert_statement(self, tableName, grain):
"UPDATE SET TARGET.VALUE_B = SOURCE.VALUE_B, TARGET.VALUE_N = SOURCE.VALUE_N, TARGET.VALUE_S = SOURCE.VALUE_S, TARGET.VALUE_T = SOURCE.VALUE_T, TARGET.LAST_UPDATE = SOURCE.LAST_UPDATE "
"WHEN NOT MATCHED THEN "
"INSERT (KEY%s, VALUE_B, VALUE_N, VALUE_S, VALUE_T, LAST_UPDATE) VALUES (SOURCE.KEY%s, SOURCE.VALUE_B, SOURCE.VALUE_N, SOURCE.VALUE_S, SOURCE.VALUE_T, CURRENT TIMESTAMP)") % (
dbhelper.quotingSchemaName(self.schema), dbhelper.quotingTableName(tableName), parmExtension,
dbhelper.quotingSchemaName(check_sql_injection(self.schema)), dbhelper.quotingTableName(check_sql_injection(tableName)), parmExtension,
colExtension, joinExtension, colExtension, sourceExtension)

def create_upsert_statement_postgres_sql(self, tableName, grain):
Expand All @@ -305,7 +306,7 @@ def create_upsert_statement_postgres_sql(self, tableName, grain):

for dimension in dimensions:
# Note: the dimension grain need to be in lower case since the table will be created with lowercase column.
quoted_dimension = dbhelper.quotingColumnName(dimension.lower(), self.is_postgre_sql)
quoted_dimension = dbhelper.quotingColumnName(check_sql_injection(dimension.lower()), self.is_postgre_sql)
colExtension += ', ' + quoted_dimension
parmExtension += ', %s'

Expand Down Expand Up @@ -337,8 +338,8 @@ def __init__(self, dms, alerts=None, data_item_names=None, **kwargs):
except AttributeError:
self.entity_type_name = dms.entity_type

self.quoted_schema = dbhelper.quotingSchemaName(dms.default_db_schema, self.dms.is_postgre_sql)
self.quoted_table_name = dbhelper.quotingTableName(self.ALERT_TABLE_NAME, self.dms.is_postgre_sql)
self.quoted_schema = dbhelper.quotingSchemaName(check_sql_injection(dms.default_db_schema), self.dms.is_postgre_sql)
self.quoted_table_name = dbhelper.quotingTableName(check_sql_injection(self.ALERT_TABLE_NAME), self.dms.is_postgre_sql)
self.alert_to_kpi_input_dict = dict()

# Requirement: alerts_to_msg_hub must be a subset of alerts_to_db because the alerts in data base are exploited
Expand Down

0 comments on commit 26365c3

Please sign in to comment.