Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for hybrid tables and indexes #533

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ jobs:
run: |
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \
.github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py
- name: Run test for AWS
run: hatch run test-dialect-aws
if: matrix.cloud-provider == 'aws'
- name: Run tests
run: hatch run test-dialect
- uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -203,6 +206,9 @@ jobs:
python -m pip install -U uv
python -m uv pip install -U hatch
python -m hatch env create default
- name: Run test for AWS
run: hatch run sa14:test-dialect-aws
if: matrix.cloud-provider == 'aws'
- name: Run tests
run: hatch run sa14:test-dialect
- uses: actions/upload-artifact@v4
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Source code is also available at:
- (Unreleased)

- Add support for dynamic tables and required options
- Fixed SAWarning when registering functions with existing name in default namespace
- Add support for hybrid tables
- Fixed SAWarning when registering functions with existing name in default namespace

- v1.6.1(July 9, 2024)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ SQLACHEMY_WARN_20 = "1"
check = "pre-commit run --all-files"
test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/"
test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite"
test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/"
gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1"
check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'"

Expand All @@ -110,7 +111,7 @@ line-length = 88
line-length = 88

[tool.pytest.ini_options]
addopts = "-m 'not feature_max_lob_size'"
addopts = "-m 'not feature_max_lob_size and not aws'"
markers = [
# Optional dependency groups markers
"lambda: AWS lambda tests",
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
VARBINARY,
VARIANT,
)
from .sql.custom_schema import DynamicTable
from .sql.custom_schema import DynamicTable, HybridTable
from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse
from .util import _url as URL

Expand Down Expand Up @@ -120,4 +120,5 @@
"TargetLag",
"TimeUnit",
"Warehouse",
"HybridTable",
)
138 changes: 129 additions & 9 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
_CUSTOM_Float,
_CUSTOM_Time,
)
from .sql.custom_schema.custom_table_prefix import CustomTablePrefix
from .util import (
_update_connection_application_name,
parse_url_boolean,
Expand Down Expand Up @@ -352,14 +353,6 @@ def _map_name_to_idx(result):
name_to_idx[col[0]] = idx
return name_to_idx

@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
"""
Gets all indexes
"""
# no index is supported by Snowflake
return []

@reflection.cache
def get_check_constraints(self, connection, table_name, schema, **kw):
# check constraints are not supported by Snowflake
Expand Down Expand Up @@ -895,6 +888,129 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
)
}

def get_multi_indexes(
self,
connection,
*,
schema,
filter_names,
**kw,
):
"""
Gets the indexes definition
"""

table_prefixes = self.get_multi_prefixes(
connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name
)
if len(table_prefixes) == 0:
return []
schema = schema or self.default_schema_name
if not schema:
result = connection.execute(
text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES")
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}"
)
)

n2i = self.__class__._map_name_to_idx(result)
indexes = {}

for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["table"]]))
if (
row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY'
or table not in filter_names
or (schema, table) not in table_prefixes
or (
(schema, table) in table_prefixes
and CustomTablePrefix.HYBRID.name
not in table_prefixes[(schema, table)]
)
):
continue
index = {
"name": row[n2i["name"]],
"unique": row[n2i["is_unique"]] == "Y",
"column_names": row[n2i["columns"]],
"include_columns": row[n2i["included_columns"]],
"dialect_options": {},
}
if (schema, table) in indexes:
indexes[(schema, table)] = indexes[(schema, table)].append(index)
else:
indexes[(schema, table)] = [index]

return list(indexes.items())

def _value_or_default(self, data, table, schema):
table = self.normalize_name(str(table))
dic_data = dict(data)
if (schema, table) in dic_data:
return dic_data[(schema, table)]
else:
return []

def get_prefixes_from_data(self, n2i, row, **kw):
prefixes_found = []
for valid_prefix in CustomTablePrefix:
key = f"is_{valid_prefix.name.lower()}"
if key in n2i and row[n2i[key]] == "Y":
prefixes_found.append(valid_prefix.name)
return prefixes_found

@reflection.cache
def get_multi_prefixes(
self, connection, schema, table_name=None, filter_prefix=None, **kw
):
"""
Gets all table prefixes
"""
schema = schema or self.default_schema_name
filter = f"LIKE '{table_name}'" if table_name else ""
if schema:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}"
)
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'"
)
)

n2i = self.__class__._map_name_to_idx(result)
tables_prefixes = {}
for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["name"]]))
table_prefixes = self.get_prefixes_from_data(n2i, row)
if filter_prefix and filter_prefix not in table_prefixes:
continue
if (schema, table) in tables_prefixes:
tables_prefixes[(schema, table)].append(table_prefixes)
else:
tables_prefixes[(schema, table)] = table_prefixes

return tables_prefixes

@reflection.cache
def get_indexes(self, connection, tablename, schema, **kw):
"""
Gets the indexes definition
"""
table_name = self.normalize_name(str(tablename))
data = self.get_multi_indexes(
connection=connection, schema=schema, filter_names=[table_name], **kw
)

return self._value_or_default(data, table_name, schema)

