-
Notifications
You must be signed in to change notification settings - Fork 311
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BUG] Fix StructuredDataset empty-str
file_format
in dc attr access (…
…#3027) * fix: Retain user-specified file format info Signed-off-by: JiaWei Jiang <[email protected]> * fix: Set sdt format based on user-specified file_format Signed-off-by: JiaWei Jiang <[email protected]> * Remove redundant modification Signed-off-by: JiaWei Jiang <[email protected]> * test: Test file_format attribute alignment in dc.sd Signed-off-by: JiaWei Jiang <[email protected]> * Merge master and support pqt file upload Signed-off-by: JiaWei Jiang <[email protected]> * Remove redundant condition to always copy file_format over Signed-off-by: JiangJiaWei1103 <[email protected]> * Prioritize file_format in type hint over the user-specified one Signed-off-by: JiangJiaWei1103 <[email protected]> --------- Signed-off-by: JiaWei Jiang <[email protected]> Signed-off-by: JiangJiaWei1103 <[email protected]> Co-authored-by: Future-Outlier <[email protected]>
- Loading branch information
1 parent
4208a64
commit 88ac611
Showing
3 changed files
with
135 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
tests/flytekit/integration/remote/workflows/basic/sd_attr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from dataclasses import dataclass | ||
|
||
import pandas as pd | ||
from flytekit import task, workflow | ||
from flytekit.types.structured import StructuredDataset | ||
|
||
|
||
@dataclass | ||
class DC: | ||
sd: StructuredDataset | ||
|
||
|
||
@task | ||
def create_dc(uri: str, file_format: str) -> DC: | ||
"""Create a dataclass with a StructuredDataset attribute. | ||
Args: | ||
uri: File URI. | ||
file_format: File format, e.g., parquet, csv. | ||
Returns: | ||
dc: A dataclass with a StructuredDataset attribute. | ||
""" | ||
dc = DC(sd=StructuredDataset(uri=uri, file_format=file_format)) | ||
|
||
return dc | ||
|
||
|
||
@task | ||
def check_file_format(sd: StructuredDataset, true_file_format: str) -> StructuredDataset: | ||
"""Check StructuredDataset file_format attribute. | ||
StruturedDataset file_format should align with what users specify. | ||
Args: | ||
sd: Python native StructuredDataset. | ||
true_file_format: User-specified file_format. | ||
""" | ||
assert sd.file_format == true_file_format, ( | ||
f"StructuredDataset file_format should align with the user-specified file_format: {true_file_format}." | ||
) | ||
assert sd._literal_sd.metadata.structured_dataset_type.format == true_file_format, ( | ||
f"StructuredDatasetType format should align with the user-specified file_format: {true_file_format}." | ||
) | ||
print(f">>> SD <<<\n{sd}") | ||
print(f">>> Literal SD <<<\n{sd._literal_sd}") | ||
print(f">>> SDT <<<\n{sd._literal_sd.metadata.structured_dataset_type}") | ||
print(f">>> DF <<<\n{sd.open(pd.DataFrame).all()}") | ||
|
||
return sd | ||
|
||
|
||
@workflow | ||
def wf(dc: DC, file_format: str) -> StructuredDataset: | ||
# Fail to use dc.sd.file_format as the input | ||
sd = check_file_format(sd=dc.sd, true_file_format=file_format) | ||
|
||
return sd | ||
|
||
|
||
if __name__ == "__main__": | ||
# Define inputs | ||
uri = "tests/flytekit/integration/remote/workflows/basic/data/df.parquet" | ||
file_format = "parquet" | ||
|
||
dc = create_dc(uri=uri, file_format=file_format) | ||
sd = wf(dc=dc, file_format=file_format) | ||
print(sd.file_format) |