Skip to content

Commit

Permalink
feat(datasets): Improved Dependency Management for Spark-based Datase…
Browse files Browse the repository at this point in the history
…ts (#911)

* added the skeleton for the utils sub pkg

Signed-off-by: Minura Punchihewa <[email protected]>

* moved the utility funcs from spark_dataset to relevant modules in _utils

Signed-off-by: Minura Punchihewa <[email protected]>

* updated the use of utility funcs in spark_dataset

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed import in databricks_utils

Signed-off-by: Minura Punchihewa <[email protected]>

* renamed _strip_dbfs_prefix to strip_dbfs_prefix

Signed-off-by: Minura Punchihewa <[email protected]>

* updated the other modules that import from spark_dataset to use _utils

Signed-off-by: Minura Punchihewa <[email protected]>

* updated the use of strip_dbfs_prefix in spark_dataset

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed lint issues

Signed-off-by: Minura Punchihewa <[email protected]>

* removed the base deps for spark, pandas and delta from databricks datasets

Signed-off-by: Minura Punchihewa <[email protected]>

* moved the file based utility funcs to databricks_utils

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed the imports of the file based utility funcs

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed lint issues

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed the use of _get_spark() in tests

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed uses of databricks utils in tests

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed more tests

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed more lint issues

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed more tests

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed more tests

Signed-off-by: Minura Punchihewa <[email protected]>

* improved type hints for spark & databricks utility funcs

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed more lint issues

Signed-off-by: Minura Punchihewa <[email protected]>

* further improved type hints for utility funcs

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed a couple of incorrect type hints

Signed-off-by: Minura Punchihewa <[email protected]>

* fixed several incorrect type hints

Signed-off-by: Minura Punchihewa <[email protected]>

* updated the release notes

Signed-off-by: Minura Punchihewa <[email protected]>

* Reorder release notes

Signed-off-by: Merel Theisen <[email protected]>

---------

Signed-off-by: Minura Punchihewa <[email protected]>
Signed-off-by: Merel Theisen <[email protected]>
Co-authored-by: Nok Lam Chan <[email protected]>
Co-authored-by: Merel Theisen <[email protected]>
Co-authored-by: Merel Theisen <[email protected]>
  • Loading branch information
4 people authored Nov 14, 2024
1 parent 57f6279 commit 07aef5a
Show file tree
Hide file tree
Showing 15 changed files with 220 additions and 200 deletions.
2 changes: 1 addition & 1 deletion kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Upcoming Release 6.0.0

## Major features and improvements

- Added functionality to save Pandas DataFrame directly to Snowflake, facilitating seemless `.csv` ingestion
- Added Python 3.9, 3.10 and 3.11 support for SnowflakeTableDataset
- Added the following new **experimental** datasets:
Expand All @@ -13,6 +12,7 @@

## Bug fixes and other changes
- Implemented Snowflake's (local testing framework)[https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally] for testing purposes
- Improved the dependency management for Spark-based datasets by refactoring the Spark and Databricks utility functions used across the datasets.

## Breaking Changes
- Demoted `video.VideoDataset` from core to experimental dataset.
Expand Down
Empty file.
105 changes: 105 additions & 0 deletions kedro-datasets/kedro_datasets/_utils/databricks_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
from fnmatch import fnmatch
from pathlib import PurePosixPath
from typing import TYPE_CHECKING, Union

from pyspark.sql import SparkSession

if TYPE_CHECKING:
from databricks.connect import DatabricksSession
from pyspark.dbutils import DBUtils


def parse_glob_pattern(pattern: str) -> str:
special = ("*", "?", "[")
clean = []
for part in pattern.split("/"):
if any(char in part for char in special):
break
clean.append(part)
return "/".join(clean)


def split_filepath(filepath: str | os.PathLike) -> tuple[str, str]:
split_ = str(filepath).split("://", 1)
if len(split_) == 2: # noqa: PLR2004
return split_[0] + "://", split_[1]
return "", split_[0]


def strip_dbfs_prefix(path: str, prefix: str = "/dbfs") -> str:
return path[len(prefix) :] if path.startswith(prefix) else path


def dbfs_glob(pattern: str, dbutils: "DBUtils") -> list[str]:
"""Perform a custom glob search in DBFS using the provided pattern.
It is assumed that version paths are managed by Kedro only.
Args:
pattern: Glob pattern to search for.
dbutils: dbutils instance to operate with DBFS.
Returns:
List of DBFS paths prefixed with '/dbfs' that satisfy the glob pattern.
"""
pattern = strip_dbfs_prefix(pattern)
prefix = parse_glob_pattern(pattern)
matched = set()
filename = pattern.split("/")[-1]

for file_info in dbutils.fs.ls(prefix):
if file_info.isDir():
path = str(
PurePosixPath(strip_dbfs_prefix(file_info.path, "dbfs:")) / filename
)
if fnmatch(path, pattern):
path = "/dbfs" + path
matched.add(path)
return sorted(matched)


def get_dbutils(spark: Union[SparkSession, "DatabricksSession"]) -> "DBUtils":
"""Get the instance of 'dbutils' or None if the one could not be found."""
dbutils = globals().get("dbutils")
if dbutils:
return dbutils

try:
from pyspark.dbutils import DBUtils

dbutils = DBUtils(spark)
except ImportError:
try:
import IPython
except ImportError:
pass
else:
ipython = IPython.get_ipython()
dbutils = ipython.user_ns.get("dbutils") if ipython else None

return dbutils


def dbfs_exists(pattern: str, dbutils: "DBUtils") -> bool:
"""Perform an `ls` list operation in DBFS using the provided pattern.
It is assumed that version paths are managed by Kedro.
Broad `Exception` is present due to `dbutils.fs.ExecutionError` that
cannot be imported directly.
Args:
pattern: Filepath to search for.
dbutils: dbutils instance to operate with DBFS.
Returns:
Boolean value if filepath exists.
"""
pattern = strip_dbfs_prefix(pattern)
file = parse_glob_pattern(pattern)
try:
dbutils.fs.ls(file)
return True
except Exception:
return False


def deployed_on_databricks() -> bool:
"""Check if running on Databricks."""
return "DATABRICKS_RUNTIME_VERSION" in os.environ
29 changes: 29 additions & 0 deletions kedro-datasets/kedro_datasets/_utils/spark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import TYPE_CHECKING, Union

from pyspark.sql import SparkSession

if TYPE_CHECKING:
from databricks.connect import DatabricksSession


def get_spark() -> Union[SparkSession, "DatabricksSession"]:
"""
Returns the SparkSession. In case databricks-connect is available we use it for
extended configuration mechanisms and notebook compatibility,
otherwise we use classic pyspark.
"""
try:
# When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2)
# the remote session is instantiated using the databricks module
# If the databricks-connect module is installed, we use a remote session
from databricks.connect import DatabricksSession

# We can't test this as there's no Databricks test env available
spark = DatabricksSession.builder.getOrCreate() # pragma: no cover

except ImportError:
# For "normal" spark sessions that don't use databricks-connect
# we get spark normally
spark = SparkSession.builder.getOrCreate()

return spark
22 changes: 11 additions & 11 deletions kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyspark.sql.types import StructType
from pyspark.sql.utils import AnalysisException, ParseException

from kedro_datasets.spark.spark_dataset import _get_spark
from kedro_datasets._utils.spark_utils import get_spark

logger = logging.getLogger(__name__)
pd.DataFrame.iteritems = pd.DataFrame.items
Expand Down Expand Up @@ -183,7 +183,7 @@ def exists(self) -> bool:
"""
if self.catalog:
try:
_get_spark().sql(f"USE CATALOG `{self.catalog}`")
get_spark().sql(f"USE CATALOG `{self.catalog}`")
except (ParseException, AnalysisException) as exc:
logger.warning(
"catalog %s not found or unity not enabled. Error message: %s",
Expand All @@ -192,7 +192,7 @@ def exists(self) -> bool:
)
try:
return (
_get_spark()
get_spark()
.sql(f"SHOW TABLES IN `{self.database}`")
.filter(f"tableName = '{self.table}'")
.count()
Expand Down Expand Up @@ -359,15 +359,15 @@ def _load(self) -> DataFrame | pd.DataFrame:
if self._version and self._version.load >= 0:
try:
data = (
_get_spark()
get_spark()
.read.format("delta")
.option("versionAsOf", self._version.load)
.table(self._table.full_table_location())
)
except Exception as exc:
raise VersionNotFoundError(self._version.load) from exc
else:
data = _get_spark().table(self._table.full_table_location())
data = get_spark().table(self._table.full_table_location())
if self._table.dataframe_type == "pandas":
data = data.toPandas()
return data
Expand All @@ -391,13 +391,13 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None:
if schema:
cols = schema.fieldNames()
if self._table.dataframe_type == "pandas":
data = _get_spark().createDataFrame(
data = get_spark().createDataFrame(
data.loc[:, cols], schema=self._table.schema()
)
else:
data = data.select(*cols)
elif self._table.dataframe_type == "pandas":
data = _get_spark().createDataFrame(data)
data = get_spark().createDataFrame(data)

method = getattr(self, f"_save_{self._table.write_mode}", None)
if method:
Expand Down Expand Up @@ -456,7 +456,7 @@ def _save_upsert(self, update_data: DataFrame) -> None:
update_data (DataFrame): The Spark dataframe to upsert.
"""
if self._exists():
base_data = _get_spark().table(self._table.full_table_location())
base_data = get_spark().table(self._table.full_table_location())
base_columns = base_data.columns
update_columns = update_data.columns

Expand All @@ -479,11 +479,11 @@ def _save_upsert(self, update_data: DataFrame) -> None:
)

