Skip to content
Merged
61 changes: 61 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,63 @@ def assert_type(self, t: Type[enum.Enum], v: T):
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")


class LiteralTypeTransformer(TypeTransformer[object]):
def __init__(self):
super().__init__("LiteralTypeTransformer", object)

@classmethod
def get_base_type(cls, t: Type) -> Type:
args = get_args(t)
if not args:
raise TypeTransformerFailedError("Literal must have at least one value")

base_type = type(args[0])
if not all(type(a) == base_type for a in args):
raise TypeTransformerFailedError("All values must be of the same type")

return base_type

def get_literal_type(self, t: Type) -> LiteralType:
base_type = self.get_base_type(t)
vals = list(get_args(t))
ann = TypeAnnotationModel(annotations={"literal_values": vals})
if base_type is str:
simple = SimpleType.STRING
elif base_type is int:
simple = SimpleType.INTEGER
elif base_type is float:
simple = SimpleType.FLOAT
elif base_type is bool:
simple = SimpleType.BOOLEAN
elif base_type is datetime.datetime:
simple = SimpleType.DATETIME
elif base_type is datetime.timedelta:
simple = SimpleType.DURATION
else:
raise TypeTransformerFailedError(f"Unsupported type: {base_type}")
return LiteralType(simple=simple, annotation=ann)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type, expected: LiteralType) -> Literal:
base_type = self.get_base_type(python_type)
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
return base_transformer.to_literal(ctx, python_val, python_type, expected)

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> object:
base_type = self.get_base_type(expected_python_type)
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
return base_transformer.to_python_value(ctx, lv, base_type)

def guess_python_type(self, literal_type: LiteralType) -> Type:
if literal_type.annotation and literal_type.annotation.annotations:
return typing.Literal[tuple(literal_type.annotation.annotations.get("literal_values"))] # type: ignore
raise ValueError(f"LiteralType transformer cannot reverse {literal_type}")

def assert_type(self, python_type: Type, python_val: T):
base_type = self.get_base_type(python_type)
base_transformer: TypeTransformer[object] = TypeEngine.get_transformer(base_type)
return base_transformer.assert_type(base_type, python_val)


def _handle_json_schema_property(
property_key: str,
property_val: dict,
Expand Down Expand Up @@ -1174,6 +1231,7 @@ class TypeEngine(typing.Generic[T]):
_RESTRICTED_TYPES: typing.List[type] = []
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
_ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore
_LITERAL_TYPE_TRANSFORMER: TypeTransformer = LiteralTypeTransformer()
lazy_import_lock = threading.Lock()
_LITERAL_CACHE: LRUCache = LRUCache(maxsize=128)

Expand Down Expand Up @@ -1224,6 +1282,9 @@ def _get_transformer(cls, python_type: Type) -> Optional[TypeTransformer[T]]:
# Special case: prevent that for a type `FooEnum(str, Enum)`, the str transformer is used.
return cls._ENUM_TRANSFORMER

if get_origin(python_type) == typing.Literal:
return cls._LITERAL_TYPE_TRANSFORMER

if hasattr(python_type, "__origin__"):
# If the type is a generic type, we should check the origin type. But consider the case like Iterator[JSON]
# or List[int] has been specifically registered; we should check for the entire type.
Expand Down
41 changes: 41 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3993,3 +3993,44 @@ def test_type_engine_cache_with_flytefile():
assert mock_async_to_literal.call_count == 1

assert lv1 is lv2

def test_literal_transformer_string_type():
# Python -> Flyte
t = typing.Literal["outcome", "income"]
lt = TypeEngine.get_transformer(t).get_literal_type(t)
assert lt.simple == SimpleType.STRING
assert lt.annotation.annotations["literal_values"] == ["outcome", "income"]
assert lt == LiteralType.from_flyte_idl(lt.to_flyte_idl())

lv = TypeEngine.to_literal(FlyteContext.current_context(), "outcome", t, lt)
assert lv.scalar.primitive.string_value == "outcome"

# Flyte -> Python (reconstruction)
pt = TypeEngine.get_transformer(t).guess_python_type(lt)
assert pt is typing.Literal["outcome", "income"]
pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt)
TypeEngine.get_transformer(pt).assert_type(pt, pv)
assert pv == "outcome"

def test_literal_transformer_int_type():
# Python -> Flyte
t = typing.Literal[1, 2, 3]
lt = TypeEngine.get_transformer(t).get_literal_type(t)
assert lt.simple == SimpleType.INTEGER
assert lt.annotation.annotations["literal_values"] == [1, 2, 3]
assert lt == LiteralType.from_flyte_idl(lt.to_flyte_idl())

lv = TypeEngine.to_literal(FlyteContext.current_context(), 1, t, lt)
assert lv.scalar.primitive.integer == 1

# Flyte -> Python (reconstruction)
pt = TypeEngine.get_transformer(t).guess_python_type(lt)
assert pt is typing.Literal[1, 2, 3]
pv = TypeEngine.get_transformer(pt).to_python_value(FlyteContext.current_context(), lv, pt)
TypeEngine.get_transformer(pt).assert_type(pt, pv)
assert pv == 1

def test_literal_transformer_mixed_base_types():
t = typing.Literal["a", 1]
with pytest.raises(TypeTransformerFailedError):
TypeEngine.get_transformer(t).get_literal_type(t)