Skip to content

Commit 591645c

Browse files
authored
feat: don't force db connect if using serverless (#3781)
1 parent 26026f3 commit 591645c

File tree

6 files changed

+85
-49
lines changed

6 files changed

+85
-49
lines changed

docs/integrations/engines/databricks.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ SQLMesh connects to Databricks with the [Databricks SQL Connector](https://docs.
1414

1515
The SQL Connector is bundled with SQLMesh and automatically installed when you include the `databricks` extra in the command `pip install "sqlmesh[databricks]"`.
1616

17-
The SQL Connector has all the functionality needed for SQLMesh to execute SQL models on Databricks and Python models locally (the default SQLMesh approach).
17+
The SQL Connector has all the functionality needed for SQLMesh to execute SQL models on Databricks and Python models that do not return PySpark DataFrames.
1818

19-
The SQL Connector does not support Databricks Serverless Compute. If you require Serverless Compute then you must use the Databricks Connect library.
19+
If you have Python models returning PySpark DataFrames, check out the [Databricks Connect](#databricks-connect-1) section.
2020

2121
### Databricks Connect
2222

@@ -229,7 +229,9 @@ If you want Databricks to process PySpark DataFrames in SQLMesh Python models, t
229229

230230
SQLMesh **DOES NOT** include/bundle the Databricks Connect library. You must [install the version of Databricks Connect](https://docs.databricks.com/en/dev-tools/databricks-connect/python/install.html) that matches the Databricks Runtime used in your Databricks cluster.
231231

232-
SQLMesh's Databricks Connect implementation supports Databricks Runtime 13.0 or higher. If SQLMesh detects that you have Databricks Connect installed, then it will use it for all Python models (both Pandas and PySpark DataFrames).
232+
If SQLMesh detects that you have Databricks Connect installed, then it will automatically configure the connection and use it for all Python models that return a Pandas or PySpark DataFrame.
233+
234+
To have databricks-connect installed but ignored by SQLMesh, set `disable_databricks_connect` to `true` in the connection configuration.
233235

234236
Databricks Connect can execute SQL and DataFrame operations on different clusters by setting the SQLMesh `databricks_connect_*` connection options. For example, these options could configure SQLMesh to run SQL on a [Databricks SQL Warehouse](https://docs.databricks.com/sql/admin/create-sql-warehouse.html) while still routing DataFrame operations to a normal Databricks Cluster.
235237

@@ -259,7 +261,7 @@ The only relevant SQLMesh configuration parameter is the optional `catalog` para
259261
| `databricks_connect_server_hostname` | Databricks Connect Only: Databricks Connect server hostname. Uses `server_hostname` if not set. | string | N |
260262
| `databricks_connect_access_token` | Databricks Connect Only: Databricks Connect access token. Uses `access_token` if not set. | string | N |
261263
| `databricks_connect_cluster_id` | Databricks Connect Only: Databricks Connect cluster ID. Uses `http_path` if not set. Cannot be a Databricks SQL Warehouse. | string | N |
262-
| `databricks_connect_use_serverless` | Databricks Connect Only: Use a serverless cluster for Databricks Connect. If using serverless then SQL connector is disabled since Serverless is not supported for SQL Connector | bool | N |
264+
| `databricks_connect_use_serverless` | Databricks Connect Only: Use a serverless cluster for Databricks Connect instead of `databricks_connect_cluster_id`. | bool | N |
263265
| `force_databricks_connect` | When running locally, force the use of Databricks Connect for all model operations (so don't use SQL Connector for SQL models) | bool | N |
264266
| `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N |
265267
| `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N |

sqlmesh/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def configure_logging(
141141
log_limit: int = c.DEFAULT_LOG_LIMIT,
142142
log_file_dir: t.Optional[t.Union[str, Path]] = None,
143143
) -> None:
144+
# Remove noisy grpc logs that are not useful for users
145+
os.environ["GRPC_VERBOSITY"] = os.environ.get("GRPC_VERBOSITY", "NONE")
146+
144147
logger = logging.getLogger()
145148
debug = force_debug or debug_mode_enabled()
146149

sqlmesh/core/config/connection.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,12 @@ class DatabricksConnectionConfig(ConnectionConfig):
623623

624624
@model_validator(mode="before")
625625
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
626+
# SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block.
627+
# Disabling this allows SQLMesh to determine what should be shown to the user.
628+
# Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show
629+
# the user since it is expected if the table doesn't exist. Without this change the user would see the error.
630+
logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL)
631+
626632
if not isinstance(data, dict):
627633
return data
628634

@@ -641,10 +647,6 @@ def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
641647
data.get("auth_type"),
642648
)
643649

644-
if databricks_connect_use_serverless:
645-
data["force_databricks_connect"] = True
646-
data["disable_databricks_connect"] = False
647-
648650
if (not server_hostname or not http_path or not access_token) and (
649651
not databricks_connect_use_serverless and not auth_type
650652
):
@@ -666,11 +668,12 @@ def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
666668
data["databricks_connect_access_token"] = access_token
667669
if not data.get("databricks_connect_server_hostname"):
668670
data["databricks_connect_server_hostname"] = f"https://{server_hostname}"
669-
if not databricks_connect_use_serverless:
670-
if not data.get("databricks_connect_cluster_id"):
671-
if t.TYPE_CHECKING:
672-
assert http_path is not None
673-
data["databricks_connect_cluster_id"] = http_path.split("/")[-1]
671+
if not databricks_connect_use_serverless and not data.get(
672+
"databricks_connect_cluster_id"
673+
):
674+
if t.TYPE_CHECKING:
675+
assert http_path is not None
676+
data["databricks_connect_cluster_id"] = http_path.split("/")[-1]
674677

675678
if auth_type:
676679
from databricks.sql.auth.auth import AuthType

sqlmesh/core/engine_adapter/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@
4343
from sqlmesh.utils import columns_to_types_all_known, random_id
4444
from sqlmesh.utils.connection_pool import create_connection_pool
4545
from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column
46-
from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError
46+
from sqlmesh.utils.errors import (
47+
SQLMeshError,
48+
UnsupportedCatalogOperationError,
49+
MissingDefaultCatalogError,
50+
)
4751
from sqlmesh.utils.pandas import columns_to_types_from_df
4852

4953
if t.TYPE_CHECKING:
@@ -186,7 +190,9 @@ def default_catalog(self) -> t.Optional[str]:
186190
return None
187191
default_catalog = self._default_catalog or self.get_current_catalog()
188192
if not default_catalog:
189-
raise SQLMeshError("Could not determine a default catalog despite it being supported.")
193+
raise MissingDefaultCatalogError(
194+
"Could not determine a default catalog despite it being supported."
195+
)
190196
return default_catalog
191197

192198
@property

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import logging
4-
import os
54
import typing as t
65

76
import pandas as pd
@@ -17,7 +16,7 @@
1716
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
1817
from sqlmesh.core.node import IntervalUnit
1918
from sqlmesh.core.schema_diff import SchemaDiffer
20-
from sqlmesh.utils.errors import SQLMeshError
19+
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
2120

2221
if t.TYPE_CHECKING:
2322
from sqlmesh.core._typing import SchemaName, TableName
@@ -92,17 +91,6 @@ def _use_spark_session(self) -> bool:
9291
)
9392
)
9493

95-
@property
96-
def use_serverless(self) -> bool:
97-
from sqlmesh import RuntimeEnv
98-
from sqlmesh.utils import str_to_bool
99-
100-
if not self._use_spark_session:
101-
return False
102-
return (
103-
RuntimeEnv.get().is_databricks and str_to_bool(os.environ.get("IS_SERVERLESS", "False"))
104-
) or bool(self._extra_config["databricks_connect_use_serverless"])
105-
10694
@property
10795
def is_spark_session_cursor(self) -> bool:
10896
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor
@@ -124,12 +112,17 @@ def spark(self) -> PySparkSession:
124112
from databricks.connect import DatabricksSession
125113

126114
if self._spark is None:
115+
connect_kwargs = dict(
116+
host=self._extra_config["databricks_connect_server_hostname"],
117+
token=self._extra_config["databricks_connect_access_token"],
118+
)
119+
if "databricks_connect_use_serverless" in self._extra_config:
120+
connect_kwargs["serverless"] = True
121+
else:
122+
connect_kwargs["cluster_id"] = self._extra_config["databricks_connect_cluster_id"]
123+
127124
self._spark = (
128-
DatabricksSession.builder.remote(
129-
host=self._extra_config["databricks_connect_server_hostname"],
130-
token=self._extra_config["databricks_connect_access_token"],
131-
cluster_id=self._extra_config["databricks_connect_cluster_id"],
132-
)
125+
DatabricksSession.builder.remote(**connect_kwargs)
133126
.userAgent("sqlmesh")
134127
.getOrCreate()
135128
)
@@ -157,14 +150,8 @@ def _df_to_source_queries(
157150

158151
def query_factory() -> Query:
159152
temp_table = self._get_temp_table(target_table or "spark", table_only=True)
160-
if self.use_serverless:
161-
# Global temp views are not supported on Databricks Serverless
162-
# This also means we can't mix Python SQL Connection and DB Connect since they wouldn't
163-
# share the same temp objects.
164-
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
165-
else:
166-
df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
167-
temp_table.set("db", "global_temp")
153+
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect))
154+
self._connection_pool.set_attribute("requires_spark_session_temp_objects", True)
168155
return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table)
169156

