Skip to content

Commit

Permalink
fix: db properly support with_log_level
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq committed Feb 6, 2025
1 parent 7d48f38 commit 41d7823
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
8 changes: 8 additions & 0 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
register_comments: bool = True,
pre_ping: bool = False,
pretty_sql: bool = False,
skip_connection_post_init: bool = False,
**kwargs: t.Any,
):
self.dialect = dialect.lower() or self.DIALECT
Expand All @@ -132,6 +133,11 @@ def __init__(
self._register_comments = register_comments
self._pre_ping = pre_ping
self._pretty_sql = pretty_sql
if not skip_connection_post_init:
self._connection_post_init()

def _connection_post_init(self) -> None:
pass

def with_log_level(self, level: int) -> EngineAdapter:
adapter = self.__class__(
Expand All @@ -141,10 +147,12 @@ def with_log_level(self, level: int) -> EngineAdapter:
default_catalog=self._default_catalog,
execute_log_level=level,
register_comments=self._register_comments,
skip_connection_post_init=True,
**self._extra_config,
)

adapter._connection_pool = self._connection_pool
adapter._connection_post_init()

return adapter

Expand Down
33 changes: 23 additions & 10 deletions sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
},
)

def __init__(self, *args: t.Any, **kwargs: t.Any):
super().__init__(*args, **kwargs)
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **{**kwargs, "skip_connection_post_init": True})
self._set_spark_engine_adapter_if_needed()
if not kwargs.get("skip_connection_post_init"):
self._connection_post_init()

def _connection_post_init(self) -> None:
# Set the default catalog for both connections to make sure they are aligned
self.set_current_catalog(self.default_catalog) # type: ignore

Expand Down Expand Up @@ -121,7 +125,8 @@ def _set_spark_engine_adapter_if_needed(self) -> None:
DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate()
)
self._spark_engine_adapter = SparkEngineAdapter(
partial(connection, spark=spark, catalog=catalog)
partial(connection, spark=spark, catalog=catalog),
default_catalog=catalog,
)

@property
Expand Down Expand Up @@ -181,7 +186,7 @@ def _fetch_native_df(
"""Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
if self.is_spark_session_connection:
return super()._fetch_native_df(query, quote_identifiers=quote_identifiers)
if self._use_spark_session:
if self._spark_engine_adapter:
return self._spark_engine_adapter._fetch_native_df( # type: ignore
query, quote_identifiers=quote_identifiers
)
Expand Down Expand Up @@ -211,6 +216,8 @@ def get_current_catalog(self) -> t.Optional[str]:
pyspark_catalog = self._spark_engine_adapter.get_current_catalog()
except (Py4JError, SparkConnectGrpcException):
pass
elif self.is_spark_session_connection:
pyspark_catalog = self.connection.spark.catalog.currentCatalog()
if not self.is_spark_session_connection:
result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
sql_connector_catalog = result[0] if result else None
Expand All @@ -221,20 +228,26 @@ def get_current_catalog(self) -> t.Optional[str]:
return pyspark_catalog or sql_connector_catalog

def set_current_catalog(self, catalog_name: str) -> None:
# Since Databricks splits commands across the Dataframe API and the SQL Connector
# (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
# are set to the same catalog since they maintain their default catalog separately
self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
if self._use_spark_session:
def _set_spark_session_current_catalog(spark: PySparkSession) -> None:
from py4j.protocol import Py4JError
from pyspark.errors.exceptions.connect import SparkConnectGrpcException

try:
# Note: Spark 3.4+ Only API
self._spark_engine_adapter.set_current_catalog(catalog_name) # type: ignore
spark.catalog.setCurrentCatalog(catalog_name)
except (Py4JError, SparkConnectGrpcException):
pass

# Since Databricks splits commands across the Dataframe API and the SQL Connector
# (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
# are set to the same catalog since they maintain their default catalog separately
self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
if self.is_spark_session_connection:
_set_spark_session_current_catalog(self.connection.spark)

if self._spark_engine_adapter:
_set_spark_session_current_catalog(self._spark_engine_adapter.spark)

def _get_data_objects(
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
) -> t.List[DataObject]:
Expand Down

0 comments on commit 41d7823

Please sign in to comment.