Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kohlmann master sql injection #626

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions iotfunctions/dbhelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# *****************************************************************************

import datetime
import logging

import ibm_db
import pandas as pd
import psycopg2.extras
import re

logger = logging.getLogger(__name__)
SQL_PATTERN = re.compile('\w*')
SQL_PATTERN_EXTENDED = re.compile('[\w-]*')

# PostgreSQL Queries
POSTGRE_SQL_INFORMATION_SCHEMA = " SELECT table_schema,table_name , column_name ,udt_name, character_maximum_length FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s ; "
Expand All @@ -30,26 +34,21 @@ def quotingSchemaName(schemaName, is_postgre_sql=False):

def quotingTableName(tableName, is_postgre_sql=False):
quotedTableName = 'NULL'
quote = '\"'
twoQuotes = '\"\"'

if tableName is not None:
tableName = tableName if is_postgre_sql else tableName.upper()
# Quote string and escape all quotes in string by an additional quote
quotedTableName = quote + tableName.replace(quote, twoQuotes) + quote
quotedTableName = f'"{tableName if is_postgre_sql else tableName.upper()}"'

return quotedTableName


def quotingSqlString(sqlValue):
preparedValue = 'NULL'
quote = '\''
twoQuotes = '\'\''

if sqlValue is not None:
if isinstance(sqlValue, str):
# Quote string and escape all quotes in string by an additional quote
preparedValue = quote + sqlValue.replace(quote, twoQuotes) + quote
preparedValue = f"'{sqlValue}'"
else:
# sqlValue is no string; therefore just return it as is
preparedValue = sqlValue
Expand Down Expand Up @@ -141,3 +140,33 @@ def execute_postgre_sql_query(db_connection, sql, params=None, raise_error=True)
finally:
if cursor is not None:
cursor.close()


def check_sql_injection(input_obj):
input_type = type(input_obj)
if input_type == str:
if SQL_PATTERN.fullmatch(input_obj) is None:
raise RuntimeError(f"The following string contains forbidden characters and cannot be inserted into a sql "
f"statement for security reason. Only letters including underscore are allowed: {input_obj}")
elif input_type == int or input_type == float:
pass
elif input_type == pd.Timestamp or input_type == datetime.datetime:
pass
else:
raise RuntimeError(f"The following object has an unexpected type {input_type} and cannot be inserted into a "
f"sql statement for security reason: {input_obj}")

return input_obj


def check_sql_injection_extended(input_string):

if type(input_string) == str:
if SQL_PATTERN_EXTENDED.fullmatch(input_string) is None:
raise RuntimeError(f"The string {input_string} contains forbidden characters and cannot be inserted "
f"into a sql statement for security reason. Only letters, underscore and hyphen are allowed.")
else:
raise RuntimeError(f"A string is expected but the object {input_string} has type {type(input_string)}. "
f"It cannot be inserted into a sql statement for security reason.")

return input_string
30 changes: 13 additions & 17 deletions iotfunctions/dbtables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pyarrow.parquet

from iotfunctions import dbhelper
from iotfunctions.dbhelper import check_sql_injection

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,8 +51,9 @@ def __init__(self, tenant_id, entity_type_id, schema, db_connection, db_type):
raise Exception('Initialization of %s failed because the database type %s is unknown.' % (
self.__class__.__name__, self.db_type))

self.quoted_schema = dbhelper.quotingSchemaName(self.schema, self.is_postgre_sql)
self.quoted_cache_tablename = dbhelper.quotingTableName(self.cache_tablename, self.is_postgre_sql)
self.quoted_schema = dbhelper.quotingSchemaName(check_sql_injection(self.schema), self.is_postgre_sql)
self.quoted_cache_tablename = dbhelper.quotingTableName(check_sql_injection(self.cache_tablename), self.is_postgre_sql)
self.quoted_constraint_name = dbhelper.quotingTableName(check_sql_injection('uc_%s' % self.cache_tablename), self.is_postgre_sql)

self._handle_cache_table()

