Skip to content
81 changes: 81 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,83 @@ def assert_type(self, t: Type[enum.Enum], v: T):
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")


class LiteralTypeTransformer(TypeTransformer[object]):
def __init__(self):
super().__init__("LiteralTypeTransformer", object)

def get_literal_type(self, t: Type) -> LiteralType:
args = get_args(t)
if not args:
raise TypeTransformerFailedError("Literal must have at least one value")

base_type = type(args[0])
if not all(type(a) == base_type for a in args):
raise TypeTransformerFailedError("All values must be of the same type")

if base_type == str:
return StrTransformer.get_literal_type(args[0])
elif base_type == int:
return IntTransformer.get_literal_type(args[0])
elif base_type == float:
return FloatTransformer.get_literal_type(args[0])
elif base_type == bool:
return BoolTransformer.get_literal_type(args[0])
elif base_type == datetime.datetime:
return DatetimeTransformer.get_literal_type(args[0])
elif base_type == datetime.timedelta:
return TimedeltaTransformer.get_literal_type(args[0])
else:
raise TypeTransformerFailedError(f"Unsupported Literal base type: {base_type}")

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal:
if expected.simple == SimpleType.STRING and isinstance(python_val, str):
return StrTransformer.to_literal(ctx, python_val, str, expected)
elif expected.simple == SimpleType.INTEGER and isinstance(python_val, int):
return IntTransformer.to_literal(ctx, python_val, int, expected)
elif expected.simple == SimpleType.FLOAT and isinstance(python_val, float):
return FloatTransformer.to_literal(ctx, python_val, float, expected)
elif expected.simple == SimpleType.BOOLEAN and isinstance(python_val, bool):
return BoolTransformer.to_literal(ctx, python_val, bool, expected)
elif expected.simple == SimpleType.DATETIME and isinstance(python_val, datetime.datetime):
return DatetimeTransformer.to_literal(ctx, python_val, datetime.datetime, expected)
elif expected.simple == SimpleType.DURATION and isinstance(python_val, datetime.timedelta):
return TimedeltaTransformer.to_literal(ctx, python_val, datetime.timedelta, expected)
else:
raise TypeError(f"Unsupported LiteralType for LiteralTypeTransformer: {expected.simple}")

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> object:
if lv.scalar.primitive.string_value is not None:
return StrTransformer.to_python_value(ctx, lv, str)
elif lv.scalar.primitive.integer is not None:
return IntTransformer.to_python_value(ctx, lv, int)
elif lv.scalar.primitive.float_value is not None:
return FloatTransformer.to_python_value(ctx, lv, float)
elif lv.scalar.primitive.boolean is not None:
return BoolTransformer.to_python_value(ctx, lv, bool)
elif lv.scalar.primitive.datetime is not None:
return DatetimeTransformer.to_python_value(ctx, lv, datetime.datetime)
elif lv.scalar.primitive.duration is not None:
return TimedeltaTransformer.to_python_value(ctx, lv, datetime.timedelta)
else:
raise TypeTransformerFailedError("Unsupported Literal value")

def guess_python_type(self, literal_type: LiteralType):
if literal_type.simple == SimpleType.STRING:
return StrTransformer.guess_python_type(literal_type)
elif literal_type.simple == SimpleType.INTEGER:
return IntTransformer.guess_python_type(literal_type)
elif literal_type.simple == SimpleType.FLOAT:
return FloatTransformer.guess_python_type(literal_type)
elif literal_type.simple == SimpleType.BOOLEAN:
return BoolTransformer.guess_python_type(literal_type)
elif literal_type.simple == SimpleType.DATETIME:
return DatetimeTransformer.guess_python_type(literal_type)
elif literal_type.simple == SimpleType.DURATION:
return TimedeltaTransformer.guess_python_type(literal_type)
else:
raise TypeTransformerFailedError(f"LiteralTypeTransformer cannot reverse {literal_type}")


def _handle_json_schema_property(
property_key: str,
property_val: dict,
Expand Down Expand Up @@ -1173,6 +1250,7 @@ class TypeEngine(typing.Generic[T]):
_RESTRICTED_TYPES: typing.List[type] = []
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
_ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore
_LITERAL_TYPE_TRANSFORMER: TypeTransformer = LiteralTypeTransformer()
lazy_import_lock = threading.Lock()

@classmethod
Expand Down Expand Up @@ -1222,6 +1300,9 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
return cls._ENUM_TRANSFORMER

if get_origin(python_type) == typing.Literal:
return cls._LITERAL_TYPE_TRANSFORMER

if hasattr(python_type, "__origin__"):
# If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
# or List[int] has been specifically registered; we should check for the entire type.
Expand Down