Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
StructuredDatasetTransformerEngine,
)

import pyspark
from pyspark.sql import dataframe
from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame

pd = lazy_module("pandas")
pyspark = lazy_module("pyspark")
ps_dataframe = lazy_module("pyspark.sql.dataframe")
DataFrame = ps_dataframe.DataFrame
DataFrame = dataframe.DataFrame


class SparkDataFrameRenderer:
Expand Down Expand Up @@ -52,6 +54,30 @@ def encode(
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))


class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(ClassicDataFrame, None, PARQUET)

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
path = typing.cast(str, structured_dataset.uri)
if not path:
path = ctx.file_access.join(
ctx.file_access.raw_output_prefix,
ctx.file_access.get_random_string(),
)
df = typing.cast(DataFrame, structured_dataset.dataframe)
ss = pyspark.sql.SparkSession.builder.getOrCreate()
# Avoid generating SUCCESS files
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
df.write.mode("overwrite").parquet(path=path)
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))


class ParquetToSparkDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(DataFrame, None, PARQUET)
Expand All @@ -69,6 +95,25 @@ def decode(
return user_ctx.spark_session.read.parquet(flyte_value.uri)


class ClassicParquetToSparkDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(ClassicDataFrame, None, PARQUET)

def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> DataFrame:
user_ctx = FlyteContext.current_context().user_space_params
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
return user_ctx.spark_session.read.parquet(flyte_value.uri)


StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler())
StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ClassicParquetToSparkDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer())
Loading