Expand All @@ -66,8 +68,7 @@ def _create_cache_table(self):
"UPDATED_TS TIMESTAMP NOT NULL DEFAULT CURRENT TIMESTAMP, " \
"CONSTRAINT %s UNIQUE(ENTITY_TYPE_ID, PARQUET_NAME) ENFORCED ) " \
"ORGANIZE BY ROW" % (self.quoted_schema, self.quoted_cache_tablename,
dbhelper.quotingTableName('uc_%s' % self.cache_tablename,
self.is_postgre_sql))
self.quoted_constraint_name)
try:
stmt = ibm_db.exec_immediate(self.db_connection, sql_statement)
ibm_db.free_result(stmt)
Expand All @@ -80,8 +81,7 @@ def _create_cache_table(self):
"parquet_file BYTEA, " \
"updated_ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, " \
"CONSTRAINT %s UNIQUE(entity_type_id, parquet_name))" % (
self.quoted_schema, self.quoted_cache_tablename,
dbhelper.quotingTableName('uc_%s' % self.cache_tablename, self.is_postgre_sql))
self.quoted_schema, self.quoted_cache_tablename, self.quoted_constraint_name)
try:
dbhelper.execute_postgre_sql_query(self.db_connection, sql_statement)
except Exception as ex:
Expand Down Expand Up @@ -161,8 +161,7 @@ def _push_cache(self, cache_filename, cache_pathname):

statement3 = "ON CONFLICT ON CONSTRAINT %s DO update set entity_type_id = EXCLUDED.entity_type_id, " \
"parquet_name = EXCLUDED.parquet_name, parquet_file = EXCLUDED.parquet_file, " \
"updated_ts = EXCLUDED.updated_ts" % dbhelper.quotingTableName(
('uc_%s' % self.cache_tablename), self.is_postgre_sql)
"updated_ts = EXCLUDED.updated_ts" % self.quoted_constraint_name

sql_statement = statement1 + " values (%s, %s, %s, current_timestamp) " + statement3

Expand Down Expand Up @@ -482,8 +481,9 @@ def __init__(self, tenant_id, entity_type_id, schema, db_connection, db_type):
raise Exception('Initialization of %s failed because the database type %s is unknown.' % (
self.__class__.__name__, self.db_type))

self.quoted_schema = dbhelper.quotingSchemaName(self.schema, self.is_postgre_sql)
self.quoted_store_tablename = dbhelper.quotingTableName(self.store_tablename, self.is_postgre_sql)
self.quoted_schema = dbhelper.quotingSchemaName(check_sql_injection(self.schema), self.is_postgre_sql)
self.quoted_store_tablename = dbhelper.quotingTableName(check_sql_injection(self.store_tablename), self.is_postgre_sql)
self.quoted_constraint_name = dbhelper.quotingTableName(check_sql_injection('uc_%s' % self.store_tablename), self.is_postgre_sql)

self._handle_store_table()

Expand All @@ -497,9 +497,7 @@ def _create_store_table(self):
"UPDATED_TS TIMESTAMP NOT NULL DEFAULT CURRENT TIMESTAMP, " \
"LAST_UPDATED_BY VARCHAR(256), " \
"CONSTRAINT %s UNIQUE(ENTITY_TYPE_ID, MODEL_NAME) ENFORCED) " \
"ORGANIZE BY ROW" % (self.quoted_schema, self.quoted_store_tablename,
dbhelper.quotingTableName('uc_%s' % self.store_tablename,
self.is_postgre_sql))
"ORGANIZE BY ROW" % (self.quoted_schema, self.quoted_store_tablename, self.quoted_constraint_name)
try:
stmt = ibm_db.exec_immediate(self.db_connection, sql_statement)
ibm_db.free_result(stmt)
Expand All @@ -513,8 +511,7 @@ def _create_store_table(self):
"updated_ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, " \
"last_updated_by VARCHAR(256), " \
"CONSTRAINT %s UNIQUE(entity_type_id, model_name))" % (
self.quoted_schema, self.quoted_store_tablename,
dbhelper.quotingTableName('uc_%s' % self.store_tablename, self.is_postgre_sql))
self.quoted_schema, self.quoted_store_tablename, self.quoted_constraint_name)
try:
dbhelper.execute_postgre_sql_query(self.db_connection, sql_statement)
except Exception as ex:
Expand Down Expand Up @@ -591,8 +588,7 @@ def store_model(self, model_name, model, user_name=None, serialize=True):