update_data.createOrReplaceTempView("update")
_get_spark().conf.set("fullTableAddress", self._table.full_table_location())
_get_spark().conf.set("whereExpr", where_expr)
get_spark().conf.set("fullTableAddress", self._table.full_table_location())
get_spark().conf.set("whereExpr", where_expr)
upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr}
WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *"""
_get_spark().sql(upsert_sql)
get_spark().sql(upsert_sql)
else:
self._save_append(update_data)

Expand Down
15 changes: 6 additions & 9 deletions kedro-datasets/kedro_datasets/spark/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
from kedro.io.core import AbstractDataset, DatasetError
from pyspark.sql.utils import AnalysisException

from kedro_datasets.spark.spark_dataset import (
_get_spark,
_split_filepath,
_strip_dbfs_prefix,
)
from kedro_datasets._utils.databricks_utils import split_filepath, strip_dbfs_prefix
from kedro_datasets._utils.spark_utils import get_spark


class DeltaTableDataset(AbstractDataset[None, DeltaTable]):
Expand Down Expand Up @@ -81,24 +78,24 @@ def __init__(
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
fs_prefix, filepath = _split_filepath(filepath)
fs_prefix, filepath = split_filepath(filepath)

self._fs_prefix = fs_prefix
self._filepath = PurePosixPath(filepath)
self.metadata = metadata

def load(self) -> DeltaTable:
load_path = self._fs_prefix + str(self._filepath)
return DeltaTable.forPath(_get_spark(), load_path)
return DeltaTable.forPath(get_spark(), load_path)

def save(self, data: None) -> NoReturn:
raise DatasetError(f"{self.__class__.__name__} is a read only dataset type")

def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))
load_path = strip_dbfs_prefix(self._fs_prefix + str(self._filepath))

try:
_get_spark().read.load(path=load_path, format="delta")
get_spark().read.load(path=load_path, format="delta")
except AnalysisException as exception:
# `AnalysisException.desc` is deprecated with pyspark >= 3.4
message = exception.desc if hasattr(exception, "desc") else str(exception)
Expand Down
Loading

0 comments on commit 07aef5a

Please sign in to comment.