170157
if self._use_spark_session:
@@ -199,28 +186,50 @@ def fetchdf(
199186
return df.toPandas()
200187
return df
201188

189+
def _execute(
190+
self,
191+
sql: str,
192+
**kwargs: t.Any,
193+
) -> None:
194+
if self._connection_pool.get_attribute("requires_spark_session_temp_objects"):
195+
self._fetch_native_df(sql)
196+
else:
197+
super()._execute(sql, **kwargs)
198+
199+
def _end_session(self) -> None:
200+
"""End the existing session."""
201+
self._connection_pool.set_attribute("requires_spark_session_temp_objects", False)
202+
202203
def get_current_catalog(self) -> t.Optional[str]:
203-
# Update the Dataframe API if we have a spark session
204+
pyspark_catalog = None
205+
sql_connector_catalog = None
204206
if self._use_spark_session:
205207
from py4j.protocol import Py4JError
206208
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
207209

208210
try:
209211
# Note: Spark 3.4+ Only API
210-
return self.spark.catalog.currentCatalog()
212+
pyspark_catalog = self.spark.catalog.currentCatalog()
211213
except (Py4JError, SparkConnectGrpcException):
212214
pass
213-
result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
214-
if result:
215-
return result[0]
216-
return None
215+
if not self.is_spark_session_cursor:
216+
result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
217+
sql_connector_catalog = result[0] if result else None
218+
if (
219+
self._use_spark_session
220+
and not self.is_spark_session_cursor
221+
and pyspark_catalog != sql_connector_catalog
222+
):
223+
raise SQLMeshError(
224+
f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same."
225+
)
226+
return pyspark_catalog or sql_connector_catalog
217227

218228
def set_current_catalog(self, catalog_name: str) -> None:
219229
# Since Databricks splits commands across the Dataframe API and the SQL Connector
220230
# (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
221-
# are set to the same catalog since they maintain their default catalog seperately
231+
# are set to the same catalog since they maintain their default catalog separately
222232
self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
223-
# Update the Dataframe API is we have a spark session
224233
if self._use_spark_session:
225234
from py4j.protocol import Py4JError
226235
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
@@ -257,6 +266,15 @@ def clone_table(
257266
def wap_supported(self, table_name: TableName) -> bool:
258267
return False
259268

269+
@property
270+
def default_catalog(self) -> t.Optional[str]:
271+
try:
272+
return super().default_catalog
273+
except MissingDefaultCatalogError as e:
274+
raise MissingDefaultCatalogError(
275+
"Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details"
276+
) from e
277+
260278
def _build_table_properties_exp(
261279
self,
262280
catalog_name: t.Optional[str] = None,

sqlmesh/utils/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ class PythonModelEvalError(SQLMeshError):
159159
pass
160160

161161

162+
class MissingDefaultCatalogError(SQLMeshError):
163+
pass
164+
165+
162166
def raise_config_error(
163167
msg: str,
164168
location: t.Optional[str | Path] = None,

0 commit comments

Comments
 (0)