|
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 | import sqlparse |
| 5 | +from sqlparse.exceptions import SQLParseError |
5 | 6 | import time |
6 | 7 |
|
7 | 8 |
|
8 | 9 | class TestDoSPrevention: |
9 | 10 | """Test cases to ensure sqlparse is protected against DoS attacks.""" |
10 | 11 |
|
11 | 12 | def test_large_tuple_list_performance(self): |
12 | | - """Test that parsing a large list of tuples doesn't cause DoS.""" |
| 13 | + """Test that parsing a large list of tuples raises SQLParseError.""" |
13 | 14 | # Generate SQL with many tuples (like Django composite primary key queries) |
14 | 15 | sql = ''' |
15 | 16 | SELECT "composite_pk_comment"."tenant_id", "composite_pk_comment"."comment_id" |
16 | 17 | FROM "composite_pk_comment" |
17 | 18 | WHERE ("composite_pk_comment"."tenant_id", "composite_pk_comment"."comment_id") IN (''' |
18 | 19 |
|
19 | | - # Generate 5000 tuples - this would previously cause a hang |
| 20 | + # Generate 5000 tuples - this should trigger MAX_GROUPING_TOKENS |
20 | 21 | tuples = [] |
21 | 22 | for i in range(1, 5001): |
22 | 23 | tuples.append(f"(1, {i})") |
23 | 24 |
|
24 | 25 | sql += ", ".join(tuples) + ")" |
25 | 26 |
|
26 | | - # Test should complete quickly (under 5 seconds) |
27 | | - start_time = time.time() |
28 | | - result = sqlparse.format(sql, reindent=True, keyword_case="upper") |
29 | | - execution_time = time.time() - start_time |
30 | | - |
31 | | - assert execution_time < 5.0, f"Parsing took too long: {execution_time:.2f}s" |
32 | | - assert len(result) > 0, "Result should not be empty" |
33 | | - assert "SELECT" in result.upper(), "SQL should be properly formatted" |
| 27 | + # Should raise SQLParseError due to token limit |
| 28 | + with pytest.raises(SQLParseError, match="Maximum number of tokens exceeded"): |
| 29 | + sqlparse.format(sql, reindent=True, keyword_case="upper") |
34 | 30 |
|
35 | 31 | def test_deeply_nested_groups_limited(self): |
36 | | - """Test that deeply nested groups don't cause stack overflow.""" |
| 32 | + """Test that deeply nested groups raise SQLParseError.""" |
37 | 33 | # Create deeply nested parentheses |
38 | 34 | sql = "SELECT " + "(" * 200 + "1" + ")" * 200 |
39 | 35 |
|
40 | | - # Should not raise RecursionError |
41 | | - result = sqlparse.format(sql, reindent=True) |
42 | | - assert "SELECT" in result |
43 | | - assert "1" in result |
| 36 | + # Should raise SQLParseError due to depth limit |
| 37 | + with pytest.raises(SQLParseError, match="Maximum grouping depth exceeded"): |
| 38 | + sqlparse.format(sql, reindent=True) |
44 | 39 |
|
45 | 40 | def test_very_large_token_list_limited(self): |
46 | | - """Test that very large token lists are handled gracefully.""" |
| 41 | + """Test that very large token lists raise SQLParseError.""" |
47 | 42 | # Create a SQL with many identifiers |
48 | 43 | identifiers = [] |
49 | 44 | for i in range(15000): # More than MAX_GROUPING_TOKENS |
50 | 45 | identifiers.append(f"col{i}") |
51 | 46 |
|
52 | 47 | sql = f"SELECT {', '.join(identifiers)} FROM table1" |
53 | 48 |
|
54 | | - # Should complete without hanging |
55 | | - start_time = time.time() |
56 | | - result = sqlparse.format(sql, reindent=True) |
57 | | - execution_time = time.time() - start_time |
58 | | - |
59 | | - assert execution_time < 10.0, f"Parsing took too long: {execution_time:.2f}s" |
60 | | - assert "SELECT" in result |
61 | | - assert "FROM" in result |
| 49 | + # Should raise SQLParseError due to token limit |
| 50 | + with pytest.raises(SQLParseError, match="Maximum number of tokens exceeded"): |
| 51 | + sqlparse.format(sql, reindent=True) |
62 | 52 |
|
63 | 53 | def test_normal_sql_still_works(self): |
64 | 54 | """Test that normal SQL still works correctly after DoS protections.""" |
|
0 commit comments