Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets): Improved Dependency Management for Spark-based Datasets #911

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e4a9e47
added the skeleton for the utils sub pkg
MinuraPunchihewa Oct 26, 2024
1ab51ad
moved the utility funcs from spark_dataset to relevant modules in _utils
MinuraPunchihewa Oct 26, 2024
0187cf1
updated the use of utility funcs in spark_dataset
MinuraPunchihewa Oct 26, 2024
84675bf
fixed import in databricks_utils
MinuraPunchihewa Oct 26, 2024
ae3829d
renamed _strip_dbfs_prefix to strip_dbfs_prefix
MinuraPunchihewa Oct 26, 2024
c5bb7c5
updated the other modules that import from spark_dataset to use _utils
MinuraPunchihewa Oct 26, 2024
48df490
updated the use of strip_dbfs_prefix in spark_dataset
MinuraPunchihewa Oct 26, 2024
f562190
fixed lint issues
MinuraPunchihewa Oct 26, 2024
0a89638
removed the base deps for spark, pandas and delta from databricks dat…
MinuraPunchihewa Oct 27, 2024
72e1925
Merge branch 'main' into feature/improve_spark_dependencies
MinuraPunchihewa Oct 30, 2024
406adfa
Merge branch 'main' into feature/improve_spark_dependencies
noklam Oct 31, 2024
8f7842f
moved the file based utility funcs to databricks_utils
MinuraPunchihewa Oct 31, 2024
6e50319
fixed the imports of the file based utility funcs
MinuraPunchihewa Oct 31, 2024
63834a9
fixed lint issues
MinuraPunchihewa Oct 31, 2024
6df7741
fixed the use of _get_spark() in tests
MinuraPunchihewa Oct 31, 2024
80f9fb5
fixed uses of databricks utils in tests
MinuraPunchihewa Oct 31, 2024
03da501
fixed more tests
MinuraPunchihewa Oct 31, 2024
0534907
fixed more lint issues
MinuraPunchihewa Oct 31, 2024
65d1c30
fixed more tests
MinuraPunchihewa Oct 31, 2024
f7f0b5e
fixed more tests
MinuraPunchihewa Oct 31, 2024
608d2a2
improved type hints for spark & databricks utility funcs
MinuraPunchihewa Oct 31, 2024
e28cdf1
fixed more lint issues
MinuraPunchihewa Oct 31, 2024
781a7ef
further improved type hints for utility funcs
MinuraPunchihewa Oct 31, 2024
768e8ae
fixed a couple of incorrect type hints
MinuraPunchihewa Oct 31, 2024
c44e32a
fixed several incorrect type hints
MinuraPunchihewa Oct 31, 2024
4b31967
Merge branch 'main' into feature/improve_spark_dependencies
noklam Nov 7, 2024
4c427cc
Merge branch 'main' into feature/improve_spark_dependencies
merelcht Nov 13, 2024
7334e1b
Merge branch 'main' into feature/improve_spark_dependencies
merelcht Nov 13, 2024
d7a385d
updated the release notes
MinuraPunchihewa Nov 13, 2024
0f65bdb
Reorder release notes
merelcht Nov 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Upcoming Release 6.0.0

## Major features and improvements

- Added functionality to save Pandas DataFrame directly to Snowflake, facilitating seemless `.csv` ingestion
- Added Python 3.9, 3.10 and 3.11 support for SnowflakeTableDataset
- Added the following new **experimental** datasets:
Expand All @@ -13,6 +12,7 @@

## Bug fixes and other changes
- Implemented Snowflake's (local testing framework)[https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally] for testing purposes
- Improved the dependency management for Spark-based datasets by refactoring the Spark and Databricks utility functions used across the datasets.

## Breaking Changes
- Demoted `video.VideoDataset` from core to experimental dataset.
Expand Down
Empty file.
105 changes: 105 additions & 0 deletions kedro-datasets/kedro_datasets/_utils/databricks_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
from fnmatch import fnmatch
from pathlib import PurePosixPath
from typing import TYPE_CHECKING, Union

from pyspark.sql import SparkSession

if TYPE_CHECKING:
from databricks.connect import DatabricksSession
from pyspark.dbutils import DBUtils