statement3 = "ON CONFLICT ON CONSTRAINT %s DO update set entity_type_id = EXCLUDED.entity_type_id, " \
"model_name = EXCLUDED.model_name, model = EXCLUDED.model, " \
"updated_ts = EXCLUDED.updated_ts, last_updated_by = EXCLUDED.last_updated_by" % dbhelper.quotingTableName(
('uc_%s' % self.store_tablename), self.is_postgre_sql)
"updated_ts = EXCLUDED.updated_ts, last_updated_by = EXCLUDED.last_updated_by" % self.quoted_constraint_name

sql_statement = statement1 + " values (%s, %s, %s, current_timestamp, %s) " + statement3

Expand Down
23 changes: 12 additions & 11 deletions iotfunctions/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pandas as pd

from iotfunctions import dbhelper, util
from iotfunctions.dbhelper import check_sql_injection, check_sql_injection_extended


class LoaderPipeline:
Expand Down Expand Up @@ -143,16 +144,16 @@ def execute(self, df, start_ts, end_ts, entities=None):
key_timestamp = 'key_timestamp_'

if self.schema is not None and len(self.schema) > 0:
schema_prefix = f"{dbhelper.quotingSchemaName(self.schema, self.dms.is_postgre_sql)}."
schema_prefix = f"{dbhelper.quotingSchemaName(check_sql_injection(self.schema), self.dms.is_postgre_sql)}."
else:
schema_prefix = ""

sql = 'SELECT %s, %s AS "%s", %s AS "%s" FROM %s%s' % (
', '.join([dbhelper.quotingColumnName(col, self.dms.is_postgre_sql) for col in self.columns]),
dbhelper.quotingColumnName(self.id_col, self.dms.is_postgre_sql), key_id,
dbhelper.quotingColumnName(self.timestamp_col, self.dms.is_postgre_sql), key_timestamp,
', '.join([dbhelper.quotingColumnName(check_sql_injection(col), self.dms.is_postgre_sql) for col in self.columns]),
dbhelper.quotingColumnName(check_sql_injection(self.id_col), self.dms.is_postgre_sql), key_id,
dbhelper.quotingColumnName(check_sql_injection(self.timestamp_col), self.dms.is_postgre_sql), key_timestamp,
schema_prefix,
dbhelper.quotingTableName(self.table, self.dms.is_postgre_sql))
dbhelper.quotingTableName(check_sql_injection(self.table), self.dms.is_postgre_sql))
condition_applied = False
if self.where_clause is not None:
sql += ' WHERE %s' % self.where_clause
Expand All @@ -162,18 +163,18 @@ def execute(self, df, start_ts, end_ts, entities=None):
sql += ' WHERE '
else:
sql += ' AND '
sql += "%s <= %s AND %s < %s" % (dbhelper.quotingSqlString(str(start_ts)),
dbhelper.quotingColumnName(self.timestamp_col, self.dms.is_postgre_sql),
dbhelper.quotingColumnName(self.timestamp_col, self.dms.is_postgre_sql),
dbhelper.quotingSqlString(str(end_ts)))
sql += "%s <= %s AND %s < %s" % (dbhelper.quotingSqlString(str(check_sql_injection(start_ts))),
dbhelper.quotingColumnName(check_sql_injection(self.timestamp_col), self.dms.is_postgre_sql),
dbhelper.quotingColumnName(check_sql_injection(self.timestamp_col), self.dms.is_postgre_sql),
dbhelper.quotingSqlString(str(check_sql_injection(end_ts))))
condition_applied = True
if entities is not None:
if not condition_applied:
sql += ' WHERE '
else:
sql += ' AND '
sql += "%s IN (%s)" % (dbhelper.quotingColumnName(self.id_col, self.dms.is_postgre_sql),
', '.join([dbhelper.quotingSqlString(ent) for ent in entities]))
sql += "%s IN (%s)" % (dbhelper.quotingColumnName(check_sql_injection(self.id_col), self.dms.is_postgre_sql),
', '.join([dbhelper.quotingSqlString(check_sql_injection_extended(ent)) for ent in entities]))

