From 41d78237477e2cbd23c7c7b99f0a22e2e9d2a7a9 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Thu, 6 Feb 2025 14:09:15 -0800 Subject: [PATCH] fix: db properly support `with_log_level` --- sqlmesh/core/engine_adapter/base.py | 8 ++++++ sqlmesh/core/engine_adapter/databricks.py | 33 ++++++++++++++++------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index ad48eb80f..7cd133d5e 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -119,6 +119,7 @@ def __init__( register_comments: bool = True, pre_ping: bool = False, pretty_sql: bool = False, + skip_connection_post_init: bool = False, **kwargs: t.Any, ): self.dialect = dialect.lower() or self.DIALECT @@ -132,6 +133,11 @@ def __init__( self._register_comments = register_comments self._pre_ping = pre_ping self._pretty_sql = pretty_sql + if not skip_connection_post_init: + self._connection_post_init() + + def _connection_post_init(self) -> None: + pass def with_log_level(self, level: int) -> EngineAdapter: adapter = self.__class__( @@ -141,10 +147,12 @@ def with_log_level(self, level: int) -> EngineAdapter: default_catalog=self._default_catalog, execute_log_level=level, register_comments=self._register_comments, + skip_connection_post_init=True, **self._extra_config, ) adapter._connection_pool = self._connection_pool + adapter._connection_post_init() return adapter diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 031cbcd6c..64fa32cd1 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -48,9 +48,13 @@ class DatabricksEngineAdapter(SparkEngineAdapter): }, ) - def __init__(self, *args: t.Any, **kwargs: t.Any): - super().__init__(*args, **kwargs) + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + super().__init__(*args, **{**kwargs, "skip_connection_post_init": True}) self._set_spark_engine_adapter_if_needed() + if not kwargs.get("skip_connection_post_init"): + self._connection_post_init() + + def _connection_post_init(self) -> None: # Set the default catalog for both connections to make sure they are aligned self.set_current_catalog(self.default_catalog) # type: ignore @@ -121,7 +125,8 @@ def _set_spark_engine_adapter_if_needed(self) -> None: DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate() ) self._spark_engine_adapter = SparkEngineAdapter( - partial(connection, spark=spark, catalog=catalog) + partial(connection, spark=spark, catalog=catalog), + default_catalog=catalog, ) @property @@ -181,7 +186,7 @@ def _fetch_native_df( """Fetches a DataFrame that can be either Pandas or PySpark from the cursor""" if self.is_spark_session_connection: return super()._fetch_native_df(query, quote_identifiers=quote_identifiers) - if self._use_spark_session: + if self._spark_engine_adapter: return self._spark_engine_adapter._fetch_native_df( # type: ignore query, quote_identifiers=quote_identifiers ) @@ -211,6 +216,8 @@ def get_current_catalog(self) -> t.Optional[str]: pyspark_catalog = self._spark_engine_adapter.get_current_catalog() except (Py4JError, SparkConnectGrpcException): pass + elif self.is_spark_session_connection: + pyspark_catalog = self.connection.spark.catalog.currentCatalog() if not self.is_spark_session_connection: result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION)) sql_connector_catalog = result[0] if result else None @@ -221,20 +228,26 @@ def get_current_catalog(self) -> t.Optional[str]: return pyspark_catalog or sql_connector_catalog def set_current_catalog(self, catalog_name: str) -> None: - # Since Databricks splits commands across the Dataframe API and the SQL Connector - # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both - # are set to the same catalog since they maintain their default catalog separately - self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) - if self._use_spark_session: + def _set_spark_session_current_catalog(spark: PySparkSession) -> None: from py4j.protocol import Py4JError from pyspark.errors.exceptions.connect import SparkConnectGrpcException try: # Note: Spark 3.4+ Only API - self._spark_engine_adapter.set_current_catalog(catalog_name) # type: ignore + spark.catalog.setCurrentCatalog(catalog_name) except (Py4JError, SparkConnectGrpcException): pass + # Since Databricks splits commands across the Dataframe API and the SQL Connector + # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both + # are set to the same catalog since they maintain their default catalog separately + self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) + if self.is_spark_session_connection: + _set_spark_session_current_catalog(self.connection.spark) + + if self._spark_engine_adapter: + _set_spark_session_current_catalog(self._spark_engine_adapter.spark) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: