Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys
import textwrap
import threading
import types
import typing
from abc import ABC, abstractmethod
from collections import OrderedDict
Expand Down Expand Up @@ -530,10 +531,7 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
# However, FooSchema is created by flytekit and it's not equal to the user-defined dataclass (Foo).
# Therefore, we should iterate all attributes in the dataclass and check the type of value in dataclass matches the expected_type.

expected_fields_dict = {}

for f in dataclasses.fields(expected_type):
expected_fields_dict[f.name] = f.type
expected_fields_dict = typing.get_type_hints(expected_type)

if isinstance(v, dict):
original_dict = v
Expand Down Expand Up @@ -568,7 +566,13 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
original_type = type(v)
if UnionTransformer.is_optional_type(expected_type):
expected_type = UnionTransformer.get_sub_type_in_optional(expected_type)
if original_type != expected_type:

if UnionTransformer.is_union(expected_type) and UnionTransformer.in_union(
original_type, expected_type
):
pass

elif original_type != expected_type:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do something like this instead to make code structure better?

is_in_union = (
    UnionTransformer.is_union(expected_type)
    and UnionTransformer.in_union(original_type, expected_type)
)

if not is_in_union and original_type != expected_type:
    raise TypeTransformerFailedError(
        f"Type of Val '{original_type}' is not an instance of {expected_type}"
    )

raise TypeTransformerFailedError(
f"Type of Val '{original_type}' is not an instance of {expected_type}"
)
Expand Down Expand Up @@ -1839,6 +1843,14 @@ class UnionTransformer(AsyncTypeTransformer[T]):
def __init__(self):
super().__init__("Typed Union", typing.Union)

@staticmethod
def is_union(t: Type[Any] | types.UnionType) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use typing.Union here? So that we do not need to import types

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are maintaining strict type checking, unfortunately no. As you need to check for UnionType and Union separately and UnionType doesn't fall under Type[Any] (union types are weird in Python...).

I noticed a couple other problems that were due to type mismatches so I'll defer to your preferences here as I don't want to overhaul things.

return _is_union_type(t)

@staticmethod
def in_union(t: Type[Any], union: types.UnionType) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use typing.Union here? So that we do not need to import types

return t in typing.get_args(union)
Copy link
Member

@machichima machichima May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use get_args here instead of typing.get_args? Which is already imported


@staticmethod
def is_optional_type(t: Type) -> bool:
return _is_union_type(t) and type(None) in get_args(t)
Expand Down
Loading