self.parse_dates.add(key_timestamp)
requested_col_names = self.names + [key_id, key_timestamp]
Expand Down
11 changes: 6 additions & 5 deletions iotfunctions/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy import (MetaData, Table)

from . import dbhelper
from .dbhelper import check_sql_injection
from .exceptions import StageException, DataWriterException
from .util import MessageHub, asList
from . import metadata as md
Expand Down Expand Up @@ -271,7 +272,7 @@ def create_upsert_statement(self, tableName, grain):
joinExtension = ''
sourceExtension = ''
for dimension in dimensions:
quoted_dimension = dbhelper.quotingColumnName(dimension)
quoted_dimension = dbhelper.quotingColumnName(check_sql_injection(dimension))
colExtension += ', ' + quoted_dimension
parmExtension += ', ?'
joinExtension += ' AND TARGET.' + quoted_dimension + ' = SOURCE.' + quoted_dimension
Expand All @@ -284,7 +285,7 @@ def create_upsert_statement(self, tableName, grain):
"UPDATE SET TARGET.VALUE_B = SOURCE.VALUE_B, TARGET.VALUE_N = SOURCE.VALUE_N, TARGET.VALUE_S = SOURCE.VALUE_S, TARGET.VALUE_T = SOURCE.VALUE_T, TARGET.LAST_UPDATE = SOURCE.LAST_UPDATE "
"WHEN NOT MATCHED THEN "
"INSERT (KEY%s, VALUE_B, VALUE_N, VALUE_S, VALUE_T, LAST_UPDATE) VALUES (SOURCE.KEY%s, SOURCE.VALUE_B, SOURCE.VALUE_N, SOURCE.VALUE_S, SOURCE.VALUE_T, CURRENT TIMESTAMP)") % (
dbhelper.quotingSchemaName(self.schema), dbhelper.quotingTableName(tableName), parmExtension,
dbhelper.quotingSchemaName(check_sql_injection(self.schema)), dbhelper.quotingTableName(check_sql_injection(tableName)), parmExtension,
colExtension, joinExtension, colExtension, sourceExtension)

def create_upsert_statement_postgres_sql(self, tableName, grain):
Expand All @@ -305,7 +306,7 @@ def create_upsert_statement_postgres_sql(self, tableName, grain):

for dimension in dimensions:
# Note: the dimension grain need to be in lower case since the table will be created with lowercase column.
quoted_dimension = dbhelper.quotingColumnName(dimension.lower(), self.is_postgre_sql)
quoted_dimension = dbhelper.quotingColumnName(check_sql_injection(dimension.lower()), self.is_postgre_sql)
colExtension += ', ' + quoted_dimension
parmExtension += ', %s'

Expand Down Expand Up @@ -337,8 +338,8 @@ def __init__(self, dms, alerts=None, data_item_names=None, **kwargs):
except AttributeError:
self.entity_type_name = dms.entity_type

self.quoted_schema = dbhelper.quotingSchemaName(dms.default_db_schema, self.dms.is_postgre_sql)
self.quoted_table_name = dbhelper.quotingTableName(self.ALERT_TABLE_NAME, self.dms.is_postgre_sql)
self.quoted_schema = dbhelper.quotingSchemaName(check_sql_injection(dms.default_db_schema), self.dms.is_postgre_sql)
self.quoted_table_name = dbhelper.quotingTableName(check_sql_injection(self.ALERT_TABLE_NAME), self.dms.is_postgre_sql)
self.alert_to_kpi_input_dict = dict()

# Requirement: alerts_to_msg_hub must be a subset of alerts_to_db because the alerts in data base are exploited
Expand Down
Loading