Skip to content

Commit 07aef5a

Browse files
MinuraPunchihewanoklammerelcht
authored
feat(datasets): Improved Dependency Management for Spark-based Datasets (#911)
* added the skeleton for the utils sub pkg Signed-off-by: Minura Punchihewa <[email protected]> * moved the utility funcs from spark_dataset to relevant modules in _utils Signed-off-by: Minura Punchihewa <[email protected]> * updated the use of utility funcs in spark_dataset Signed-off-by: Minura Punchihewa <[email protected]> * fixed import in databricks_utils Signed-off-by: Minura Punchihewa <[email protected]> * renamed _strip_dbfs_prefix to strip_dbfs_prefix Signed-off-by: Minura Punchihewa <[email protected]> * updated the other modules that import from spark_dataset to use _utils Signed-off-by: Minura Punchihewa <[email protected]> * updated the use of strip_dbfs_prefix in spark_dataset Signed-off-by: Minura Punchihewa <[email protected]> * fixed lint issues Signed-off-by: Minura Punchihewa <[email protected]> * removed the base deps for spark, pandas and delta from databricks datasets Signed-off-by: Minura Punchihewa <[email protected]> * moved the file based utility funcs to databricks_utils Signed-off-by: Minura Punchihewa <[email protected]> * fixed the imports of the file based utility funcs Signed-off-by: Minura Punchihewa <[email protected]> * fixed lint issues Signed-off-by: Minura Punchihewa <[email protected]> * fixed the use of _get_spark() in tests Signed-off-by: Minura Punchihewa <[email protected]> * fixed uses of databricks utils in tests Signed-off-by: Minura Punchihewa <[email protected]> * fixed more tests Signed-off-by: Minura Punchihewa <[email protected]> * fixed more lint issues Signed-off-by: Minura Punchihewa <[email protected]> * fixed more tests Signed-off-by: Minura Punchihewa <[email protected]> * fixed more tests Signed-off-by: Minura Punchihewa <[email protected]> * improved type hints for spark & databricks utility funcs Signed-off-by: Minura Punchihewa <[email protected]> * fixed more lint issues Signed-off-by: Minura Punchihewa <[email protected]> * further improved type hints for utility funcs Signed-off-by: Minura Punchihewa <[email protected]> * fixed a couple of incorrect type hints Signed-off-by: Minura Punchihewa <[email protected]> * fixed several incorrect type hints Signed-off-by: Minura Punchihewa <[email protected]> * updated the release notes Signed-off-by: Minura Punchihewa <[email protected]> * Reorder release notes Signed-off-by: Merel Theisen <[email protected]> --------- Signed-off-by: Minura Punchihewa <[email protected]> Signed-off-by: Merel Theisen <[email protected]> Co-authored-by: Nok Lam Chan <[email protected]> Co-authored-by: Merel Theisen <[email protected]> Co-authored-by: Merel Theisen <[email protected]>
1 parent 57f6279 commit 07aef5a

15 files changed

+220
-200
lines changed

kedro-datasets/RELEASE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Upcoming Release 6.0.0
22

33
## Major features and improvements
4-
54
- Added functionality to save Pandas DataFrame directly to Snowflake, facilitating seemless `.csv` ingestion
65
- Added Python 3.9, 3.10 and 3.11 support for SnowflakeTableDataset
76
- Added the following new **experimental** datasets:
@@ -13,6 +12,7 @@
1312

1413
## Bug fixes and other changes
1514
- Implemented Snowflake's (local testing framework)[https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally] for testing purposes
15+
- Improved the dependency management for Spark-based datasets by refactoring the Spark and Databricks utility functions used across the datasets.
1616

1717
## Breaking Changes
1818
- Demoted `video.VideoDataset` from core to experimental dataset.

kedro-datasets/kedro_datasets/_utils/__init__.py

Whitespace-only changes.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import os
2+
from fnmatch import fnmatch
3+
from pathlib import PurePosixPath
4+
from typing import TYPE_CHECKING, Union
5+
6+
from pyspark.sql import SparkSession
7+
8+
if TYPE_CHECKING:
9+
from databricks.connect import DatabricksSession
10+
from pyspark.dbutils import DBUtils
11+
12+
13+
def parse_glob_pattern(pattern: str) -> str:
14+
special = ("*", "?", "[")
15+
clean = []
16+
for part in pattern.split("/"):
17+
if any(char in part for char in special):
18+
break
19+
clean.append(part)
20+
return "/".join(clean)
21+
22+
23+
def split_filepath(filepath: str | os.PathLike) -> tuple[str, str]:
24+
split_ = str(filepath).split("://", 1)
25+
if len(split_) == 2: # noqa: PLR2004
26+
return split_[0] + "://", split_[1]
27+
return "", split_[0]
28+
29+
30+
def strip_dbfs_prefix(path: str, prefix: str = "/dbfs") -> str:
31+
return path[len(prefix) :] if path.startswith(prefix) else path
32+
33+
34+
def dbfs_glob(pattern: str, dbutils: "DBUtils") -> list[str]:
35+
"""Perform a custom glob search in DBFS using the provided pattern.
36+
It is assumed that version paths are managed by Kedro only.
37+
38+
Args:
39+
pattern: Glob pattern to search for.
40+
dbutils: dbutils instance to operate with DBFS.
41+
42+
Returns:
43+
List of DBFS paths prefixed with '/dbfs' that satisfy the glob pattern.
44+
"""
45+
pattern = strip_dbfs_prefix(pattern)
46+
prefix = parse_glob_pattern(pattern)
47+
matched = set()
48+
filename = pattern.split("/")[-1]
49+
50+
for file_info in dbutils.fs.ls(prefix):
51+
if file_info.isDir():
52+
path = str(
53+
PurePosixPath(strip_dbfs_prefix(file_info.path, "dbfs:")) / filename
54+
)
55+
if fnmatch(path, pattern):
56+
path = "/dbfs" + path
57+
matched.add(path)
58+
return sorted(matched)
59+
60+
61+
def get_dbutils(spark: Union[SparkSession, "DatabricksSession"]) -> "DBUtils":
62+
"""Get the instance of 'dbutils' or None if the one could not be found."""
63+
dbutils = globals().get("dbutils")
64+
if dbutils:
65+
return dbutils
66+
67+
try:
68+
from pyspark.dbutils import DBUtils
69+
70+
dbutils = DBUtils(spark)
71+
except ImportError:
72+
try:
73+
import IPython
74+
except ImportError:
75+
pass
76+
else:
77+
ipython = IPython.get_ipython()
78+
dbutils = ipython.user_ns.get("dbutils") if ipython else None
79+
80+
return dbutils
81+
82+
83+
def dbfs_exists(pattern: str, dbutils: "DBUtils") -> bool:
84+
"""Perform an `ls` list operation in DBFS using the provided pattern.
85+
It is assumed that version paths are managed by Kedro.
86+
Broad `Exception` is present due to `dbutils.fs.ExecutionError` that
87+
cannot be imported directly.
88+
Args:
89+
pattern: Filepath to search for.
90+
dbutils: dbutils instance to operate with DBFS.
91+
Returns:
92+
Boolean value if filepath exists.
93+
"""
94+
pattern = strip_dbfs_prefix(pattern)
95+
file = parse_glob_pattern(pattern)
96+
try:
97+
dbutils.fs.ls(file)
98+
return True
99+
except Exception:
100+
return False
101+
102+
103+
def deployed_on_databricks() -> bool:
104+
"""Check if running on Databricks."""
105+
return "DATABRICKS_RUNTIME_VERSION" in os.environ
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import TYPE_CHECKING, Union
2+
3+
from pyspark.sql import SparkSession
4+
5+
if TYPE_CHECKING:
6+
from databricks.connect import DatabricksSession
7+
8+
9+
def get_spark() -> Union[SparkSession, "DatabricksSession"]:
10+
"""
11+
Returns the SparkSession. In case databricks-connect is available we use it for
12+
extended configuration mechanisms and notebook compatibility,
13+
otherwise we use classic pyspark.
14+
"""
15+
try:
16+
# When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2)
17+
# the remote session is instantiated using the databricks module
18+
# If the databricks-connect module is installed, we use a remote session
19+
from databricks.connect import DatabricksSession
20+
21+
# We can't test this as there's no Databricks test env available
22+
spark = DatabricksSession.builder.getOrCreate() # pragma: no cover
23+
24+
except ImportError:
25+
# For "normal" spark sessions that don't use databricks-connect
26+
# we get spark normally
27+
spark = SparkSession.builder.getOrCreate()
28+
29+
return spark

kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pyspark.sql.types import StructType
2020
from pyspark.sql.utils import AnalysisException, ParseException
2121

22-
from kedro_datasets.spark.spark_dataset import _get_spark
22+
from kedro_datasets._utils.spark_utils import get_spark
2323

2424
logger = logging.getLogger(__name__)
2525
pd.DataFrame.iteritems = pd.DataFrame.items
@@ -183,7 +183,7 @@ def exists(self) -> bool:
183183
"""
184184
if self.catalog:
185185
try:
186-
_get_spark().sql(f"USE CATALOG `{self.catalog}`")
186+
get_spark().sql(f"USE CATALOG `{self.catalog}`")
187187
except (ParseException, AnalysisException) as exc:
188188
logger.warning(
189189
"catalog %s not found or unity not enabled. Error message: %s",
@@ -192,7 +192,7 @@ def exists(self) -> bool:
192192
)
193193
try:
194194
return (
195-
_get_spark()
195+
get_spark()
196196
.sql(f"SHOW TABLES IN `{self.database}`")
197197
.filter(f"tableName = '{self.table}'")
198198
.count()
@@ -359,15 +359,15 @@ def _load(self) -> DataFrame | pd.DataFrame:
359359
if self._version and self._version.load >= 0:
360360
try:
361361
data = (
362-
_get_spark()
362+
get_spark()
363363
.read.format("delta")
364364
.option("versionAsOf", self._version.load)
365365
.table(self._table.full_table_location())
366366
)
367367
except Exception as exc:
368368
raise VersionNotFoundError(self._version.load) from exc
369369
else:
370-
data = _get_spark().table(self._table.full_table_location())
370+
data = get_spark().table(self._table.full_table_location())
371371
if self._table.dataframe_type == "pandas":
372372
data = data.toPandas()
373373
return data
@@ -391,13 +391,13 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None:
391391
if schema:
392392
cols = schema.fieldNames()
393393
if self._table.dataframe_type == "pandas":
394-
data = _get_spark().createDataFrame(
394+
data = get_spark().createDataFrame(
395395
data.loc[:, cols], schema=self._table.schema()
396396
)
397397
else:
398398
data = data.select(*cols)
399399
elif self._table.dataframe_type == "pandas":
400-
data = _get_spark().createDataFrame(data)
400+
data = get_spark().createDataFrame(data)
401401

402402
method = getattr(self, f"_save_{self._table.write_mode}", None)
403403
if method:
@@ -456,7 +456,7 @@ def _save_upsert(self, update_data: DataFrame) -> None:
456456
update_data (DataFrame): The Spark dataframe to upsert.
457457
"""
458458
if self._exists():
459-
base_data = _get_spark().table(self._table.full_table_location())
459+
base_data = get_spark().table(self._table.full_table_location())
460460
base_columns = base_data.columns
461461
update_columns = update_data.columns
462462

@@ -479,11 +479,11 @@ def _save_upsert(self, update_data: DataFrame) -> None:
479479
)
480480

