Skip to content

Commit 5b011a2

Browse files
authored
Fix spark imports (#5795)
fix spark imports
1 parent cf4a195 commit 5b011a2

File tree

1 file changed

+11
-4
lines changed
  • src/datasets/packaged_modules/spark

1 file changed

+11
-4
lines changed

src/datasets/packaged_modules/spark/spark.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import os
2+
import posixpath
23
import uuid
34
from dataclasses import dataclass
4-
from typing import Iterable, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union
56

67
import pyarrow as pa
7-
import pyspark
88

99
import datasets
10-
from datasets.arrow_writer import ArrowWriter
10+
from datasets.arrow_writer import ArrowWriter, ParquetWriter
1111
from datasets.config import MAX_SHARD_SIZE
1212
from datasets.filesystems import (
1313
is_remote_filesystem,
@@ -18,6 +18,9 @@
1818

1919
logger = datasets.utils.logging.get_logger(__name__)
2020

21+
if TYPE_CHECKING:
22+
import pyspark
23+
2124

2225
@dataclass
2326
class SparkConfig(datasets.BuilderConfig):
@@ -31,10 +34,12 @@ class Spark(datasets.DatasetBuilder):
3134

3235
def __init__(
3336
self,
34-
df: pyspark.sql.DataFrame,
37+
df: "pyspark.sql.DataFrame",
3538
cache_dir: str = None,
3639
**config_kwargs,
3740
):
41+
import pyspark
42+
3843
self._spark = pyspark.sql.SparkSession.builder.getOrCreate()
3944
self.df = df
4045
self._validate_cache_dir(cache_dir)
@@ -86,6 +91,8 @@ def _prepare_split_single(
8691
file_format: str,
8792
max_shard_size: int,
8893
) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:
94+
import pyspark
95+
8996
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
9097
embed_local_files = file_format == "parquet"
9198

0 commit comments

Comments
 (0)