Skip to content

Commit 41d7823

Browse files
committed
fix: db properly support with_log_level
1 parent 7d48f38 commit 41d7823

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
register_comments: bool = True,
120120
pre_ping: bool = False,
121121
pretty_sql: bool = False,
122+
skip_connection_post_init: bool = False,
122123
**kwargs: t.Any,
123124
):
124125
self.dialect = dialect.lower() or self.DIALECT
@@ -132,6 +133,11 @@ def __init__(
132133
self._register_comments = register_comments
133134
self._pre_ping = pre_ping
134135
self._pretty_sql = pretty_sql
136+
if not skip_connection_post_init:
137+
self._connection_post_init()
138+
139+
def _connection_post_init(self) -> None:
140+
pass
135141

136142
def with_log_level(self, level: int) -> EngineAdapter:
137143
adapter = self.__class__(
@@ -141,10 +147,12 @@ def with_log_level(self, level: int) -> EngineAdapter:
141147
default_catalog=self._default_catalog,
142148
execute_log_level=level,
143149
register_comments=self._register_comments,
150+
skip_connection_post_init=True,
144151
**self._extra_config,
145152
)
146153

147154
adapter._connection_pool = self._connection_pool
155+
adapter._connection_post_init()
148156

149157
return adapter
150158

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,13 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
4848
},
4949
)
5050

51-
def __init__(self, *args: t.Any, **kwargs: t.Any):
52-
super().__init__(*args, **kwargs)
51+
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
52+
super().__init__(*args, **{**kwargs, "skip_connection_post_init": True})
5353
self._set_spark_engine_adapter_if_needed()
54+
if not kwargs.get("skip_connection_post_init"):
55+
self._connection_post_init()
56+
57+
def _connection_post_init(self) -> None:
5458
# Set the default catalog for both connections to make sure they are aligned
5559
self.set_current_catalog(self.default_catalog) # type: ignore
5660

@@ -121,7 +125,8 @@ def _set_spark_engine_adapter_if_needed(self) -> None:
121125
DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate()
122126
)
123127
self._spark_engine_adapter = SparkEngineAdapter(
124-
partial(connection, spark=spark, catalog=catalog)
128+
partial(connection, spark=spark, catalog=catalog),
129+
default_catalog=catalog,
125130
)
126131

127132
@property
@@ -181,7 +186,7 @@ def _fetch_native_df(
181186
"""Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
182187
if self.is_spark_session_connection:
183188
return super()._fetch_native_df(query, quote_identifiers=quote_identifiers)
184-
if self._use_spark_session:
189+
if self._spark_engine_adapter:
185190
return self._spark_engine_adapter._fetch_native_df( # type: ignore
186191
query, quote_identifiers=quote_identifiers
187192
)
@@ -211,6 +216,8 @@ def get_current_catalog(self) -> t.Optional[str]:
211216
pyspark_catalog = self._spark_engine_adapter.get_current_catalog()
212217
except (Py4JError, SparkConnectGrpcException):
213218
pass
219+
elif self.is_spark_session_connection:
220+
pyspark_catalog = self.connection.spark.catalog.currentCatalog()
214221
if not self.is_spark_session_connection:
215222
result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
216223
sql_connector_catalog = result[0] if result else None
@@ -221,20 +228,26 @@ def get_current_catalog(self) -> t.Optional[str]:
221228
return pyspark_catalog or sql_connector_catalog
222229

223230
def set_current_catalog(self, catalog_name: str) -> None:
224-
# Since Databricks splits commands across the Dataframe API and the SQL Connector
225-
# (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
226-
# are set to the same catalog since they maintain their default catalog separately
227-
self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
228-
if self._use_spark_session:
231+
def _set_spark_session_current_catalog(spark: PySparkSession) -> None:
229232
from py4j.protocol import Py4JError
230233
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
231234

232235
try:
233236
# Note: Spark 3.4+ Only API
234-
self._spark_engine_adapter.set_current_catalog(catalog_name) # type: ignore
237+
spark.catalog.setCurrentCatalog(catalog_name)
235238
except (Py4JError, SparkConnectGrpcException):
236239
pass
237240

241+
# Since Databricks splits commands across the Dataframe API and the SQL Connector
242+
# (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
243+
# are set to the same catalog since they maintain their default catalog separately
244+
self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
245+
if self.is_spark_session_connection:
246+
_set_spark_session_current_catalog(self.connection.spark)
247+
248+
if self._spark_engine_adapter:
249+
_set_spark_session_current_catalog(self._spark_engine_adapter.spark)
250+
238251
def _get_data_objects(
239252
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
240253
) -> t.List[DataObject]:

0 commit comments

Comments
 (0)