def connect(self, *cargs, **cparams):
return (
super().connect(
Expand All @@ -912,8 +1028,12 @@ def connect(self, *cargs, **cparams):

@sa_vnt.listens_for(Table, "before_create")
def check_table(table, connection, _ddl_runner, **kw):
from .sql.custom_schema.hybrid_table import HybridTable

if HybridTable.is_equal_type(table): # noqa
return True
if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes:
raise NotImplementedError("Snowflake does not support indexes")
raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes")


dialect = SnowflakeDialect
3 changes: 2 additions & 1 deletion src/snowflake/sqlalchemy/sql/custom_schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
from .dynamic_table import DynamicTable
from .hybrid_table import HybridTable

__all__ = ["DynamicTable"]
__all__ = ["DynamicTable", "HybridTable"]
23 changes: 18 additions & 5 deletions src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
from ..._constants import DIALECT_NAME
from ...compat import IS_VERSION_20
from ...custom_commands import NoneType
from .custom_table_prefix import CustomTablePrefix
from .options.table_option import TableOption


class CustomTableBase(Table):
__table_prefix__ = ""
_support_primary_and_foreign_keys = True
__table_prefixes__: typing.List[CustomTablePrefix] = []
_support_primary_and_foreign_keys: bool = True

@property
def table_prefixes(self) -> typing.List[str]:
return [prefix.name for prefix in self.__table_prefixes__]

def __init__(
self,
Expand All @@ -24,8 +29,8 @@ def __init__(
*args: SchemaItem,
**kw: Any,
) -> None:
if self.__table_prefix__ != "":
prefixes = kw.get("prefixes", []) + [self.__table_prefix__]
if len(self.__table_prefixes__) > 0:
prefixes = kw.get("prefixes", []) + self.table_prefixes
kw.update(prefixes=prefixes)
if not IS_VERSION_20 and hasattr(super(), "_init"):
super()._init(name, metadata, *args, **kw)
Expand All @@ -40,7 +45,7 @@ def _validate_table(self):
self.primary_key or self.foreign_keys
):
raise ArgumentError(
f"Primary key and foreign keys are not supported in {self.__table_prefix__} TABLE."
f"Primary key and foreign keys are not supported in {' '.join(self.table_prefixes)} TABLE."
)

return True
Expand All @@ -49,3 +54,11 @@ def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]:
if option_name in self.dialect_options[DIALECT_NAME]:
return self.dialect_options[DIALECT_NAME][option_name]
return NoneType

@classmethod
def is_equal_type(cls, table: Table) -> bool:
for prefix in cls.__table_prefixes__:
if prefix.name not in table._prefixes:
return False

return True
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.

from enum import Enum


class CustomTablePrefix(Enum):
DEFAULT = 0
EXTERNAL = 1
EVENT = 2
HYBRID = 3
ICEBERG = 4
DYNAMIC = 5
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from snowflake.sqlalchemy.custom_commands import NoneType

from .custom_table_prefix import CustomTablePrefix
from .options.target_lag import TargetLag
from .options.warehouse import Warehouse
from .table_from_query import TableFromQueryBase
Expand All @@ -27,7 +28,7 @@ class DynamicTable(TableFromQueryBase):

"""

__table_prefix__ = "DYNAMIC"
__table_prefixes__ = [CustomTablePrefix.DYNAMIC]

_support_primary_and_foreign_keys = False

Expand Down
67 changes: 67 additions & 0 deletions src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Any

from sqlalchemy.exc import ArgumentError
from sqlalchemy.sql.schema import MetaData, SchemaItem

from snowflake.sqlalchemy.custom_commands import NoneType

from .custom_table_base import CustomTableBase
from .custom_table_prefix import CustomTablePrefix


class HybridTable(CustomTableBase):
"""
A class representing a hybrid table with configurable options and settings.

The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables .

While it does not support reflection at this time, it provides a flexible
interface for creating dynamic tables and management.
"""

__table_prefixes__ = [CustomTablePrefix.HYBRID]

_support_primary_and_foreign_keys = True

def __init__(
self,
name: str,
metadata: MetaData,
*args: SchemaItem,
**kw: Any,
) -> None:
if kw.get("_no_init", True):
return
super().__init__(name, metadata, *args, **kw)

def _init(
self,
name: str,
metadata: MetaData,
*args: SchemaItem,
**kw: Any,
) -> None:
super().__init__(name, metadata, *args, **kw)

def _validate_table(self):
missing_attributes = []
if self.key is NoneType:
missing_attributes.append("Primary Key")
if missing_attributes:
raise ArgumentError(
"HYBRID TABLE must have the following arguments: %s"
% ", ".join(missing_attributes)
)
super()._validate_table()

def __repr__(self) -> str:
return "HybridTable(%s)" % ", ".join(
[repr(self.name)]
+ [repr(self.metadata)]
+ [repr(x) for x in self.columns]
+ [f"{k}={repr(getattr(self, k))}" for k in ["schema"]]
)
4 changes: 4 additions & 0 deletions tests/__snapshots__/test_orm.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# serializer version: 1
# name: test_orm_one_to_many_relationship_with_hybrid_table
ProgrammingError('(snowflake.connector.errors.ProgrammingError) 200009 (22000): Foreign key constraint "SYS_INDEX_HB_TBL_ADDRESS_FOREIGN_KEY_USER_ID_HB_TBL_USER_ID" was violated.')
# ---
2 changes: 2 additions & 0 deletions tests/custom_tables/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
Loading
Loading