Skip to content

Commit

Permalink
Revert "feat: don't force db connect if using serverless" (#3784)
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman authored Feb 4, 2025
1 parent d36c2d4 commit c24ee34
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 85 deletions.
10 changes: 4 additions & 6 deletions docs/integrations/engines/databricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ SQLMesh connects to Databricks with the [Databricks SQL Connector](https://docs.

The SQL Connector is bundled with SQLMesh and automatically installed when you include the `databricks` extra in the command `pip install "sqlmesh[databricks]"`.

The SQL Connector has all the functionality needed for SQLMesh to execute SQL models on Databricks and Python models that do not return PySpark DataFrames.
The SQL Connector has all the functionality needed for SQLMesh to execute SQL models on Databricks and Python models locally (the default SQLMesh approach).

If you have Python models returning PySpark DataFrames, check out the [Databricks Connect](#databricks-connect-1) section.
The SQL Connector does not support Databricks Serverless Compute. If you require Serverless Compute then you must use the Databricks Connect library.

### Databricks Connect

Expand Down Expand Up @@ -229,9 +229,7 @@ If you want Databricks to process PySpark DataFrames in SQLMesh Python models, t

SQLMesh **DOES NOT** include/bundle the Databricks Connect library. You must [install the version of Databricks Connect](https://docs.databricks.com/en/dev-tools/databricks-connect/python/install.html) that matches the Databricks Runtime used in your Databricks cluster.

If SQLMesh detects that you have Databricks Connect installed, then it will automatically configure the connection and use it for all Python models that return a Pandas or PySpark DataFrame.

To have databricks-connect installed but ignored by SQLMesh, set `disable_databricks_connect` to `true` in the connection configuration.
SQLMesh's Databricks Connect implementation supports Databricks Runtime 13.0 or higher. If SQLMesh detects that you have Databricks Connect installed, then it will use it for all Python models (both Pandas and PySpark DataFrames).

Databricks Connect can execute SQL and DataFrame operations on different clusters by setting the SQLMesh `databricks_connect_*` connection options. For example, these options could configure SQLMesh to run SQL on a [Databricks SQL Warehouse](https://docs.databricks.com/sql/admin/create-sql-warehouse.html) while still routing DataFrame operations to a normal Databricks Cluster.

Expand Down Expand Up @@ -261,7 +259,7 @@ The only relevant SQLMesh configuration parameter is the optional `catalog` para
| `databricks_connect_server_hostname` | Databricks Connect Only: Databricks Connect server hostname. Uses `server_hostname` if not set. | string | N |
| `databricks_connect_access_token` | Databricks Connect Only: Databricks Connect access token. Uses `access_token` if not set. | string | N |
| `databricks_connect_cluster_id` | Databricks Connect Only: Databricks Connect cluster ID. Uses `http_path` if not set. Cannot be a Databricks SQL Warehouse. | string | N |
| `databricks_connect_use_serverless` | Databricks Connect Only: Use a serverless cluster for Databricks Connect instead of `databricks_connect_cluster_id`. | bool | N |
| `databricks_connect_use_serverless` | Databricks Connect Only: Use a serverless cluster for Databricks Connect. If using serverless then SQL connector is disabled since Serverless is not supported for SQL Connector | bool | N |
| `force_databricks_connect` | When running locally, force the use of Databricks Connect for all model operations (so don't use SQL Connector for SQL models) | bool | N |
| `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N |
| `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N |
Expand Down
3 changes: 0 additions & 3 deletions sqlmesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ def configure_logging(
log_limit: int = c.DEFAULT_LOG_LIMIT,
log_file_dir: t.Optional[t.Union[str, Path]] = None,
) -> None:
# Remove noisy grpc logs that are not useful for users
os.environ["GRPC_VERBOSITY"] = os.environ.get("GRPC_VERBOSITY", "NONE")

logger = logging.getLogger()
debug = force_debug or debug_mode_enabled()

Expand Down
21 changes: 9 additions & 12 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,6 @@ class DatabricksConnectionConfig(ConnectionConfig):

@model_validator(mode="before")
def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
# SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block.
# Disabling this allows SQLMesh to determine what should be shown to the user.
# Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show
# the user since it is expected if the table doesn't exist. Without this change the user would see the error.
logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL)

if not isinstance(data, dict):
return data

Expand All @@ -647,6 +641,10 @@ def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
data.get("auth_type"),
)

if databricks_connect_use_serverless:
data["force_databricks_connect"] = True
data["disable_databricks_connect"] = False

if (not server_hostname or not http_path or not access_token) and (
not databricks_connect_use_serverless and not auth_type
):
Expand All @@ -668,12 +666,11 @@ def _databricks_connect_validator(cls, data: t.Any) -> t.Any:
data["databricks_connect_access_token"] = access_token
if not data.get("databricks_connect_server_hostname"):
data["databricks_connect_server_hostname"] = f"https://{server_hostname}"
if not databricks_connect_use_serverless and not data.get(
"databricks_connect_cluster_id"
):
if t.TYPE_CHECKING:
assert http_path is not None
data["databricks_connect_cluster_id"] = http_path.split("/")[-1]
if not databricks_connect_use_serverless:
if not data.get("databricks_connect_cluster_id"):
if t.TYPE_CHECKING:
assert http_path is not None
data["databricks_connect_cluster_id"] = http_path.split("/")[-1]

if auth_type:
from databricks.sql.auth.auth import AuthType
Expand Down
10 changes: 2 additions & 8 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@
from sqlmesh.utils import columns_to_types_all_known, random_id
from sqlmesh.utils.connection_pool import create_connection_pool
from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column
from sqlmesh.utils.errors import (
SQLMeshError,
UnsupportedCatalogOperationError,
MissingDefaultCatalogError,
)
from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError
from sqlmesh.utils.pandas import columns_to_types_from_df

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -190,9 +186,7 @@ def default_catalog(self) -> t.Optional[str]:
return None
default_catalog = self._default_catalog or self.get_current_catalog()
if not default_catalog:
raise MissingDefaultCatalogError(
"Could not determine a default catalog despite it being supported."
)
raise SQLMeshError("Could not determine a default catalog despite it being supported.")
return default_catalog

@property
Expand Down
86 changes: 34 additions & 52 deletions sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
import typing as t

import pandas as pd
Expand All @@ -16,7 +17,7 @@
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
from sqlmesh.utils.errors import SQLMeshError

if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, TableName
Expand Down Expand Up @@ -91,6 +92,17 @@ def _use_spark_session(self) -> bool:
)
)

@property
def use_serverless(self) -> bool:
from sqlmesh import RuntimeEnv
from sqlmesh.utils import str_to_bool

if not self._use_spark_session:
return False
return (
RuntimeEnv.get().is_databricks and str_to_bool(os.environ.get("IS_SERVERLESS", "False"))
) or bool(self._extra_config["databricks_connect_use_serverless"])

@property
def is_spark_session_cursor(self) -> bool:
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor
Expand All @@ -112,17 +124,12 @@ def spark(self) -> PySparkSession:
from databricks.connect import DatabricksSession

if self._spark is None:
connect_kwargs = dict(
host=self._extra_config["databricks_connect_server_hostname"],
token=self._extra_config["databricks_connect_access_token"],
)
if "databricks_connect_use_serverless" in self._extra_config:
connect_kwargs["serverless"] = True
else:
connect_kwargs["cluster_id"] = self._extra_config["databricks_connect_cluster_id"]

self._spark = (
DatabricksSession.builder.remote(**connect_kwargs)
DatabricksSession.builder.remote(
host=self._extra_config["databricks_connect_server_hostname"],
token=self._extra_config["databricks_connect_access_token"],
cluster_id=self._extra_config["databricks_connect_cluster_id"],
)
.userAgent("sqlmesh")
.getOrCreate()
)
Expand Down Expand Up @@ -150,8 +157,14 @@ def _df_to_source_queries(

def query_factory() -> Query:
temp_table = self._get_temp_table(target_table or "spark", table_only=True)
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect))
self._connection_pool.set_attribute("requires_spark_session_temp_objects", True)
if self.use_serverless:
# Global temp views are not supported on Databricks Serverless
# This also means we can't mix Python SQL Connection and DB Connect since they wouldn't
# share the same temp objects.
df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
else:
df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore
temp_table.set("db", "global_temp")
return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table)

