Skip to content

Commit de5a137

Browse files
committed
fixes handling of union types
In the dataclasses conversion code type that is a member of a union was not properly checked for if it was a member and so there would always be an error. For instance `FlyteFile.path` is `Union[str,Pathlike]` and so `str != Union[str,Pathlike]`. This patch adds support for checking that a type is part of a union and a satisfactory type. Signed-off-by: Samuel Lotz <[email protected]>
1 parent 1a25939 commit de5a137

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

flytekit/core/type_engine.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from abc import ABC, abstractmethod
1818
from collections import OrderedDict
1919
from functools import lru_cache
20-
from types import GenericAlias
20+
from types import GenericAlias, UnionType
2121
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast
2222

2323
import msgpack
@@ -565,7 +565,14 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
565565
original_type = type(v)
566566
if UnionTransformer.is_optional_type(expected_type):
567567
expected_type = UnionTransformer.get_sub_type_in_optional(expected_type)
568-
if original_type != expected_type:
568+
569+
if (
570+
UnionTransformer.is_union(expected_type) and
571+
UnionTransformer.in_union(original_type, expected_type)
572+
):
573+
pass
574+
575+
elif original_type != expected_type:
569576
raise TypeTransformerFailedError(
570577
f"Type of Val '{original_type}' is not an instance of {expected_type}"
571578
)
@@ -1836,6 +1843,16 @@ class UnionTransformer(AsyncTypeTransformer[T]):
18361843
def __init__(self):
18371844
super().__init__("Typed Union", typing.Union)
18381845

1846+
@staticmethod
1847+
def is_union(t: Type[Any] | UnionType) -> bool:
1848+
1849+
return _is_union_type(t)
1850+
1851+
@staticmethod
1852+
def in_union(t: Type[Any], union: types.UnionType) -> bool:
1853+
1854+
return t in typing.get_args(union)
1855+
18391856
@staticmethod
18401857
def is_optional_type(t: Type) -> bool:
18411858
return _is_union_type(t) and type(None) in get_args(t)

0 commit comments

Comments
 (0)