Skip to content

Commit c3818b3

Browse files
authored
Fix Pydantic transformer to support from __future__ import annotations (#3343)
Fixes #6694 When using `from __future__ import annotations`, the `__annotations__` dictionary contains string literals instead of actual type objects. This was causing the PydanticTransformer to fail to recognize standard Python types (str, int, float, etc.) in BaseModel fields, resulting in fallback to PickleFile serialization. Changes: - Updated `PydanticTransformer.get_literal_type()` to use Pydantic's `model_fields` API (v2) or `__fields__` (v1) instead of directly accessing `__annotations__` - These Pydantic APIs properly resolve string annotations to actual type objects, regardless of whether future annotations are enabled - Added comprehensive test suite covering simple types, complex types, nested models, and literal type structure validation The fix ensures compatibility with PEP 563 future annotations, which will become the default behavior in Python 3.14. Signed-off-by: Govert Verkes <[email protected]>
1 parent 3645cb3 commit c3818b3

File tree

2 files changed

+171
-1
lines changed

2 files changed

+171
-1
lines changed

flytekit/extras/pydantic_transformer/transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import typing
34
from typing import Type
45

56
import msgpack
@@ -25,7 +26,7 @@ def __init__(self):
2526
def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
2627
schema = t.model_json_schema()
2728
literal_type = {}
28-
fields = t.__annotations__.items()
29+
fields = typing.get_type_hints(t).items()
2930

3031
for name, python_type in fields:
3132
try:
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""
2+
Test that Pydantic BaseModel works correctly with `from __future__ import annotations`.
3+
4+
This addresses the issue where __annotations__ contains string literals instead of
5+
actual type objects when future annotations are enabled, which was causing Flyte
6+
to fail to recognize standard Python types.
7+
"""
8+
from __future__ import annotations
9+
10+
from typing import Dict, List
11+
12+
from pydantic import BaseModel, Field
13+
14+
from flytekit import task, workflow
15+
from flytekit.core.context_manager import FlyteContextManager
16+
from flytekit.core.type_engine import TypeEngine
17+
from flytekit.models.types import SimpleType
18+
19+
20+
# Define models at module level so they're available for get_type_hints resolution
21+
class SimpleModel(BaseModel):
22+
name: str
23+
age: int
24+
height: float
25+
is_active: bool
26+
27+
28+
class ComplexModel(BaseModel):
29+
tags: List[str] = Field(default_factory=lambda: ["tag1", "tag2"])
30+
scores: Dict[str, int] = Field(default_factory=lambda: {"math": 95, "science": 87})
31+
matrix: List[List[int]] = Field(default_factory=lambda: [[1, 2], [3, 4]])
32+
33+
34+
class InnerModel(BaseModel):
35+
value: int
36+
label: str
37+
38+
39+
class OuterModel(BaseModel):
40+
inner: InnerModel
41+
name: str
42+
43+
44+
class TypedModel(BaseModel):
45+
string_field: str
46+
int_field: int
47+
float_field: float
48+
bool_field: bool
49+
list_field: List[str]
50+
51+
52+
def test_simple_types_with_future_annotations():
53+
"""Test that basic Python types work correctly with future annotations."""
54+
55+
@task
56+
def process_simple_model(model: SimpleModel) -> str:
57+
return f"{model.name} is {model.age} years old"
58+
59+
@workflow
60+
def simple_wf(model: SimpleModel) -> str:
61+
return process_simple_model(model=model)
62+
63+
# Test workflow execution
64+
test_model = SimpleModel(name="Alice", age=30, height=5.5, is_active=True)
65+
result = simple_wf(model=test_model)
66+
assert result == "Alice is 30 years old"
67+
68+
# Verify that TypeEngine correctly handles the model
69+
ctx = FlyteContextManager.current_context()
70+
lt = TypeEngine.to_literal_type(SimpleModel)
71+
assert lt.simple is not None
72+
assert lt.metadata is not None
73+
74+
# Verify literal conversion works
75+
literal = TypeEngine.to_literal(ctx, test_model, SimpleModel, lt)
76+
assert literal is not None
77+
78+
# Verify round-trip conversion
79+
converted = TypeEngine.to_python_value(ctx, literal, SimpleModel)
80+
assert converted.name == "Alice"
81+
assert converted.age == 30
82+
assert converted.height == 5.5
83+
assert converted.is_active is True
84+
85+
86+
def test_complex_types_with_future_annotations():
87+
"""Test that complex types (List, Dict) work correctly with future annotations."""
88+
89+
@task
90+
def process_complex_model(model: ComplexModel) -> int:
91+
return sum(model.scores.values())
92+
93+
@workflow
94+
def complex_wf(model: ComplexModel) -> int:
95+
return process_complex_model(model=model)
96+
97+
# Test workflow execution
98+
test_model = ComplexModel()
99+
result = complex_wf(model=test_model)
100+
assert result == 182 # 95 + 87
101+
102+
# Verify that TypeEngine correctly handles the model
103+
ctx = FlyteContextManager.current_context()
104+
lt = TypeEngine.to_literal_type(ComplexModel)
105+
assert lt.simple is not None
106+
assert lt.metadata is not None
107+
108+
# Verify round-trip conversion
109+
literal = TypeEngine.to_literal(ctx, test_model, ComplexModel, lt)
110+
converted = TypeEngine.to_python_value(ctx, literal, ComplexModel)
111+
assert converted.tags == ["tag1", "tag2"]
112+
assert converted.scores == {"math": 95, "science": 87}
113+
assert converted.matrix == [[1, 2], [3, 4]]
114+
115+
116+
def test_nested_basemodels_with_future_annotations():
117+
"""Test that nested BaseModels work correctly with future annotations."""
118+
119+
@task
120+
def process_nested_model(model: OuterModel) -> str:
121+
return f"{model.name}: {model.inner.label} = {model.inner.value}"
122+
123+
@workflow
124+
def nested_wf(model: OuterModel) -> str:
125+
return process_nested_model(model=model)
126+
127+
# Test workflow execution
128+
inner = InnerModel(value=42, label="answer")
129+
outer = OuterModel(inner=inner, name="test")
130+
result = nested_wf(model=outer)
131+
assert result == "test: answer = 42"
132+
133+
# Verify round-trip conversion
134+
ctx = FlyteContextManager.current_context()
135+
lt = TypeEngine.to_literal_type(OuterModel)
136+
literal = TypeEngine.to_literal(ctx, outer, OuterModel, lt)
137+
converted = TypeEngine.to_python_value(ctx, literal, OuterModel)
138+
assert converted.name == "test"
139+
assert converted.inner.value == 42
140+
assert converted.inner.label == "answer"
141+
142+
143+
def test_literal_type_structure_with_future_annotations():
144+
"""Test that LiteralType structure is correctly generated with future annotations."""
145+
146+
# Get the literal type
147+
lt = TypeEngine.to_literal_type(TypedModel)
148+
149+
# Verify structure is created
150+
assert lt.structure is not None
151+
assert lt.structure.dataclass_type is not None
152+
153+
# Verify that each field has the correct literal type
154+
dataclass_type = lt.structure.dataclass_type
155+
156+
# Check that we have all expected fields
157+
assert "string_field" in dataclass_type
158+
assert "int_field" in dataclass_type
159+
assert "float_field" in dataclass_type
160+
assert "bool_field" in dataclass_type
161+
assert "list_field" in dataclass_type
162+
163+
# Verify the types are correctly identified (not as PickleFile)
164+
# This is the key test - with the bug, these would be PickleFile types
165+
assert dataclass_type["string_field"].simple == SimpleType.STRING
166+
assert dataclass_type["int_field"].simple == SimpleType.INTEGER
167+
assert dataclass_type["float_field"].simple == SimpleType.FLOAT
168+
assert dataclass_type["bool_field"].simple == SimpleType.BOOLEAN
169+
assert dataclass_type["list_field"].collection_type is not None

0 commit comments

Comments
 (0)