481481
update_data.createOrReplaceTempView("update")
482-
_get_spark().conf.set("fullTableAddress", self._table.full_table_location())
483-
_get_spark().conf.set("whereExpr", where_expr)
482+
get_spark().conf.set("fullTableAddress", self._table.full_table_location())
483+
get_spark().conf.set("whereExpr", where_expr)
484484
upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr}
485485
WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *"""
486-
_get_spark().sql(upsert_sql)
486+
get_spark().sql(upsert_sql)
487487
else:
488488
self._save_append(update_data)
489489

kedro-datasets/kedro_datasets/spark/deltatable_dataset.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,8 @@
1010
from kedro.io.core import AbstractDataset, DatasetError
1111
from pyspark.sql.utils import AnalysisException
1212

13-
from kedro_datasets.spark.spark_dataset import (
14-
_get_spark,
15-
_split_filepath,
16-
_strip_dbfs_prefix,
17-
)
13+
from kedro_datasets._utils.databricks_utils import split_filepath, strip_dbfs_prefix
14+
from kedro_datasets._utils.spark_utils import get_spark
1815

1916

2017
class DeltaTableDataset(AbstractDataset[None, DeltaTable]):
@@ -81,24 +78,24 @@ def __init__(
8178
metadata: Any arbitrary metadata.
8279
This is ignored by Kedro, but may be consumed by users or external plugins.
8380
"""
84-
fs_prefix, filepath = _split_filepath(filepath)
81+
fs_prefix, filepath = split_filepath(filepath)
8582

8683
self._fs_prefix = fs_prefix
8784
self._filepath = PurePosixPath(filepath)
8885
self.metadata = metadata
8986

9087
def load(self) -> DeltaTable:
9188
load_path = self._fs_prefix + str(self._filepath)
92-
return DeltaTable.forPath(_get_spark(), load_path)
89+
return DeltaTable.forPath(get_spark(), load_path)
9390

9491
def save(self, data: None) -> NoReturn:
9592
raise DatasetError(f"{self.__class__.__name__} is a read only dataset type")
9693

9794
def _exists(self) -> bool:
98-
load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath))
95+
load_path = strip_dbfs_prefix(self._fs_prefix + str(self._filepath))
9996

10097
try:
101-
_get_spark().read.load(path=load_path, format="delta")
98+
get_spark().read.load(path=load_path, format="delta")
10299
except AnalysisException as exception:
103100
# `AnalysisException.desc` is deprecated with pyspark >= 3.4
104101
message = exception.desc if hasattr(exception, "desc") else str(exception)

0 commit comments

Comments
 (0)