Skip to content

[CONTRIB] Normalize schema and table names. #10763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dd16421
https://support.greatexpectations.io/hc/en-us/requests/210
stejin Dec 11, 2024
de0ac29
Merge branch 'great-expectations:develop' into stejin_01
stejin Dec 11, 2024
ab9b4c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2024
26ac112
Merge
stejin Jan 10, 2025
e31ee9c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
419bafd
Updated test_connection
stejin Apr 3, 2025
bc45eca
Merge branch 'develop' into stejin_01
stejin Apr 3, 2025
f474861
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2025
703e443
Merge branch 'stejin_01' of github.com:stejin/great_expectations into…
stejin Apr 3, 2025
9477b9a
Fixed SQL Server SQL expectation 'batch' error
May 23, 2025
c49270b
Enabled regex expectations for SQL Server
May 26, 2025
689b088
SQL Server view compatibility fix
fredwang1012 May 27, 2025
0a858ba
Restored functionality that was lost from last commit
fredwang1012 May 28, 2025
17a5f54
Restored some accidentally removed functionality in previous commits
fredwang1012 May 28, 2025
176b618
Fixed an issue with {batch} problem
fredwang1012 May 29, 2025
be0475d
Some further SQL Server compatibility improvements
fredwang1012 May 30, 2025
3f3b6cb
Schema now wrapped in [brackets] in SQL query
fredwang1012 May 30, 2025
621568e
Fixed some comments
fredwang1012 Jun 2, 2025
c98849a
Removed some unnecessary bits
fredwang1012 Jun 2, 2025
120c8bb
Merge branch 'great-expectations:develop' into stejin_01
stejin Jun 2, 2025
44f8d4e
Merge branch 'great-expectations:develop' into stejin_01
stejin Jun 2, 2025
2906a7f
Refactored SQL Server Regex
fredwang1012 Jun 2, 2025
b2e9e1f
Merge branch 'stejin_01' into fredwang1012
fredwang1012 Jun 2, 2025
cdbd2f7
Merge pull request #4 from fredwang1012/fredwang1012
stejin Jun 2, 2025
4ca3e2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2025
77bce02
Regex fix
fredwang1012 Jun 3, 2025
6673dc8
Fixed regex
fredwang1012 Jun 3, 2025
63e5119
Merge pull request #5 from fredwang1012/fredwang1012
stejin Jun 3, 2025
2cc340e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 137 additions & 43 deletions great_expectations/datasource/fluent/sql_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@
)

if TYPE_CHECKING:
from sqlalchemy.sql import quoted_name # noqa: TID251 # type-checking only

# We re-import sqlalchemy here to make type-checking and our compatability layer
# play nice with one another
from great_expectations.compatibility import sqlalchemy
Expand All @@ -97,6 +95,7 @@
LOGGER: Final[logging.Logger] = logging.getLogger(__name__)

DEFAULT_QUOTE_CHARACTERS: Final[Tuple[str, str]] = ('"', "'")
MSSQL_BRACKET_CHARACTERS: Final[Tuple[str, str]] = ("[", "]")


@overload
Expand All @@ -113,9 +112,12 @@ def to_lower_if_not_quoted(
) -> str | None:
"""
Convert a string to lowercase if it is not enclosed in quotes.
Also considers MSSQL brackets as quotes.
"""
if not value:
return value

