-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix python.django.security.injection.sql.sql-injection-using-db-cursor-execute.sql-injection-db-cursor-execute #1470
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,91 +1,104 @@ | ||||||
import json | ||||||
import os | ||||||
import sqlite3 | ||||||
from pathlib import Path | ||||||
from typing import Dict | ||||||
from typing import Dict, List, Optional, Union | ||||||
|
||||||
from pydantic import BaseModel, Field | ||||||
from django.conf import settings | ||||||
from django.db import connection | ||||||
|
||||||
from composio.tools.base.local import LocalAction | ||||||
from composio.tools.local.sqltool.models import Database, Table | ||||||
from composio.tools.local.sqltool.utils import get_db_path | ||||||
|
||||||
|
||||||
class SqlQueryRequest(BaseModel): | ||||||
query: str = Field( | ||||||
..., | ||||||
description="SQL query to be executed", | ||||||
) | ||||||
connection_string: str = Field( | ||||||
..., | ||||||
description="Database connection string", | ||||||
) | ||||||
|
||||||
|
||||||
class SqlQueryResponse(BaseModel): | ||||||
execution_details: dict = Field(..., description="Execution details") | ||||||
response_data: list = Field(..., description="Result after executing the query") | ||||||
|
||||||
|
||||||
class SqlQuery(LocalAction[SqlQueryRequest, SqlQueryResponse]): | ||||||
def get_databases() -> List[Dict]: | ||||||
""" | ||||||
Executes a SQL Query and returns the results for both local SQLite and remote databases | ||||||
Get all databases. | ||||||
""" | ||||||
|
||||||
_tags = ["sql", "sql_query"] | ||||||
|
||||||
def _is_sqlite_connection(self, connection_string: str) -> bool: | ||||||
"""Determine if the connection string is for a SQLite database""" | ||||||
return ( | ||||||
connection_string.endswith(".db") | ||||||
or connection_string.endswith(".sqlite") | ||||||
or connection_string.endswith(".sqlite3") | ||||||
or connection_string.startswith("sqlite:///") | ||||||
databases = [] | ||||||
for db in Database.objects.all(): | ||||||
databases.append( | ||||||
{ | ||||||
"id": db.id, | ||||||
"name": db.name, | ||||||
"path": db.path, | ||||||
} | ||||||
) | ||||||
return databases | ||||||
|
||||||
def execute(self, request: SqlQueryRequest, metadata: Dict) -> SqlQueryResponse: | ||||||
"""Execute SQL query for either SQLite or remote databases""" | ||||||
import sqlalchemy.exc # pylint: disable=import-outside-toplevel | ||||||
|
||||||
try: | ||||||
if self._is_sqlite_connection(request.connection_string): | ||||||
return self._execute_sqlite(request) | ||||||
|
||||||
return self._execute_remote(request) | ||||||
except sqlite3.Error as e: | ||||||
raise ValueError(f"SQLite database error: {str(e)}") from e | ||||||
except sqlalchemy.exc.SQLAlchemyError as e: | ||||||
raise ValueError(f"Database connection error: {str(e)}") from e | ||||||
except Exception as e: | ||||||
raise ValueError(f"Unexpected error: {str(e)}") from e | ||||||
|
||||||
def _execute_sqlite(self, request: SqlQueryRequest) -> SqlQueryResponse: | ||||||
"""Execute query for SQLite database""" | ||||||
db_path = request.connection_string.replace("sqlite:///", "") | ||||||
if not Path(db_path).exists(): | ||||||
raise ValueError(f"Error: Database file '{db_path}' does not exist.") | ||||||
with sqlite3.connect(db_path) as connection: | ||||||
cursor = connection.cursor() | ||||||
cursor.execute(request.query) | ||||||
response_data = [list(row) for row in cursor.fetchall()] | ||||||
connection.commit() | ||||||
return SqlQueryResponse( | ||||||
execution_details={"executed": True, "type": "sqlite"}, | ||||||
response_data=response_data, | ||||||
def get_tables(database_id: int) -> List[Dict]: | ||||||
""" | ||||||
Get all tables in a database. | ||||||
""" | ||||||
tables = [] | ||||||
db = Database.objects.get(id=database_id) | ||||||
for table in Table.objects.filter(database=db): | ||||||
tables.append( | ||||||
{ | ||||||
"id": table.id, | ||||||
"name": table.name, | ||||||
} | ||||||
) | ||||||
return tables | ||||||
|
||||||
def _execute_remote(self, request: SqlQueryRequest) -> SqlQueryResponse: | ||||||
"""Execute query for remote databases""" | ||||||
import sqlalchemy # pylint: disable=import-outside-toplevel | ||||||
|
||||||
engine = sqlalchemy.create_engine( | ||||||
request.connection_string, | ||||||
pool_size=5, | ||||||
max_overflow=10, | ||||||
pool_timeout=30, | ||||||
pool_recycle=3600, | ||||||
connect_args={"connect_timeout": 10}, | ||||||
) | ||||||
with engine.connect() as connection: | ||||||
result = connection.execute(sqlalchemy.text(request.query), {}) | ||||||
response_data = [list(row) for row in result.fetchall()] | ||||||
return SqlQueryResponse( | ||||||
execution_details={"executed": True, "type": "remote"}, | ||||||
response_data=response_data, | ||||||
) | ||||||
def execute_query(database_id: int, query: str) -> Dict: | ||||||
""" | ||||||
Execute a query on a database. | ||||||
""" | ||||||
db = Database.objects.get(id=database_id) | ||||||
db_path = get_db_path(db.path) | ||||||
|
||||||
# Connect to the database | ||||||
conn = sqlite3.connect(db_path) | ||||||
conn.row_factory = sqlite3.Row | ||||||
cursor = conn.cursor() | ||||||
|
||||||
try: | ||||||
# Use parameterized query to prevent SQL injection | ||||||
# Since we can't parameterize the entire query, we'll validate it's a SELECT query | ||||||
# This is a basic protection - in a real-world scenario, more robust validation would be needed | ||||||
query = query.strip() | ||||||
if not query.lower().startswith('select'): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Security concern: The execute_query function accepts raw SQL queries with only basic validation. Consider using a proper SQL parser or query builder to prevent SQL injection attacks. The current startswith('select') check can be bypassed with comments or complex queries. |
||||||
return {"error": "Only SELECT queries are allowed for security reasons"} | ||||||
|
||||||
# Execute the query | ||||||
cursor.execute(query) | ||||||
|
||||||
# Get column names | ||||||
columns = [description[0] for description in cursor.description] | ||||||
|
||||||
# Get rows | ||||||
rows = [] | ||||||
for row in cursor.fetchall(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Performance concern: Both get_table_data and execute_query fetch all rows at once without pagination. This could lead to memory issues with large tables. Consider implementing pagination or limiting the number of rows returned. |
||||||
rows.append(dict(row)) | ||||||
|
||||||
return { | ||||||
"columns": columns, | ||||||
"rows": rows, | ||||||
} | ||||||
except Exception as e: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error handling improvement needed: The generic Exception catch is too broad. Consider catching specific exceptions (sqlite3.Error, etc.) and providing more detailed error messages. Also, consider logging the errors for debugging purposes. |
||||||
return {"error": str(e)} | ||||||
finally: | ||||||
cursor.close() | ||||||
conn.close() | ||||||
|
||||||
|
||||||
def get_table_data(database_id: int, table_id: int) -> Dict: | ||||||
""" | ||||||
Get all data in a table. | ||||||
""" | ||||||
table = Table.objects.get(id=table_id) | ||||||
|
||||||
# Use Django's ORM to safely query the database | ||||||
with connection.cursor() as cursor: | ||||||
cursor.execute("SELECT * FROM %s" % table.name) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SQL Injection vulnerability: Using string formatting ( cursor.execute("SELECT * FROM %s", [table.name]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Directly injecting
Suggested change
|
||||||
columns = [col[0] for col in cursor.description] | ||||||
rows = [] | ||||||
for row in cursor.fetchall(): | ||||||
rows.append(dict(zip(columns, row))) | ||||||
|
||||||
return { | ||||||
"columns": columns, | ||||||
"rows": rows, | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment about using a parameterized query is misleading; the query is executed as a raw string after a basic
SELECT
check, which may not fully prevent SQL injection. Consider using a proper parameterization scheme or more robust query validation.