Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added spark_mode parameter to allow switching between spark and databricks connect #862

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
|----------------------|------------------------------------------------|-------------------------|
| `plotly.HTMLDataset` | A dataset for saving a `plotly` figure as HTML | `kedro_datasets.plotly` |

* Added `spark_mode` parameter in spark based datasets to decide whether to use databricks-connect or not

## Bug fixes and other changes
* Refactored all datasets to set `fs_args` defaults in the same way as `load_args` and `save_args` and not have hardcoded values in the save methods.
* Fixed bug related to loading/saving models from/to remote storage using `TensorFlowModelDataset`.
Expand Down
28 changes: 18 additions & 10 deletions kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__( # noqa: PLR0913
dataframe_type: str = "spark",
primary_key: str | list[str] | None = None,
version: Version | None = None,
spark_mode: str = None,
# the following parameters are used by project hooks
# to create or update table properties
schema: dict[str, Any] | None = None,
Expand All @@ -252,6 +253,8 @@ def __init__( # noqa: PLR0913
Can be in the form of a list. Defaults to None.
version: kedro.io.core.Version instance to load the data.
Defaults to None.
spark_mode: The mode to initialize the Spark session. Can be 'spark',
or 'databricks-connect'. Defaults to None.
schema: the schema of the table in JSON form.
Dataframes will be truncated to match the schema if provided.
Used by the hooks to create the table if the schema is provided
Expand Down Expand Up @@ -280,6 +283,7 @@ def __init__( # noqa: PLR0913
)

self._version = version
self.spark_mode = spark_mode
self.metadata = metadata

super().__init__(
Expand All @@ -303,15 +307,15 @@ def _load(self) -> DataFrame | pd.DataFrame:
if self._version and self._version.load >= 0:
try:
data = (
_get_spark()
_get_spark(self.spark_mode)
.read.format("delta")
.option("versionAsOf", self._version.load)
.table(self._table.full_table_location())
)
except Exception as exc:
raise VersionNotFoundError(self._version.load) from exc
else:
data = _get_spark().table(self._table.full_table_location())
data = _get_spark(self.spark_mode).table(self._table.full_table_location())
if self._table.dataframe_type == "pandas":
data = data.toPandas()
return data
Expand Down Expand Up @@ -349,7 +353,9 @@ def _save_upsert(self, update_data: DataFrame) -> None:
update_data (DataFrame): the Spark dataframe to upsert
"""
if self._exists():
base_data = _get_spark().table(self._table.full_table_location())
base_data = _get_spark(self.spark_mode).table(
self._table.full_table_location()
)
base_columns = base_data.columns
update_columns = update_data.columns

Expand All @@ -372,11 +378,13 @@ def _save_upsert(self, update_data: DataFrame) -> None:
)

update_data.createOrReplaceTempView("update")
_get_spark().conf.set("fullTableAddress", self._table.full_table_location())
_get_spark().conf.set("whereExpr", where_expr)
_get_spark(self.spark_mode).conf.set(
"fullTableAddress", self._table.full_table_location()
)
_get_spark(self.spark_mode).conf.set("whereExpr", where_expr)
upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr}
WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *"""
_get_spark().sql(upsert_sql)
_get_spark(self.spark_mode).sql(upsert_sql)
else:
self._save_append(update_data)

Expand All @@ -399,13 +407,13 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None:
if schema:
cols = schema.fieldNames()
if self._table.dataframe_type == "pandas":
data = _get_spark().createDataFrame(
data = _get_spark(self.spark_mode).createDataFrame(
data.loc[:, cols], schema=self._table.schema()
)
else:
data = data.select(*cols)
elif self._table.dataframe_type == "pandas":
data = _get_spark().createDataFrame(data)
data = _get_spark(self.spark_mode).createDataFrame(data)
if self._table.write_mode == "overwrite":
self._save_overwrite(data)
elif self._table.write_mode == "upsert":
Expand Down Expand Up @@ -440,7 +448,7 @@ def _exists(self) -> bool:
"""
if self._table.catalog:
try:
_get_spark().sql(f"USE CATALOG `{self._table.catalog}`")
_get_spark(self.spark_mode).sql(f"USE CATALOG `{self._table.catalog}`")
except (ParseException, AnalysisException) as exc:
logger.warning(
"catalog %s not found or unity not enabled. Error message: %s",
Expand All @@ -449,7 +457,7 @@ def _exists(self) -> bool:
)
try:
return (
_get_spark()
_get_spark(self.spark_mode)
.sql(f"SHOW TABLES IN `{self._table.database}`")
.filter(f"tableName = '{self._table.table}'")
.count()
Expand Down
60 changes: 43 additions & 17 deletions kedro-datasets/kedro_datasets/spark/spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,46 @@
logger = logging.getLogger(__name__)


def _get_spark() -> Any:
def _get_spark(spark_mode: str = None) -> Any:
"""
Returns the SparkSession. In case databricks-connect is available we use it for
extended configuration mechanisms and notebook compatibility,
otherwise we use classic pyspark.
Returns the appropriate Spark session based on the spark_mode.
Supports 'spark' and 'databricks-connect'.
If spark_mode is None, it will use 'databricks-connect' if installed, otherwise 'spark'.

If you want to use spark connect you can set the SPARK_REMOTE environment variable
and use the 'spark' mode.

For configuring authentication you should use the corresponding environment variables

Args:
spark_mode: The mode to initialize the Spark session. Can be 'spark',
or 'databricks-connect'. Defaults to None.

Returns:
A Spark session appropriate to the selected mode.
"""
try:
# When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2)
# the remote session is instantiated using the databricks module
# If the databricks-connect module is installed, we use a remote session
from databricks.connect import DatabricksSession

# We can't test this as there's no Databricks test env available
spark = DatabricksSession.builder.getOrCreate() # pragma: no cover
# If the spark_mode is not specified, we will try to infer it
if not spark_mode:
# Try to use databricks-connect if available
try:
from databricks.connect import DatabricksSession

spark_mode = "databricks-connect"
except ImportError:
spark_mode = "spark"

except ImportError:
# For "normal" spark sessions that don't use databricks-connect
# we get spark normally
if spark_mode == "spark":
spark = SparkSession.builder.getOrCreate()

elif spark_mode == "databricks-connect":
from databricks.connect import DatabricksSession

spark = DatabricksSession.builder.getOrCreate()

else:
raise ValueError(f"Invalid spark_mode: {spark_mode}")

return spark


Expand Down Expand Up @@ -272,6 +292,7 @@ def __init__( # noqa: PLR0913
save_args: dict[str, Any] | None = None,
version: Version | None = None,
credentials: dict[str, Any] | None = None,
spark_mode: str = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Creates a new instance of ``SparkDataset``.
Expand Down Expand Up @@ -304,14 +325,19 @@ def __init__( # noqa: PLR0913
``key``, ``secret``, if ``filepath`` prefix is ``s3a://`` or ``s3n://``.
Optional keyword arguments passed to ``hdfs.client.InsecureClient``
if ``filepath`` prefix is ``hdfs://``. Ignored otherwise.
spark_mode: The mode to initialize the Spark session. Can be 'spark',
or 'databricks-connect'. Defaults to None.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.

"""

credentials = deepcopy(credentials) or {}
fs_prefix, filepath = _split_filepath(filepath)
path = PurePosixPath(filepath)
exists_function = None
glob_function = None
self.spark_mode = spark_mode
self.metadata = metadata

if (
Expand Down Expand Up @@ -349,7 +375,7 @@ def __init__( # noqa: PLR0913
elif filepath.startswith("/dbfs/"):
# dbfs add prefix to Spark path by default
# See https://github.com/kedro-org/kedro-plugins/issues/117
dbutils = _get_dbutils(_get_spark())
dbutils = _get_dbutils(_get_spark(self.spark_mode))
if dbutils:
glob_function = partial(_dbfs_glob, dbutils=dbutils)
exists_function = partial(_dbfs_exists, dbutils=dbutils)
Expand Down Expand Up @@ -415,7 +441,7 @@ def _describe(self) -> dict[str, Any]:

def _load(self) -> DataFrame:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))
read_obj = _get_spark().read
read_obj = _get_spark(self.spark_mode).read

# Pass schema if defined
if self._schema:
Expand All @@ -431,7 +457,7 @@ def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path()))

try:
_get_spark().read.load(load_path, self._file_format)
_get_spark(self.spark_mode).read.load(load_path, self._file_format)
except AnalysisException as exception:
# `AnalysisException.desc` is deprecated with pyspark >= 3.4
message = exception.desc if hasattr(exception, "desc") else str(exception)
Expand Down
Loading