Skip to content

Commit 6bcefb0

Browse files
authored
Input sanitizer for benchmark result renderer (#2594)
Since `DataFrame.query` is potentially vulnerable we limit the possible filter input to a fixed grammar that is roughly like this: ``` expr = left op right left = ( expr ) | literal right = ( expr ) | literal op = in | >= | < | <= | == | and | or ``` this will give us boolean operations and basic comparisons. Note that `literal` can be arbitrary python literals (strings, tuples, ...).
1 parent 1f4143a commit 6bcefb0

File tree

4 files changed

+147
-4
lines changed

4 files changed

+147
-4
lines changed

method_comparison/__init__.py

Whitespace-only changes.

method_comparison/app.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import plotly.express as px
2222
import plotly.graph_objects as go
2323
from processing import load_df
24+
from sanitizer import parse_and_filter
2425

2526

2627
metric_preferences = {
@@ -246,7 +247,8 @@ def update_on_task(task_name, current_filter):
246247
filtered = filter_data(task_name, new_models[0] if new_models else "", df)
247248
if current_filter.strip():
248249
try:
249-
df_queried = filtered.query(current_filter)
250+
mask = parse_and_filter(filtered, current_filter)
251+
df_queried = filtered[mask]
250252
if not df_queried.empty:
251253
filtered = df_queried
252254
except Exception:
@@ -262,7 +264,8 @@ def update_on_model(task_name, model_id, current_filter):
262264
filtered = filter_data(task_name, model_id, df)
263265
if current_filter.strip():
264266
try:
265-
filtered = filtered.query(current_filter)
267+
mask = parse_and_filter(filtered, current_filter)
268+
filtered = filtered[mask]
266269
except Exception:
267270
pass
268271
return filtered
@@ -275,7 +278,8 @@ def update_pareto_plot_and_summary(task_name, model_id, metric_x, metric_y, curr
275278
filtered = filter_data(task_name, model_id, df)
276279
if current_filter.strip():
277280
try:
278-
filtered = filtered.query(current_filter)
281+
mask = parse_and_filter(filtered, current_filter)
282+
filtered = filtered[mask]
279283
except Exception as e:
280284
return generate_pareto_plot(filtered, metric_x, metric_y), f"Filter error: {e}"
281285

@@ -295,7 +299,8 @@ def apply_filter(filter_query, task_name, model_id, metric_x, metric_y):
295299
filtered = filter_data(task_name, model_id, df)
296300
if filter_query.strip():
297301
try:
298-
filtered = filtered.query(filter_query)
302+
mask = parse_and_filter(filtered, filter_query)
303+
filtered = filtered[mask]
299304
except Exception as e:
300305
# Update the table, plot, and summary even if there is a filter error.
301306
return (

method_comparison/sanitizer.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import ast
2+
3+
import pandas as pd
4+
5+
6+
def _evaluate_node(df, node):
7+
"""
8+
Recursively evaluates an AST node to generate a pandas boolean mask.
9+
"""
10+
# Base Case: A simple comparison like 'price > 100'
11+
if isinstance(node, ast.Compare):
12+
if not isinstance(node.left, ast.Name):
13+
raise ValueError("Left side of comparison must be a column name.")
14+
col = node.left.id
15+
if col not in df.columns:
16+
raise ValueError(f"Column '{col}' not found in DataFrame.")
17+
18+
if len(node.ops) > 1:
19+
raise ValueError("Chained comparisons like '10 < price < 100' are not supported.")
20+
21+
op_node = node.ops[0]
22+
val_node = node.comparators[0]
23+
try:
24+
value = ast.literal_eval(val_node)
25+
except ValueError:
26+
raise ValueError("Right side of comparison must be a literal (number, string, list).")
27+
28+
operator_map = {
29+
ast.Gt: lambda c, v: df[c] > v,
30+
ast.GtE: lambda c, v: df[c] >= v,
31+
ast.Lt: lambda c, v: df[c] < v,
32+
ast.LtE: lambda c, v: df[c] <= v,
33+
ast.Eq: lambda c, v: df[c] == v,
34+
ast.NotEq: lambda c, v: df[c] != v,
35+
ast.In: lambda c, v: df[c].isin(v),
36+
ast.NotIn: lambda c, v: ~df[c].isin(v)
37+
}
38+
op_type = type(op_node)
39+
if op_type not in operator_map:
40+
raise ValueError(f"Unsupported operator '{op_type.__name__}'.")
41+
return operator_map[op_type](col, value)
42+
43+
# Recursive Step: "Bitwise" operation & and | (the same as boolean operations)
44+
elif isinstance(node, ast.BinOp):
45+
if isinstance(node.op, ast.BitOr):
46+
return _evaluate_node(df, node.left) | _evaluate_node(df, node.right)
47+
elif isinstance(node.op, ast.BitAnd):
48+
return _evaluate_node(df, node.left) & _evaluate_node(df, node.right)
49+
50+
# Recursive Step: A boolean operation like '... and ...' or '... or ...'
51+
elif isinstance(node, ast.BoolOp):
52+
op_type = type(node.op)
53+
# Evaluate the first value in the boolean expression
54+
result = _evaluate_node(df, node.values[0])
55+
# Combine it with the rest of the values based on the operator
56+
for i in range(1, len(node.values)):
57+
if op_type is ast.And or op_type is ast.BitAnd:
58+
result &= _evaluate_node(df, node.values[i])
59+
elif op_type is ast.Or or op_type is ast.BitOr:
60+
result |= _evaluate_node(df, node.values[i])
61+
return result
62+
63+
elif isinstance(node, ast.UnaryOp):
64+
if not isinstance(node.op, ast.Not):
65+
raise ValueError("Only supported unary op is negation.")
66+
return ~_evaluate_node(df, node.operand)
67+
68+
# If the node is not a comparison or boolean op, it's an unsupported expression type
69+
else:
70+
raise ValueError(f"Unsupported expression type: {type(node).__name__}")
71+
72+
73+
def parse_and_filter(df, filter_str):
74+
"""
75+
Filters a pandas DataFrame using a string expression parsed by AST.
76+
This is done to avoid the security vulnerables that `DataFrame.query`
77+
brings (arbitrary code execution).
78+
79+
Args:
80+
df (pd.DataFrame): The DataFrame to filter.
81+
filter_str (str): A string representing a filter expression.
82+
e.g., "price > 100 and stock < 50"
83+
Supported operators: >, >=, <, <=, ==, !=, in, not in, and, or.
84+
85+
Returns:
86+
pd.Series: A boolean Series representing the filter mask.
87+
"""
88+
if not filter_str:
89+
return pd.Series([True] * len(df), index=df.index)
90+
91+
try:
92+
# 'eval' mode ensures the source is a single expression.
93+
tree = ast.parse(filter_str, mode='eval')
94+
expression_node = tree.body
95+
except (SyntaxError, ValueError) as e:
96+
raise ValueError(f"Invalid filter syntax: {e}")
97+
98+
# The recursive evaluation starts here
99+
mask = _evaluate_node(df, expression_node)
100+
return mask
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from .sanitizer import parse_and_filter
5+
6+
7+
@pytest.fixture
8+
def df_products():
9+
data = {
10+
'product_id': [101, 102, 103, 104, 105, 106],
11+
'category': ['Electronics', 'Books', 'Electronics', 'Home Goods', 'Books', 'Electronics'],
12+
'price': [799.99, 19.99, 49.50, 120.00, 24.99, 150.00],
13+
'stock': [15, 300, 50, 25, 150, 0]
14+
}
15+
return pd.DataFrame(data)
16+
17+
18+
def test_exploit_fails(df_products):
19+
with pytest.raises(ValueError) as e:
20+
mask1 = parse_and_filter(df_products,
21+
"""price < 50 and @os.system("/bin/echo password")""")
22+
assert 'Invalid filter syntax' in str(e)
23+
24+
25+
@pytest.mark.parametrize('expression,ids', [
26+
("price < 50", [102, 103, 105]),
27+
("product_id in [101, 102]", [101, 102]),
28+
("price < 50 and category == 'Electronics'", [103]),
29+
("stock < 100 or category == 'Home Goods'", [101, 103, 104, 106]),
30+
("(price > 100 and stock < 20) or category == 'Books'", [101, 102, 105, 106]),
31+
("not (price > 50 or stock > 100)", [103]),
32+
("not price > 50", [102, 103, 105]),
33+
("(price < 50) & (category == 'Electronics')", [103]),
34+
("(stock < 100) | (category == 'Home Goods')", [101, 103, 104, 106]),
35+
])
36+
def test_operations(df_products, expression, ids):
37+
mask1 = parse_and_filter(df_products, expression)
38+
assert sorted(df_products[mask1].product_id) == sorted(ids)

0 commit comments

Comments
 (0)