Skip to content

Commit

Permalink
Amazon athena resolve column not found exception (#1673)
Browse files Browse the repository at this point in the history
Signed-off-by: DerekRushton <[email protected]>
  • Loading branch information
DerekRushton authored Apr 9, 2024
1 parent 3bae101 commit 84f074a
Show file tree
Hide file tree
Showing 5 changed files with 364 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@

from venv import logger
import regex as re
import time
from stix_shifter_modules.aws_athena.stix_transmission import status_connector

class PostQueryConnectorErrorHandler():
async def check_status_for_missing_column(client, search_id, query) -> None:
"""Creates a status check loop to see if the query fails with a column doesn't exist exception. If it does, return the query with the offending column removed.
If it does not, return with True
Args:
client (RestApiClientAsync)
search_id (String): For each query sent to Athena, a job is created that runs in till it finishes. This ID is used to find the job and access it's results.
query (String): The query that will be modified if a missing column error occurs.
Returns:
String: Returns either a modified query, or a special key that can be used to know that the query was successful and to stop.
"""
status = status_connector.StatusConnector(client)
column_to_delete = ""

#Wait ten seconds if the status is RUNNING. Exits early if the status is not RUNNING.
for i in range(10):
time.sleep(1)
status_response = await status.create_status_connection(search_id)
#Checks if there is a message that can be read and if that message matches the column not found message.
if(status_response != None and "message" in status_response):
match = re.search(f"Column '(.*)' cannot be resolved", status_response["message"])
if(match):
#If there is a match, return the column name.
column_to_delete = match.group(1)
break
elif(status_response != None and "status" in status_response and status_response["status"] != "RUNNING"):
#If there is no match and the status is not running, than stop trying and exit with the success message.
break

if(column_to_delete != ""):
return PostQueryConnectorErrorHandler._remove_invalid_column_table(column_to_delete, query)
else:
#May not always be successful, just that no column error is occurring.
return True


def _remove_invalid_column_table(column_to_remove, query):
""" Uses regex to iterate over the query and replace all comparison operations using the invalid column
Args:
column_to_remove (string): This is the name of the column that should be removed from the query.
query (String): This is the current query that is failing. It should contain a column that needs to be removed.
Returns:
String : Returns the modified query with the column comparisons replaced with either TRUE or FALSE.
"""
#These are the possible forms for a left expression in a comparison
COLUMN_NAME_PATTERN="([\\w\\d]*(?:\\.[\\w\\d]+)*)"
VARCHAR_CAST_LEFT_EXPRESSION = f"CAST\\({COLUMN_NAME_PATTERN} as varchar\\)"
REAL_CAST_LEFT_EXPRESSION = f"CAST\\({COLUMN_NAME_PATTERN} as real\\)"
LOWER_LEFT_EXPRESSION = f"lower\\({COLUMN_NAME_PATTERN}\\)"
#These are the possible forms for a right expression in a comparison
LOWER_RIGHT_EXPRESSION="lower\\(.*?\\)"
BRACKET_RIGHT_EXPRESSION="\\(.*?\\)"
QUOTE_RIGHT_EXPRESSION="\\\'.*?\\\'"

#These are the possible forms for an operator
OPERATORS=">|>=|<|<=|!=|LIKE|IN|="

#The general format for the pattern is {left expression} {operator} {right expression}. In order to get the column name, it needs to match on the left expression.
standard_pattern = f"((?:{VARCHAR_CAST_LEFT_EXPRESSION}|{REAL_CAST_LEFT_EXPRESSION}|{LOWER_LEFT_EXPRESSION}|{COLUMN_NAME_PATTERN}) ({OPERATORS}) (?:{LOWER_RIGHT_EXPRESSION}|{BRACKET_RIGHT_EXPRESSION}|{QUOTE_RIGHT_EXPRESSION}))"
#Match Expressions are unique. They act as a function, for example regexp(string, pattern). Standard pattern is like ID = 5 format.
match_pattern = f"((REGEXP_LIKE)\\((?:{VARCHAR_CAST_LEFT_EXPRESSION}|{REAL_CAST_LEFT_EXPRESSION}|{LOWER_LEFT_EXPRESSION}|{COLUMN_NAME_PATTERN}|{QUOTE_RIGHT_EXPRESSION}), '.*?'\\))"

logger.debug(f"The failing column name : {column_to_remove}")
logger.debug(f"Current query : {query}")
logger.debug(f"Attempt to match the standard comparison pattern : {standard_pattern}")

#Matches against the standard pattern, this gets all of the comparison expressions in the query except for TRUE/FALSE.
#It checks each comparison against the offending column and replaces any that match with TRUE or FALSE
all_comparison_strings = re.findall(standard_pattern, query, flags=re.IGNORECASE)
if (len(all_comparison_strings) > 0):
for comparison in all_comparison_strings:
filtered_comparison_list = [item for item in comparison if item != ""]
if(column_to_remove in filtered_comparison_list[1]):
logger.debug(f"The following comparison expression will be replaced (standard): {filtered_comparison_list[0]}" )

#If the column doesn't exist and the comparison is a != it will always be true.
#In the case of <,>,>=,<=, the number doesn't exist, thus it is false.
#In the case of =, it will never = the value, thus it must be false.
#In the case of IN or LIKE, the value will always resolve to FALSE because something, can't be in nothing.
#Match will always resolve to false. This one may be true if the match is impossible (that however is a weird edge case), thus false.
if("!=" in filtered_comparison_list[2]):
query = query.replace(comparison[0], f"TRUE")
else:
query = query.replace(comparison[0], f"FALSE")

#Matches against the match pattern, this gets all of the comparison expressions in the query except for TRUE/FALSE.
#It checks each comparison against the offending column and replaces any that match with TRUE or FALSE
all_match_strings = re.findall(match_pattern, query, flags=re.IGNORECASE)
if (len(all_match_strings) > 0):
for comparison in all_match_strings:
filtered_comparison_list = [item for item in comparison if item != ""]
if(column_to_remove in filtered_comparison_list[2]):
logger.debug(f"The following comparison expression will be replaced (match) : {filtered_comparison_list[0]}" )

#If the column doesn't exist and the comparison is a != it will always be true.
#In the case of <,>,>=,<=, the number doesn't exist, thus it is false.
#In the case of =, it will never = the value, thus it must be false.
#In the case of IN or LIKE, the value will always resolve to FALSE because something, can't be in nothing.
#Match will always resolve to false. This one may be true if the match is impossible (that however is a weird edge case), thus false.
query = query.replace(comparison[0], f"FALSE")

return query
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from venv import logger
from stix_shifter_modules.aws_athena.stix_transmission import status_connector
from stix_shifter_modules.aws_athena.stix_transmission.post_query_connector_error_handling import PostQueryConnectorErrorHandler
from stix_shifter_utils.modules.base.stix_transmission.base_connector import BaseQueryConnector
from stix_shifter_utils.utils.error_response import ErrorResponder
import json
Expand Down Expand Up @@ -25,6 +28,7 @@ def __init__(self, client, connection):
self.client = client
self.connection = connection
self.connector = __name__.split('.')[1]
self.total_try_count = 0

async def create_query_connection(self, query):
"""
Expand All @@ -47,6 +51,7 @@ async def create_query_connection(self, query):
raise InvalidParameterException("{} is required for {} query operation".format(config,
query_service_type))
table_config = self.connection[config_details[0]] + '."' + self.connection[config_details[1]] + '"'

other_tables = ''
findall = re.finditer("##UNNEST.*?##", query[query_service_type])
if findall:
Expand All @@ -57,36 +62,72 @@ async def create_query_connection(self, query):
other_tables += ' %s%s%s ' % ('LEFT JOIN ', match_str.replace('##', ''), ' ON TRUE ')

if query_service_type == 'ocsf':
columns = await self.column_list(self.connection[config_details[1]])
columns = await self.column_list(self.connection[config_details[0]], self.connection[config_details[1]])
column_cast = []
for column in columns:
column_cast.append("CAST(%s as JSON) AS %s" % (column, column))

select_statement = "SELECT %s FROM %s%s WHERE " % (", ".join(column_cast), table_config, other_tables)
else:
select_statement = "SELECT %s.* FROM %s%s WHERE " % (table_config, table_config, other_tables)

#self.get_list_of_columns_and_rows(query[query_service_type])
#await self.row_list(self.connection[config_details[0]], self.connection[config_details[1]])
# for multiple observation operators union and intersect, select statement will be added
if 'UNION' in query[query_service_type] or 'INTERSECT' in query[query_service_type]:
query_string = re.sub(r'\(\(', '(({}'.format(select_statement), query[query_service_type], 1)
query = query_string.replace('UNION (', 'UNION ({}'.format(select_statement)).\
query_with_select = query_string.replace('UNION (', 'UNION ({}'.format(select_statement)).\
replace('INTERSECT (', 'INTERSECT ({}'.format(select_statement))
else:
query = select_statement + query[query_service_type]
query_with_select = select_statement + query[query_service_type]
result_config = self.get_result_config()
query_args = {"QueryString": query, "ResultConfiguration": result_config}
response_dict = await self.client.makeRequest('athena', 'start_query_execution', **query_args)
return_obj['success'] = True
return_obj['search_id'] = response_dict['QueryExecutionId'] + ":" + query_service_type


return await self.query_api(query_with_select, result_config, query_service_type, return_obj)

except Exception as ex:
response_dict['__type'] = ex.__class__.__name__
response_dict['message'] = ex
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
return return_obj

async def column_list(self, table):
async def query_api(self, query, result_config, query_service_type, return_obj):
"""Creates a query job and ensures that none of the columns requested are missing from the query.
Args:
query (String): The original query without modification.
result_config (Dict): Query configuration.
query_service_type (String): The type of query ("OCSF", "VPCFlow, etc)
return_obj (Dict): Contains the metadata about the request.
Returns:
Dict: The metadata about the request such as the query/search ID and the status.
"""
logger.debug(f"The current query is : {query}")

#Creates the initial query job.
query_args = {"QueryString": query, "ResultConfiguration": result_config}
response_dict = await self.client.makeRequest('athena', 'start_query_execution', **query_args)
return_obj['search_id'] = response_dict['QueryExecutionId'] + ":" + query_service_type

modified_query = dict()
modified_query = await PostQueryConnectorErrorHandler.check_status_for_missing_column(self.client, return_obj['search_id'], query)
#If the query is successful (or the exception isn't column related) than it's considered a success and exits.
#If 10 columns are not found or it fails to replace a column 10 times, than it exits (to prevent endless loops).
#If the query is not successful, than it will retry the query with the modified query.
if(modified_query == True or self.total_try_count > 10):
logger.debug(f"The number of attempts to remove missing columns was {self.total_try_count}")
if(self.total_try_count >= 10):
logger.warn("There were 10 failed exceptions related to columns. This could be because there were more invalid columns than 10, \
or alternatively that the replacement failed to remove the offending column.")
return_obj['success'] = True
return return_obj
else:
self.total_try_count = self.total_try_count + 1
return await self.query_api(modified_query, result_config, query_service_type, return_obj)

async def column_list(self, database, table):
columns = []
query = "SELECT column_name FROM information_schema.columns WHERE table_name = '%s'" % table
query = f"SELECT column_name,data_type FROM information_schema.columns WHERE table_name = '{table}' AND table_schema = '{database}'"
result_config = self.get_result_config()
query_args = {"QueryString": query, "ResultConfiguration": result_config}
response_dict = await self.client.makeRequest('athena', 'start_query_execution', **query_args)
Expand Down Expand Up @@ -133,6 +174,13 @@ async def column_list(self, table):

return columns

def get_rows_from_response(self, data, parent, row_list):
data = data.casefold()
##Root of the tree.
if(data.startswith("row(".casefold()) or data.startwith("array(".casefold())):
remainder = data[data.find("("):data.rfind(")")]
self.get_rows_from_response(remainder, parent, row_list)

def get_result_config(self):
"""
Output location and encryption configuration are added
Expand All @@ -152,3 +200,4 @@ def get_result_config(self):
output_location = 's3://' + path
result_config = {'OutputLocation': output_location}
return result_config

Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ async def create_status_connection(self, search_id):
return_obj['status'] = self._getstatus(response_dict.get('QueryExecution', 'FAILED').
get('Status', 'FAILED').
get('State', 'FAILED'))
if (response_dict != None and "QueryExecution" in response_dict and "Status" in response_dict["QueryExecution"] and "StateChangeReason" in response_dict["QueryExecution"]["Status"]):
return_obj['message'] = response_dict["QueryExecution"]["Status"]["StateChangeReason"]
if return_obj['status'] == 'COMPLETED':
return_obj['progress'] = 100
elif return_obj['status'] == 'RUNNING':
Expand Down
Loading

0 comments on commit 84f074a

Please sign in to comment.