Skip to content

Commit 82990a6

Browse files
committed
Fix Pydantic deserialization for FlyteFile and FlyteDirectory
Fixes #6669 Signed-off-by: Govert Verkes <[email protected]>
1 parent ff4c79c commit 82990a6

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

flytekit/types/directory/types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,17 @@ def serialize_flyte_dir(self) -> Dict[str, str]:
179179
@model_validator(mode="after")
180180
def deserialize_flyte_dir(self, info) -> FlyteDirectory:
181181
if info.context is None or info.context.get("deserialize") is not True:
182-
return self
182+
# Check if all private attributes are already set up (e.g., from __init__)
183+
if hasattr(self, "_downloader") and hasattr(self, "_remote_source"):
184+
return self
185+
186+
# Populate missing private attributes for Pydantic-deserialized instances
187+
dict_obj = {"path": str(self.path)}
188+
189+
return FlyteDirToMultipartBlobTransformer().dict_to_flyte_directory(
190+
dict_obj=dict_obj,
191+
expected_python_type=type(self),
192+
)
183193

184194
pv = FlyteDirToMultipartBlobTransformer().to_python_value(
185195
FlyteContextManager.current_context(),

flytekit/types/file/file.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,18 @@ def serialize_flyte_file(self) -> Dict[str, typing.Any]:
196196
@model_validator(mode="after")
197197
def deserialize_flyte_file(self, info) -> "FlyteFile":
198198
if info.context is None or info.context.get("deserialize") is not True:
199-
return self
199+
if hasattr(self, "_downloader") and hasattr(self, "_remote_source"):
200+
return self
201+
202+
dict_obj = {"path": str(self.path)}
203+
metadata = getattr(self, "metadata", None)
204+
if metadata is not None:
205+
dict_obj["metadata"] = metadata
206+
207+
return FlyteFilePathTransformer().dict_to_flyte_file(
208+
dict_obj=dict_obj,
209+
expected_python_type=type(self),
210+
)
200211

201212
pv = FlyteFilePathTransformer().to_python_value(
202213
FlyteContextManager.current_context(),

tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,3 +1022,37 @@ def mock_resolve_remote_path(flyte_uri: str):
10221022

10231023
bm_revived = TypeEngine.to_python_value(ctx, lit, BM)
10241024
assert bm_revived.s.literal.uri == "/my/replaced/val"
1025+
1026+
1027+
def test_flytefile_pydantic_model_dump_validate_cycle():
1028+
class BM(BaseModel):
1029+
ff: FlyteFile
1030+
1031+
bm = BM(ff=FlyteFile.from_source("s3://my-bucket/file.txt"))
1032+
1033+
assert bm.ff.remote_source == "s3://my-bucket/file.txt"
1034+
1035+
bm_dict = bm.model_dump()
1036+
bm2 = BM.model_validate(bm_dict)
1037+
1038+
assert isinstance(bm2.ff, FlyteFile)
1039+
assert bm2.ff.remote_source == "s3://my-bucket/file.txt"
1040+
1041+
bm2.model_dump()
1042+
1043+
1044+
def test_flytedirectory_pydantic_model_dump_validate_cycle():
1045+
class BM(BaseModel):
1046+
fd: FlyteDirectory
1047+
1048+
bm = BM(fd=FlyteDirectory.from_source("s3://my-bucket/my-dir"))
1049+
1050+
assert bm.fd.remote_source == "s3://my-bucket/my-dir"
1051+
1052+
bm_dict = bm.model_dump()
1053+
bm2 = BM.model_validate(bm_dict)
1054+
1055+
assert isinstance(bm2.fd, FlyteDirectory)
1056+
assert bm2.fd.remote_source == "s3://my-bucket/my-dir"
1057+
1058+
bm2.model_dump()

0 commit comments

Comments
 (0)