@@ -48,9 +48,13 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
4848 },
4949 )
5050
51- def __init__ (self , * args : t .Any , ** kwargs : t .Any ):
52- super ().__init__ (* args , ** kwargs )
51+ def __init__ (self , * args : t .Any , ** kwargs : t .Any ) -> None :
52+ super ().__init__ (* args , ** { ** kwargs , "skip_connection_post_init" : True } )
5353 self ._set_spark_engine_adapter_if_needed ()
54+ if not kwargs .get ("skip_connection_post_init" ):
55+ self ._connection_post_init ()
56+
57+ def _connection_post_init (self ) -> None :
5458 # Set the default catalog for both connections to make sure they are aligned
5559 self .set_current_catalog (self .default_catalog ) # type: ignore
5660
@@ -121,7 +125,8 @@ def _set_spark_engine_adapter_if_needed(self) -> None:
121125 DatabricksSession .builder .remote (** connect_kwargs ).userAgent ("sqlmesh" ).getOrCreate ()
122126 )
123127 self ._spark_engine_adapter = SparkEngineAdapter (
124- partial (connection , spark = spark , catalog = catalog )
128+ partial (connection , spark = spark , catalog = catalog ),
129+ default_catalog = catalog ,
125130 )
126131
127132 @property
@@ -181,7 +186,7 @@ def _fetch_native_df(
181186 """Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
182187 if self .is_spark_session_connection :
183188 return super ()._fetch_native_df (query , quote_identifiers = quote_identifiers )
184- if self ._use_spark_session :
189+ if self ._spark_engine_adapter :
185190 return self ._spark_engine_adapter ._fetch_native_df ( # type: ignore
186191 query , quote_identifiers = quote_identifiers
187192 )
@@ -211,6 +216,8 @@ def get_current_catalog(self) -> t.Optional[str]:
211216 pyspark_catalog = self ._spark_engine_adapter .get_current_catalog ()
212217 except (Py4JError , SparkConnectGrpcException ):
213218 pass
219+ elif self .is_spark_session_connection :
220+ pyspark_catalog = self .connection .spark .catalog .currentCatalog ()
214221 if not self .is_spark_session_connection :
215222 result = self .fetchone (exp .select (self .CURRENT_CATALOG_EXPRESSION ))
216223 sql_connector_catalog = result [0 ] if result else None
@@ -221,20 +228,26 @@ def get_current_catalog(self) -> t.Optional[str]:
221228 return pyspark_catalog or sql_connector_catalog
222229
223230 def set_current_catalog (self , catalog_name : str ) -> None :
224- # Since Databricks splits commands across the Dataframe API and the SQL Connector
225- # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
226- # are set to the same catalog since they maintain their default catalog separately
227- self .execute (exp .Use (this = exp .to_identifier (catalog_name ), kind = "CATALOG" ))
228- if self ._use_spark_session :
231+ def _set_spark_session_current_catalog (spark : PySparkSession ) -> None :
229232 from py4j .protocol import Py4JError
230233 from pyspark .errors .exceptions .connect import SparkConnectGrpcException
231234
232235 try :
233236 # Note: Spark 3.4+ Only API
234- self . _spark_engine_adapter . set_current_catalog (catalog_name ) # type: ignore
237+ spark . catalog . setCurrentCatalog (catalog_name )
235238 except (Py4JError , SparkConnectGrpcException ):
236239 pass
237240
241+ # Since Databricks splits commands across the Dataframe API and the SQL Connector
242+ # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
243+ # are set to the same catalog since they maintain their default catalog separately
244+ self .execute (exp .Use (this = exp .to_identifier (catalog_name ), kind = "CATALOG" ))
245+ if self .is_spark_session_connection :
246+ _set_spark_session_current_catalog (self .connection .spark )
247+
248+ if self ._spark_engine_adapter :
249+ _set_spark_session_current_catalog (self ._spark_engine_adapter .spark )
250+
238251 def _get_data_objects (
239252 self , schema_name : SchemaName , object_names : t .Optional [t .Set [str ]] = None
240253 ) -> t .List [DataObject ]:
0 commit comments