Skip to content

Commit

Permalink
[BUG] Fix StructuredDataset empty-str file_format in dc attr access (
Browse files Browse the repository at this point in the history
…#3027)

* fix: Retain user-specified file format info

Signed-off-by: JiaWei Jiang <[email protected]>

* fix: Set sdt format based on user-specified file_format

Signed-off-by: JiaWei Jiang <[email protected]>

* Remove redundant modification

Signed-off-by: JiaWei Jiang <[email protected]>

* test: Test file_format attribute alignment in dc.sd

Signed-off-by: JiaWei Jiang <[email protected]>

* Merge master and support pqt file upload

Signed-off-by: JiaWei Jiang <[email protected]>

* Remove redundant condition to always copy file_format over

Signed-off-by: JiangJiaWei1103 <[email protected]>

* Prioritize file_format in type hint over the user-specified one

Signed-off-by: JiangJiaWei1103 <[email protected]>

---------

Signed-off-by: JiaWei Jiang <[email protected]>
Signed-off-by: JiangJiaWei1103 <[email protected]>
Co-authored-by: Future-Outlier <[email protected]>
  • Loading branch information
JiangJiaWei1103 and Future-Outlier authored Jan 31, 2025
1 parent 4208a64 commit 88ac611
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 0 deletions.
21 changes: 21 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,31 @@ async def async_to_literal(
# return StructuredDataset(uri=uri)
if python_val.dataframe is None:
uri = python_val.uri
file_format = python_val.file_format

# Check the user-specified uri
if not uri:
raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
if not ctx.file_access.is_remote(uri):
uri = await ctx.file_access.async_put_raw_data(uri)

# Check the user-specified file_format
# When users specify file_format for a StructuredDataset, the file_format should be retained conditionally.
# For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
# Following illustrates why we can't always copy the user-specified file_format over:
#
# @task
# def modify_format(sd: Annotated[StructuredDataset, {}, "task-format"]) -> StructuredDataset:
# return sd
#
# sd = StructuredDataset(uri="s3://my-s3-bucket/df.parquet", file_format="user-format")
# sd2 = modify_format(sd=sd)
#
# In this case, we expect sd2.file_format to be task-format (as shown in Annotated), not user-format.
# If we directly copy the user-specified file_format over, the type hint information will be missing.
if sdt.format == GENERIC_FORMAT and file_format != GENERIC_FORMAT:
sdt.format = file_format

sd_model = literals.StructuredDataset(
uri=uri,
metadata=StructuredDatasetMetadata(structured_dataset_type=sdt),
Expand Down
46 changes: 46 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import uuid
import pytest
from unittest import mock
from dataclasses import dataclass

from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase
from flytekit.configuration import Config, ImageConfig, SerializationSettings
Expand All @@ -27,6 +28,7 @@
from flytekit.remote.remote import FlyteRemote
from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset
from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient
from flytekit.configuration import PlatformConfig

Expand Down Expand Up @@ -877,6 +879,50 @@ def test_attr_access_sd():
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)


def test_sd_attr():
"""Test correctness of StructuredDataset attributes.
This test considers only the following condition:
1. Check StructuredDataset (wrapped in a dataclass) file_format attribute
We'll make sure uri aligns with the user-specified one in the future.
"""
from workflows.basic.sd_attr import wf

@dataclass
class DC:
sd: StructuredDataset

FILE_FORMAT = "parquet"

# Upload a file to minio s3 bucket
file_transfer = SimpleFileTransfer()
remote_file_path = file_transfer.upload_file(file_type=FILE_FORMAT)

# Create a dataclass as the workflow input because `pyflyte run`
# can't properly handle input arg `dc` as a json str so far
dc = DC(sd=StructuredDataset(uri=remote_file_path, file_format=FILE_FORMAT))

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True)
wf_exec = remote.execute(
wf,
inputs={"dc": dc, "file_format": FILE_FORMAT},
wait=True,
version=VERSION,
image_config=ImageConfig.from_images(IMAGE),
)
assert wf_exec.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {wf_exec.closure.phase}"
assert wf_exec.outputs["o0"].file_format == FILE_FORMAT, (
f"Workflow output StructuredDataset file_format should align with the user-specified file_format: {FILE_FORMAT}."
)

# Delete the remote file to free the space
url = urlparse(remote_file_path)
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)


def test_signal_approve_reject(register):
from flytekit.models.types import LiteralType, SimpleType
from time import sleep
Expand Down
68 changes: 68 additions & 0 deletions tests/flytekit/integration/remote/workflows/basic/sd_attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import dataclass

import pandas as pd
from flytekit import task, workflow
from flytekit.types.structured import StructuredDataset


@dataclass
class DC:
sd: StructuredDataset


@task
def create_dc(uri: str, file_format: str) -> DC:
"""Create a dataclass with a StructuredDataset attribute.
Args:
uri: File URI.
file_format: File format, e.g., parquet, csv.
Returns:
dc: A dataclass with a StructuredDataset attribute.
"""
dc = DC(sd=StructuredDataset(uri=uri, file_format=file_format))

return dc


@task
def check_file_format(sd: StructuredDataset, true_file_format: str) -> StructuredDataset:
"""Check StructuredDataset file_format attribute.
StruturedDataset file_format should align with what users specify.
Args:
sd: Python native StructuredDataset.
true_file_format: User-specified file_format.
"""
assert sd.file_format == true_file_format, (
f"StructuredDataset file_format should align with the user-specified file_format: {true_file_format}."
)
assert sd._literal_sd.metadata.structured_dataset_type.format == true_file_format, (
f"StructuredDatasetType format should align with the user-specified file_format: {true_file_format}."
)
print(f">>> SD <<<\n{sd}")
print(f">>> Literal SD <<<\n{sd._literal_sd}")
print(f">>> SDT <<<\n{sd._literal_sd.metadata.structured_dataset_type}")
print(f">>> DF <<<\n{sd.open(pd.DataFrame).all()}")

return sd


@workflow
def wf(dc: DC, file_format: str) -> StructuredDataset:
# Fail to use dc.sd.file_format as the input
sd = check_file_format(sd=dc.sd, true_file_format=file_format)

return sd


if __name__ == "__main__":
# Define inputs
uri = "tests/flytekit/integration/remote/workflows/basic/data/df.parquet"
file_format = "parquet"

dc = create_dc(uri=uri, file_format=file_format)
sd = wf(dc=dc, file_format=file_format)
print(sd.file_format)

0 comments on commit 88ac611

Please sign in to comment.