def parse_glob_pattern(pattern: str) -> str:
special = ("*", "?", "[")
clean = []
for part in pattern.split("/"):
if any(char in part for char in special):
break
clean.append(part)
return "/".join(clean)


def split_filepath(filepath: str | os.PathLike) -> tuple[str, str]:
split_ = str(filepath).split("://", 1)
if len(split_) == 2: # noqa: PLR2004
return split_[0] + "://", split_[1]
return "", split_[0]


def strip_dbfs_prefix(path: str, prefix: str = "/dbfs") -> str:
return path[len(prefix) :] if path.startswith(prefix) else path


def dbfs_glob(pattern: str, dbutils: "DBUtils") -> list[str]:
"""Perform a custom glob search in DBFS using the provided pattern.
It is assumed that version paths are managed by Kedro only.

Args:
pattern: Glob pattern to search for.
dbutils: dbutils instance to operate with DBFS.

Returns:
List of DBFS paths prefixed with '/dbfs' that satisfy the glob pattern.
"""
pattern = strip_dbfs_prefix(pattern)
prefix = parse_glob_pattern(pattern)
matched = set()
filename = pattern.split("/")[-1]

for file_info in dbutils.fs.ls(prefix):
if file_info.isDir():
path = str(
PurePosixPath(strip_dbfs_prefix(file_info.path, "dbfs:")) / filename
)
if fnmatch(path, pattern):
path = "/dbfs" + path
matched.add(path)
return sorted(matched)


def get_dbutils(spark: Union[SparkSession, "DatabricksSession"]) -> "DBUtils":
"""Get the instance of 'dbutils' or None if the one could not be found."""
dbutils = globals().get("dbutils")
if dbutils:
return dbutils

try:
from pyspark.dbutils import DBUtils

dbutils = DBUtils(spark)
except ImportError:
try:
import IPython
except ImportError:
pass
else:
ipython = IPython.get_ipython()
dbutils = ipython.user_ns.get("dbutils") if ipython else None

return dbutils


def dbfs_exists(pattern: str, dbutils: "DBUtils") -> bool:
"""Perform an `ls` list operation in DBFS using the provided pattern.
It is assumed that version paths are managed by Kedro.
Broad `Exception` is present due to `dbutils.fs.ExecutionError` that
cannot be imported directly.
Args:
pattern: Filepath to search for.
dbutils: dbutils instance to operate with DBFS.
Returns:
Boolean value if filepath exists.
"""
pattern = strip_dbfs_prefix(pattern)
file = parse_glob_pattern(pattern)
try:
dbutils.fs.ls(file)
return True
except Exception:
return False


def deployed_on_databricks() -> bool:
"""Check if running on Databricks."""
return "DATABRICKS_RUNTIME_VERSION" in os.environ
29 changes: 29 additions & 0 deletions kedro-datasets/kedro_datasets/_utils/spark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import TYPE_CHECKING, Union

from pyspark.sql import SparkSession

if TYPE_CHECKING:
from databricks.connect import DatabricksSession


def get_spark() -> Union[SparkSession, "DatabricksSession"]:
"""
Returns the SparkSession. In case databricks-connect is available we use it for
extended configuration mechanisms and notebook compatibility,
otherwise we use classic pyspark.
"""
try:
# When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2)
# the remote session is instantiated using the databricks module
# If the databricks-connect module is installed, we use a remote session
from databricks.connect import DatabricksSession

# We can't test this as there's no Databricks test env available
spark = DatabricksSession.builder.getOrCreate() # pragma: no cover

except ImportError:
# For "normal" spark sessions that don't use databricks-connect
# we get spark normally
spark = SparkSession.builder.getOrCreate()

return spark
22 changes: 11 additions & 11 deletions kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyspark.sql.types import StructType
from pyspark.sql.utils import AnalysisException, ParseException

from kedro_datasets.spark.spark_dataset import _get_spark
from kedro_datasets._utils.spark_utils import get_spark