if self._use_spark_session:
Expand Down Expand Up @@ -186,50 +199,28 @@ def fetchdf(
return df.toPandas()
return df

def _execute(
self,
sql: str,
**kwargs: t.Any,
) -> None:
if self._connection_pool.get_attribute("requires_spark_session_temp_objects"):
self._fetch_native_df(sql)
else:
super()._execute(sql, **kwargs)

def _end_session(self) -> None:
"""End the existing session."""
self._connection_pool.set_attribute("requires_spark_session_temp_objects", False)

def get_current_catalog(self) -> t.Optional[str]:
pyspark_catalog = None
sql_connector_catalog = None
# Update the Dataframe API if we have a spark session
if self._use_spark_session:
from py4j.protocol import Py4JError
from pyspark.errors.exceptions.connect import SparkConnectGrpcException

try:
# Note: Spark 3.4+ Only API
pyspark_catalog = self.spark.catalog.currentCatalog()
return self.spark.catalog.currentCatalog()
except (Py4JError, SparkConnectGrpcException):
pass
if not self.is_spark_session_cursor:
result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
sql_connector_catalog = result[0] if result else None
if (
self._use_spark_session
and not self.is_spark_session_cursor
and pyspark_catalog != sql_connector_catalog
):
raise SQLMeshError(
f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same."
)
return pyspark_catalog or sql_connector_catalog
result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
if result:
return result[0]
return None

def set_current_catalog(self, catalog_name: str) -> None:
# Since Databricks splits commands across the Dataframe API and the SQL Connector
# (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
# are set to the same catalog since they maintain their default catalog separately
# are set to the same catalog since they maintain their default catalog seperately
self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
# Update the Dataframe API is we have a spark session
if self._use_spark_session:
from py4j.protocol import Py4JError
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
Expand Down Expand Up @@ -266,15 +257,6 @@ def clone_table(
def wap_supported(self, table_name: TableName) -> bool:
return False

@property
def default_catalog(self) -> t.Optional[str]:
try:
return super().default_catalog
except MissingDefaultCatalogError as e:
raise MissingDefaultCatalogError(
"Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details"
) from e

def _build_table_properties_exp(
self,
catalog_name: t.Optional[str] = None,
Expand Down
4 changes: 0 additions & 4 deletions sqlmesh/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ class PythonModelEvalError(SQLMeshError):
pass


class MissingDefaultCatalogError(SQLMeshError):
pass


def raise_config_error(
msg: str,
location: t.Optional[str | Path] = None,
Expand Down

0 comments on commit c24ee34

Please sign in to comment.