diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 52ba9fe51..c788d0d03 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -14,6 +14,8 @@ |----------------------|------------------------------------------------|-------------------------| | `plotly.HTMLDataset` | A dataset for saving a `plotly` figure as HTML | `kedro_datasets.plotly` | +* Added `spark_mode` parameter in spark based datasets to decide whether to use databricks-connect or not + ## Bug fixes and other changes * Refactored all datasets to set `fs_args` defaults in the same way as `load_args` and `save_args` and not have hardcoded values in the save methods. * Fixed bug related to loading/saving models from/to remote storage using `TensorFlowModelDataset`. diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 677db0d56..d0ff464ff 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -226,6 +226,7 @@ def __init__( # noqa: PLR0913 dataframe_type: str = "spark", primary_key: str | list[str] | None = None, version: Version | None = None, + spark_mode: str = None, # the following parameters are used by project hooks # to create or update table properties schema: dict[str, Any] | None = None, @@ -252,6 +253,8 @@ def __init__( # noqa: PLR0913 Can be in the form of a list. Defaults to None. version: kedro.io.core.Version instance to load the data. Defaults to None. + spark_mode: The mode to initialize the Spark session. Can be 'spark', + or 'databricks-connect'. Defaults to None. schema: the schema of the table in JSON form. Dataframes will be truncated to match the schema if provided. Used by the hooks to create the table if the schema is provided @@ -280,6 +283,7 @@ def __init__( # noqa: PLR0913 ) self._version = version + self.spark_mode = spark_mode self.metadata = metadata super().__init__( @@ -303,7 +307,7 @@ def _load(self) -> DataFrame | pd.DataFrame: if self._version and self._version.load >= 0: try: data = ( - _get_spark() + _get_spark(self.spark_mode) .read.format("delta") .option("versionAsOf", self._version.load) .table(self._table.full_table_location()) @@ -311,7 +315,7 @@ def _load(self) -> DataFrame | pd.DataFrame: except Exception as exc: raise VersionNotFoundError(self._version.load) from exc else: - data = _get_spark().table(self._table.full_table_location()) + data = _get_spark(self.spark_mode).table(self._table.full_table_location()) if self._table.dataframe_type == "pandas": data = data.toPandas() return data @@ -349,7 +353,9 @@ def _save_upsert(self, update_data: DataFrame) -> None: update_data (DataFrame): the Spark dataframe to upsert """ if self._exists(): - base_data = _get_spark().table(self._table.full_table_location()) + base_data = _get_spark(self.spark_mode).table( + self._table.full_table_location() + ) base_columns = base_data.columns update_columns = update_data.columns @@ -372,11 +378,13 @@ def _save_upsert(self, update_data: DataFrame) -> None: ) update_data.createOrReplaceTempView("update") - _get_spark().conf.set("fullTableAddress", self._table.full_table_location()) - _get_spark().conf.set("whereExpr", where_expr) + _get_spark(self.spark_mode).conf.set( + "fullTableAddress", self._table.full_table_location() + ) + _get_spark(self.spark_mode).conf.set("whereExpr", where_expr) upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" - _get_spark().sql(upsert_sql) + _get_spark(self.spark_mode).sql(upsert_sql) else: self._save_append(update_data) @@ -399,13 +407,13 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None: if schema: cols = schema.fieldNames() if self._table.dataframe_type == "pandas": - data = _get_spark().createDataFrame( + data = _get_spark(self.spark_mode).createDataFrame( data.loc[:, cols], schema=self._table.schema() ) else: data = data.select(*cols) elif self._table.dataframe_type == "pandas": - data = _get_spark().createDataFrame(data) + data = _get_spark(self.spark_mode).createDataFrame(data) if self._table.write_mode == "overwrite": self._save_overwrite(data) elif self._table.write_mode == "upsert": @@ -440,7 +448,7 @@ def _exists(self) -> bool: """ if self._table.catalog: try: - _get_spark().sql(f"USE CATALOG `{self._table.catalog}`") + _get_spark(self.spark_mode).sql(f"USE CATALOG `{self._table.catalog}`") except (ParseException, AnalysisException) as exc: logger.warning( "catalog %s not found or unity not enabled. Error message: %s", @@ -449,7 +457,7 @@ def _exists(self) -> bool: ) try: return ( - _get_spark() + _get_spark(self.spark_mode) .sql(f"SHOW TABLES IN `{self._table.database}`") .filter(f"tableName = '{self._table.table}'") .count() diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index e077d6390..5ce947cbc 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -32,26 +32,46 @@ logger = logging.getLogger(__name__) -def _get_spark() -> Any: +def _get_spark(spark_mode: str = None) -> Any: """ - Returns the SparkSession. In case databricks-connect is available we use it for - extended configuration mechanisms and notebook compatibility, - otherwise we use classic pyspark. + Returns the appropriate Spark session based on the spark_mode. + Supports 'spark' and 'databricks-connect'. + If spark_mode is None, it will use 'databricks-connect' if installed, otherwise 'spark'. + + If you want to use spark connect you can set the SPARK_REMOTE environment variable + and use the 'spark' mode. + + For configuring authentication you should use the corresponding environment variables + + Args: + spark_mode: The mode to initialize the Spark session. Can be 'spark', + or 'databricks-connect'. Defaults to None. + + Returns: + A Spark session appropriate to the selected mode. """ - try: - # When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2) - # the remote session is instantiated using the databricks module - # If the databricks-connect module is installed, we use a remote session - from databricks.connect import DatabricksSession - # We can't test this as there's no Databricks test env available - spark = DatabricksSession.builder.getOrCreate() # pragma: no cover + # If the spark_mode is not specified, we will try to infer it + if not spark_mode: + # Try to use databricks-connect if available + try: + from databricks.connect import DatabricksSession + + spark_mode = "databricks-connect" + except ImportError: + spark_mode = "spark" - except ImportError: - # For "normal" spark sessions that don't use databricks-connect - # we get spark normally + if spark_mode == "spark": spark = SparkSession.builder.getOrCreate() + elif spark_mode == "databricks-connect": + from databricks.connect import DatabricksSession + + spark = DatabricksSession.builder.getOrCreate() + + else: + raise ValueError(f"Invalid spark_mode: {spark_mode}") + return spark @@ -272,6 +292,7 @@ def __init__( # noqa: PLR0913 save_args: dict[str, Any] | None = None, version: Version | None = None, credentials: dict[str, Any] | None = None, + spark_mode: str = None, metadata: dict[str, Any] | None = None, ) -> None: """Creates a new instance of ``SparkDataset``. @@ -304,14 +325,19 @@ def __init__( # noqa: PLR0913 ``key``, ``secret``, if ``filepath`` prefix is ``s3a://`` or ``s3n://``. Optional keyword arguments passed to ``hdfs.client.InsecureClient`` if ``filepath`` prefix is ``hdfs://``. Ignored otherwise. + spark_mode: The mode to initialize the Spark session. Can be 'spark', + or 'databricks-connect'. Defaults to None. metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. + """ + credentials = deepcopy(credentials) or {} fs_prefix, filepath = _split_filepath(filepath) path = PurePosixPath(filepath) exists_function = None glob_function = None + self.spark_mode = spark_mode self.metadata = metadata if ( @@ -349,7 +375,7 @@ def __init__( # noqa: PLR0913 elif filepath.startswith("/dbfs/"): # dbfs add prefix to Spark path by default # See https://github.com/kedro-org/kedro-plugins/issues/117 - dbutils = _get_dbutils(_get_spark()) + dbutils = _get_dbutils(_get_spark(self.spark_mode)) if dbutils: glob_function = partial(_dbfs_glob, dbutils=dbutils) exists_function = partial(_dbfs_exists, dbutils=dbutils) @@ -415,7 +441,7 @@ def _describe(self) -> dict[str, Any]: def _load(self) -> DataFrame: load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) - read_obj = _get_spark().read + read_obj = _get_spark(self.spark_mode).read # Pass schema if defined if self._schema: @@ -431,7 +457,7 @@ def _exists(self) -> bool: load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) try: - _get_spark().read.load(load_path, self._file_format) + _get_spark(self.spark_mode).read.load(load_path, self._file_format) except AnalysisException as exception: # `AnalysisException.desc` is deprecated with pyspark >= 3.4 message = exception.desc if hasattr(exception, "desc") else str(exception)