logger = logging.getLogger(__name__)
pd.DataFrame.iteritems = pd.DataFrame.items
Expand Down Expand Up @@ -183,7 +183,7 @@ def exists(self) -> bool:
"""
if self.catalog:
try:
_get_spark().sql(f"USE CATALOG `{self.catalog}`")
get_spark().sql(f"USE CATALOG `{self.catalog}`")
except (ParseException, AnalysisException) as exc:
logger.warning(
"catalog %s not found or unity not enabled. Error message: %s",
Expand All @@ -192,7 +192,7 @@ def exists(self) -> bool:
)
try:
return (
_get_spark()
get_spark()
.sql(f"SHOW TABLES IN `{self.database}`")
.filter(f"tableName = '{self.table}'")
.count()
Expand Down Expand Up @@ -359,15 +359,15 @@ def _load(self) -> DataFrame | pd.DataFrame:
if self._version and self._version.load >= 0:
try:
data = (
_get_spark()
get_spark()
.read.format("delta")
.option("versionAsOf", self._version.load)
.table(self._table.full_table_location())
)
except Exception as exc:
raise VersionNotFoundError(self._version.load) from exc
else:
data = _get_spark().table(self._table.full_table_location())
data = get_spark().table(self._table.full_table_location())
if self._table.dataframe_type == "pandas":
data = data.toPandas()
return data
Expand All @@ -391,13 +391,13 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None:
if schema:
cols = schema.fieldNames()
if self._table.dataframe_type == "pandas":
data = _get_spark().createDataFrame(
data = get_spark().createDataFrame(
data.loc[:, cols], schema=self._table.schema()
)
else:
data = data.select(*cols)
elif self._table.dataframe_type == "pandas":
data = _get_spark().createDataFrame(data)
data = get_spark().createDataFrame(data)

method = getattr(self, f"_save_{self._table.write_mode}", None)
if method:
Expand Down Expand Up @@ -456,7 +456,7 @@ def _save_upsert(self, update_data: DataFrame) -> None:
update_data (DataFrame): The Spark dataframe to upsert.
"""
if self._exists():
base_data = _get_spark().table(self._table.full_table_location())
base_data = get_spark().table(self._table.full_table_location())
base_columns = base_data.columns
update_columns = update_data.columns

Expand All @@ -479,11 +479,11 @@ def _save_upsert(self, update_data: DataFrame) -> None:
)

update_data.createOrReplaceTempView("update")
_get_spark().conf.set("fullTableAddress", self._table.full_table_location())
_get_spark().conf.set("whereExpr", where_expr)
get_spark().conf.set("fullTableAddress", self._table.full_table_location())
get_spark().conf.set("whereExpr", where_expr)
upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr}
WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *"""
_get_spark().sql(upsert_sql)
get_spark().sql(upsert_sql)
else:
self._save_append(update_data)

Expand Down
15 changes: 6 additions & 9 deletions kedro-datasets/kedro_datasets/spark/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@
from kedro.io.core import AbstractDataset, DatasetError
from pyspark.sql.utils import AnalysisException

from kedro_datasets.spark.spark_dataset import (
_get_spark,
_split_filepath,
_strip_dbfs_prefix,
)
from kedro_datasets._utils.databricks_utils import split_filepath, strip_dbfs_prefix
from kedro_datasets._utils.spark_utils import get_spark


class DeltaTableDataset(AbstractDataset[None, DeltaTable]):
Expand Down Expand Up @@ -81,24 +78,24 @@ def __init__(
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
fs_prefix, filepath = _split_filepath(filepath)
fs_prefix, filepath = split_filepath(filepath)

self._fs_prefix = fs_prefix
self._filepath = PurePosixPath(filepath)
self.metadata = metadata

def load(self) -> DeltaTable:
load_path = self._fs_prefix + str(self._filepath)
return DeltaTable.forPath(_get_spark(), load_path)
return DeltaTable.forPath(get_spark(), load_path)

def save(self, data: None) -> NoReturn:
raise DatasetError(f"{self.__class__.__name__} is a read only dataset type")

def _exists(self) -> bool:
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))
load_path = strip_dbfs_prefix(self._fs_prefix + str(self._filepath))

try:
_get_spark().read.load(path=load_path, format="delta")
get_spark().read.load(path=load_path, format="delta")
except AnalysisException as exception:
# `AnalysisException.desc` is deprecated with pyspark >= 3.4
message = exception.desc if hasattr(exception, "desc") else str(exception)
Expand Down
Loading
Loading