Skip to content
Merged
64 changes: 64 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,66 @@ def assert_type(self, t: Type[enum.Enum], v: T):
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")


class LiteralTypeTransformer(TypeTransformer[object]):
def __init__(self):
super().__init__("LiteralTypeTransformer", object)

@classmethod
def get_base_type(cls, t: Type) -> Type:
args = get_args(t)
if not args:
raise TypeTransformerFailedError("Literal must have at least one value")

base_type = type(args[0])
if not all(type(a) == base_type for a in args):
raise TypeTransformerFailedError("All values must be of the same type")

return base_type

def get_literal_type(self, t: Type) -> LiteralType:
base_type = self.get_base_type(t)
vals = list(get_args(t))
ann = TypeAnnotationModel(annotations={"literal_values": vals})
if base_type is str:
simple = SimpleType.STRING
elif base_type is int:
simple = SimpleType.INTEGER
elif base_type is float:
simple = SimpleType.FLOAT
elif base_type is bool:
simple = SimpleType.BOOLEAN
elif base_type is datetime.datetime:
simple = SimpleType.DATETIME
elif base_type is datetime.timedelta:
simple = SimpleType.DURATION
else:
raise TypeTransformerFailedError(f"Unsupported type: {base_type}")
return LiteralType(simple=simple, annotation=ann)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal:
base_type = self.get_base_type(python_type)
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
return base_transformer.to_literal(ctx, python_val, python_type, expected)

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> object:
base_type = self.get_base_type(expected_python_type)
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
return base_transformer.to_python_value(ctx, lv, base_type)

def guess_python_type(self, literal_type: LiteralType) -> Type:
ann = getattr(literal_type, "annotation", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

literal_type will always have an annotation attribute, right? I think we can just do if literal_type.annotation and literal_type.annotation.annotations

if ann and getattr(ann, "annotations", None):
vals = ann.annotations.get("literal_values")
if vals and isinstance(vals, list):
return typing.Literal[tuple(vals)] # type: ignore
raise ValueError(f"LiteralType transformer cannot reverse {literal_type}")

def assert_type(self, python_type: Type, python_val: T):
base_type = self.get_base_type(python_type)
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
return base_transformer.assert_type(base_type, python_val)


def _handle_json_schema_property(
property_key: str,
property_val: dict,
Expand Down Expand Up @@ -1173,6 +1233,7 @@ class TypeEngine(typing.Generic[T]):
_RESTRICTED_TYPES: typing.List[type] = []
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
_ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore
_LITERAL_TYPE_TRANSFORMER: TypeTransformer = LiteralTypeTransformer()
lazy_import_lock = threading.Lock()

@classmethod
Expand Down Expand Up @@ -1222,6 +1283,9 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
return cls._ENUM_TRANSFORMER

if get_origin(python_type) == typing.Literal:
return cls._LITERAL_TYPE_TRANSFORMER

if hasattr(python_type, "__origin__"):
# If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
# or List[int] has been specifically registered; we should check for the entire type.
Expand Down
39 changes: 39 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3851,3 +3851,42 @@ async def test_dict_transformer_annotated_type():

literal3 = await TypeEngine.async_to_literal(ctx, nested_dict, nested_dict_type, expected_type)
assert literal3.map.literals["outer"].map.literals["inner"].scalar.primitive.integer == 42

def test_literal_transformer_string_type():
# Python -> Flyte
t = typing.Literal["outcome", "income"]
lt = TypeEngine.get_transformer(t).get_literal_type(t)
assert lt.simple == SimpleType.STRING
assert lt.annotation.annotations["literal_values"] == ["outcome", "income"]

lv = TypeEngine.to_literal(FlyteContext.current_context(), "outcome", t, lt)
assert lv.scalar.primitive.string_value == "outcome"

# Flyte -> Python (reconstruction)
pt = TypeEngine.get_transformer(t).guess_python_type(lt)
assert pt is typing.Literal["outcome", "income"]
pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt)
TypeEngine.get_transformer(pt).assert_type(pt, pv)
assert pv == "outcome"

def test_literal_transformer_int_type():
# Python -> Flyte
t = typing.Literal[1, 2, 3]
lt = TypeEngine.get_transformer(t).get_literal_type(t)
assert lt.simple == SimpleType.INTEGER
assert lt.annotation.annotations["literal_values"] == [1, 2, 3]

lv = TypeEngine.to_literal(FlyteContext.current_context(), 1, t, lt)
assert lv.scalar.primitive.integer == 1

# Flyte -> Python (reconstruction)
pt = TypeEngine.get_transformer(t).guess_python_type(lt)
assert pt is typing.Literal[1, 2, 3]
pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt)
TypeEngine.get_transformer(pt).assert_type(pt, pv)
assert pv == 1

def test_literal_transformer_mixed_base_types():
t = typing.Literal["a", 1]
with pytest.raises(TypeTransformerFailedError):
TypeEngine.get_transformer(t).get_literal_type(t)
Loading