# Check standard quotes
for char in quote_characters:
if value.startswith(char) and value.endswith(char):
LOGGER.warning(
Expand All @@ -124,6 +126,15 @@ def to_lower_if_not_quoted(
" May cause sqlalchemy case-sensitivity issues."
)
return value

# Check MSSQL brackets
if value.startswith("[") and value.endswith("]"):
LOGGER.warning(
f"The {value} string is bracketed by MSSQL brackets,"
" so it will not be converted to lowercase."
)
return value

LOGGER.info(f"Setting {value} to lowercase to ensure sqlalchemy case-insensitivity.")
return value.lower()

Expand Down Expand Up @@ -998,51 +1009,109 @@ class TableAsset(_SQLAsset):

# Instance fields
type: Literal["table"] = "table"
# TODO: quoted_name or str
table_name: str = pydantic.Field(
table_name: Any = pydantic.Field( # Any because validator may transform to quoted_name
"",
description="Name of the SQL table. Will default to the value of `name` if not provided.",
)
schema_name: Optional[str] = None
schema_name: Optional[Any] = None # Any because validator may transform to quoted_name

@property
def qualified_name(self) -> str:
return f"{self.schema_name}.{self.table_name}" if self.schema_name else self.table_name
return f"{self.schema_name}.{self.table_name}" if self.schema_name else str(self.table_name)

@pydantic.validator("table_name", pre=True, always=True)
def _default_table_name(cls, table_name: str, values: dict, **kwargs) -> str:
if not (validated_table_name := table_name or values.get("name")):
raise ValueError( # noqa: TRY003 # FIXME CoP
"table_name cannot be empty and should default to name if not provided"
)

return validated_table_name

@pydantic.validator("table_name")
def _resolve_quoted_name(cls, table_name: str) -> str | quoted_name:
table_name_is_quoted: bool = cls._is_bracketed_by_quotes(table_name)

# We reimport sqlalchemy from our compatability layer because we make
# quoted_name a top level import there.
def _resolve_quoted_name(cls, table_name: str) -> Any: # Returns str or quoted_name
"""Resolve quoted names and handle MSSQL bracket notation."""
from great_expectations.compatibility import sqlalchemy

if sqlalchemy.quoted_name: # type: ignore[truthy-function] # FIXME CoP
if isinstance(table_name, sqlalchemy.quoted_name):
return table_name
# If it's already a quoted_name, return as-is
if sqlalchemy.quoted_name and isinstance(table_name, sqlalchemy.quoted_name):
return table_name

# Check if the table name is quoted/bracketed (including MSSQL brackets)
table_name_is_quoted = cls._is_bracketed_by_quotes(table_name)

if sqlalchemy.quoted_name: # type: ignore[truthy-function]
if table_name_is_quoted:
# https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.quoted_name.quote
# Remove the quotes and add them back using the sqlalchemy.quoted_name function
# TODO: We need to handle nested quotes
table_name = table_name.strip("'").strip('"')

return sqlalchemy.quoted_name(
value=table_name,
quote=table_name_is_quoted,
)
# Handle different quote types
if table_name.startswith("[") and table_name.endswith("]"):
# MSSQL brackets - strip and mark as quoted
raw_name = table_name[1:-1]
else:
# Standard quotes - strip them
raw_name = table_name.strip("'").strip('"')

return sqlalchemy.quoted_name(value=raw_name, quote=True)

# Check if MSSQL bracket notation is needed based on content
if cls._needs_mssql_brackets(table_name):
return sqlalchemy.quoted_name(value=table_name, quote=True)

return table_name

@pydantic.validator("schema_name", pre=True)
def _resolve_schema_quoted_name(cls, schema_name: Optional[str]) -> Optional[Any]:
"""Resolve quoted names for schema and handle MSSQL bracket notation."""
if schema_name is None:
return None

from great_expectations.compatibility import sqlalchemy

# If it's already a quoted_name, return as-is
if sqlalchemy.quoted_name and isinstance(schema_name, sqlalchemy.quoted_name):
return schema_name

# Check if the schema name is quoted/bracketed
schema_name_is_quoted = cls._is_bracketed_by_quotes(schema_name)

if sqlalchemy.quoted_name: # type: ignore[truthy-function]
if schema_name_is_quoted:
# Handle different quote types
if schema_name.startswith("[") and schema_name.endswith("]"):
# MSSQL brackets - strip and mark as quoted
raw_name = schema_name[1:-1]
else:
# Standard quotes - strip them
raw_name = schema_name.strip("'").strip('"')

return sqlalchemy.quoted_name(value=raw_name, quote=True)

# ALWAYS use quoted_name for MSSQL schemas to force brackets
# This will make SQLAlchemy use brackets in the generated SQL
return sqlalchemy.quoted_name(value=schema_name, quote=True)

return schema_name

@staticmethod
def _needs_mssql_brackets(name: str) -> bool:
"""
Returns True if the name requires brackets in MSSQL.

MSSQL requires brackets for identifiers that:
- Start with a number
- Contain spaces, hyphens, dots, or other special characters
- Are reserved keywords
"""
import re

# Check if name starts with a number
if re.match(r"^\d", name):
return True

# Check if name contains special characters that need escaping
if re.search(r"[.\s\-#@]", name):
return True

return False

@override
def test_connection(self) -> None:
"""Test the connection for the TableAsset.
Expand All @@ -1054,17 +1123,32 @@ def test_connection(self) -> None:
engine: sqlalchemy.Engine = datasource.get_engine()
inspector: sqlalchemy.Inspector = sa.inspect(engine)

if self.schema_name and self.schema_name not in inspector.get_schema_names():
raise TestConnectionError( # noqa: TRY003 # FIXME CoP
f'Attempt to connect to table: "{self.qualified_name}" failed because the schema '
f'"{self.schema_name}" does not exist.'
)
available_schemas = inspector.get_schema_names()

if self.schema_name:
schema_to_check = str(self.schema_name)

# For MSSQL, do case-insensitive comparison since SQL Server is case-insensitive
if engine.dialect.name.lower() == "mssql":
# Case-insensitive comparison
schema_exists = any(
schema.lower() == schema_to_check.lower() for schema in available_schemas
)
else:
# For other databases, use the existing logic
schema_exists = schema_to_check in map(to_lower_if_not_quoted, available_schemas)

if not schema_exists:
raise TestConnectionError( # noqa: TRY003 # FIXME CoP
f'Attempt to connect to table: "{self.qualified_name}" failed because the schema '
f'"{self.schema_name}" does not exist.'
)

try:
with engine.connect() as connection:
table = sa.table(self.table_name, schema=self.schema_name)
# don't need to fetch any data, just want to make sure the table is accessible
connection.execute(sa.select(1, table).limit(1))
# Use as_selectable to get properly quoted table
selectable = self.as_selectable()
connection.execute(sa.select(1).select_from(selectable).limit(1))
except Exception as query_error:
LOGGER.info(f"{self.name} `.test_connection()` query failed: {query_error!r}")
raise TestConnectionError( # noqa: TRY003 # FIXME CoP
Expand All @@ -1078,15 +1162,17 @@ def as_selectable(self) -> sqlalchemy.Selectable:

This can be used in a from clause for a query against this data.
"""
# The table_name and schema_name already have proper quoting applied by the validators
return sa.table(self.table_name, schema=self.schema_name)

@override
def _create_batch_spec_kwargs(self) -> dict[str, Any]:
# Convert to string for the batch spec
return {
"type": "table",
"data_asset_name": self.name,
"table_name": self.table_name,
"schema_name": self.schema_name,
"table_name": str(self.table_name),
"schema_name": str(self.schema_name) if self.schema_name else None,
"batch_identifiers": {},
}

Expand All @@ -1099,19 +1185,24 @@ def _is_bracketed_by_quotes(target: str) -> bool:
"""
Returns True if the target string is bracketed by quotes.

Override this method if the quote characters are different than `'` or `"` in the
target database, such as backticks in Databricks SQL.
Supports standard quotes ('', "") and MSSQL brackets ([]).

Arguments:
target: A string to check if it is bracketed by quotes.

Returns:
True if the target string is bracketed by quotes.
"""
return any(
target.startswith(quote) and target.endswith(quote)
for quote in DEFAULT_QUOTE_CHARACTERS
)
# Check standard quotes
for quote in DEFAULT_QUOTE_CHARACTERS:
if target.startswith(quote) and target.endswith(quote):
return True

# Check MSSQL brackets
if target.startswith("[") and target.endswith("]"):
return True

return False

@classmethod
def _to_lower_if_not_bracketed_by_quotes(cls, target: str) -> str:
Expand All @@ -1124,7 +1215,10 @@ def _to_lower_if_not_bracketed_by_quotes(cls, target: str) -> str:
Returns:
The target string in lowercase if it is not bracketed by quotes.
"""
return to_lower_if_not_quoted(target, quote_characters=DEFAULT_QUOTE_CHARACTERS)
# Include MSSQL brackets in the check
if cls._is_bracketed_by_quotes(target):
return target
return target.lower()


def _warn_for_more_specific_datasource_type(connection_string: str) -> None:
Expand Down Expand Up @@ -1277,6 +1371,7 @@ def test_connection(self, test_assets: bool = True) -> None:
raise TestConnectionError(cause=e) from e
if self.assets and test_assets:
for asset in self.assets:
# Temporarily set datasource reference for test
asset._datasource = self
asset.test_connection()

Expand All @@ -1301,8 +1396,7 @@ def add_table_asset(
The type of this object will match the necessary type for this datasource.
eg, it could be a TableAsset or a SqliteTableAsset.
""" # noqa: E501 # FIXME CoP
if schema_name:
schema_name = self._TableAsset._to_lower_if_not_bracketed_by_quotes(schema_name)
# The validators in TableAsset will handle lowercase and quoting
asset = self._TableAsset(
name=name,
table_name=table_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,32 @@ def __init__(self, dialect: str):
PartitionerMethod.PARTITION_ON_HASHED_COLUMN: "get_partition_query_for_data_for_batch_identifiers_for_partition_on_hashed_column", # noqa: E501 # FIXME CoP
}

def _get_date_part_expression(self, date_part: DatePart, column):
"""Generate dialect-specific date part extraction.

Args:
date_part: DatePart enum specifying which part of date to extract
column: SQLAlchemy column object to extract date part from

Returns:
SQLAlchemy expression for extracting the date part
"""
if self._dialect == GXSqlDialect.MSSQL:
# MSSQL uses YEAR(), MONTH(), DAY() functions instead of EXTRACT
if date_part.value == "year":
return sa.func.year(column)
elif date_part.value == "month":
return sa.func.month(column)
elif date_part.value == "day":
return sa.func.day(column)
else:
raise NotImplementedError(
f"Date part {date_part.value} not supported for MSSQL dialect"
)
else:
# All other dialects (PostgreSQL, MySQL, SQLite, etc.) use EXTRACT
return sa.extract(date_part.value, column)

def partition_on_year(
self,
column_name: str,
Expand Down Expand Up @@ -169,7 +195,7 @@ def partition_on_date_parts(

query: Union[sqlalchemy.BinaryExpression, sqlalchemy.BooleanClauseList] = sa.and_( # type: ignore[assignment] # FIXME CoP
*[
sa.extract(date_part.value, sa.column(column_name))
self._get_date_part_expression(date_part, sa.column(column_name))
== date_parts_dict[date_part.value]
for date_part in date_parts
]
Expand Down Expand Up @@ -478,7 +504,7 @@ def get_partition_query_for_data_for_batch_identifiers_for_partition_on_date_par
if len(date_parts) == 1:
# MSSql does not accept single item concatenation
concat_clause = sa.func.distinct( # type: ignore[assignment] # FIXME CoP
sa.func.extract(date_parts[0].value, sa.column(column_name)).label(
self._get_date_part_expression(date_parts[0], sa.column(column_name)).label(
date_parts[0].value
)
).label("concat_distinct_values")
Expand All @@ -490,15 +516,15 @@ def get_partition_query_for_data_for_batch_identifiers_for_partition_on_date_par
""" # noqa: E501 # FIXME CoP
if self._dialect == GXSqlDialect.SQLITE:
concat_date_parts = sa.cast(
sa.func.extract(date_parts[0].value, sa.column(column_name)),
self._get_date_part_expression(date_parts[0], sa.column(column_name)),
sa.String,
)

date_part: DatePart
for date_part in date_parts[1:]:
concat_date_parts = concat_date_parts.concat(
sa.cast(
sa.func.extract(date_part.value, sa.column(column_name)),
self._get_date_part_expression(date_part, sa.column(column_name)),
sa.String,
)
)
Expand All @@ -508,7 +534,7 @@ def get_partition_query_for_data_for_batch_identifiers_for_partition_on_date_par
concat_date_parts = sa.func.concat(
"",
sa.cast(
sa.func.extract(date_parts[0].value, sa.column(column_name)),
self._get_date_part_expression(date_parts[0], sa.column(column_name)),
sa.String,
),
)
Expand All @@ -517,7 +543,7 @@ def get_partition_query_for_data_for_batch_identifiers_for_partition_on_date_par
concat_date_parts = sa.func.concat(
concat_date_parts,
sa.cast(
sa.func.extract(date_part.value, sa.column(column_name)),
self._get_date_part_expression(date_part, sa.column(column_name)),
sa.String,
),
)
Expand All @@ -527,9 +553,9 @@ def get_partition_query_for_data_for_batch_identifiers_for_partition_on_date_par
partitioned_query: sqlalchemy.Selectable = sa.select( # type: ignore[call-overload] # FIXME CoP
concat_clause,
*[
sa.cast(sa.func.extract(date_part.value, sa.column(column_name)), sa.Integer).label(
date_part.value
)
sa.cast(
self._get_date_part_expression(date_part, sa.column(column_name)), sa.Integer
).label(date_part.value)
for date_part in date_parts
],
).select_from(selectable)
Expand Down
Loading
Loading