Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,16 +885,22 @@ async def binding_data_from_python_std(
# akin to what the Type Engine does when it finds a Union type (see the UnionTransformer), but we can't rely on
# that in this case, because of the mix and match of realized values, and Promises.
for i in range(len(expected_literal_type.union_type.variants)):
lt_type = expected_literal_type.union_type.variants[i]
python_type = get_args(t_value_type)[i] if t_value_type else None
try:
lt_type = expected_literal_type.union_type.variants[i]
python_type = get_args(t_value_type)[i] if t_value_type else None
return await binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes)
except Exception:
except Exception as e:
logger.debug(
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}."
f"Failed to bind data {t_value} "
f"using variant[{i}] literal type={repr(lt_type)} (expected overall {expected_literal_type}) "
f"and python type={python_type} (expected overall {t_value_type}). "
f"Error: {e}"
)
raise AssertionError(
f"Failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants}."
f"Failed to bind data {t_value} "
f"to any of the expected union variants.\n"
f"Value python type: {type(t_value).__name__}, declared python types: {t_value_type}\n"
f"Expected literal type: {repr(expected_literal_type)}"
)

elif (
Expand Down
33 changes: 18 additions & 15 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,21 +1630,24 @@ async def _literal_map_to_kwargs(
f" than allowed by the input spec {len(python_interface_inputs)}"
)
kwargs = {}
try:
for i, k in enumerate(lm.literals):
kwargs[k] = asyncio.create_task(
TypeEngine.async_to_python_value(ctx, lm.literals[k], python_interface_inputs[k])
)
await asyncio.gather(*kwargs.values())
except Exception as e:
raise TypeTransformerFailedError(
f"Error converting input '{k}' at position {i}:\n"
f"Literal value: {lm.literals[k]}\n"
f"Expected Python type: {python_interface_inputs[k]}\n"
f"Exception: {e}"
for i, k in enumerate(lm.literals):
kwargs[k] = asyncio.create_task(
TypeEngine.async_to_python_value(ctx, lm.literals[k], python_interface_inputs[k])
)
if kwargs:
await asyncio.wait(kwargs.values())

for k, t in kwargs.items():
try:
kwargs[k] = t.result()
except Exception as e:
raise TypeTransformerFailedError(
f"Error converting input '{k}':\n"
f"Literal value: {lm.literals[k]!r}\n"
f"Expected Python type: {python_interface_inputs[k]!r}\n"
f"Exception: {e}"
)

kwargs = {k: v.result() for k, v in kwargs.items() if v is not None}
return kwargs

@classmethod
Expand Down Expand Up @@ -2096,7 +2099,7 @@ async def async_to_python_value(
res_tag = trans.name
found_res = True
except Exception as e:
logger.debug(f"Failed to convert from {lv} to {v} with error: {e}")
logger.debug(f"Failed to convert from {repr(lv)} to {v} with error: {e}")

if is_ambiguous:
raise TypeError(
Expand All @@ -2107,7 +2110,7 @@ async def async_to_python_value(
if found_res:
return res

raise TypeError(f"Cannot convert from {lv} to {expected_python_type} (using tag {union_tag})")
raise TypeError(f"Cannot convert from {repr(lv)} to {expected_python_type} (using tag {union_tag})")

def guess_python_type(self, literal_type: LiteralType) -> type:
if literal_type.union_type is not None:
Expand Down
10 changes: 2 additions & 8 deletions flytekit/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def __ne__(self, other):
return not (self == other)

def __repr__(self):
return self.short_string()
return str(self.to_flyte_idl())

def __str__(self):
return self.verbose_string()
return self.short_string()

def __hash__(self):
return hash(self.to_flyte_idl().SerializeToString(deterministic=True))
Expand All @@ -90,12 +90,6 @@ def short_string(self):
type_str = type(self).__name__
return f"Flyte Serialized object ({type_str}):" + os.linesep + str_repr

def verbose_string(self):
"""
:rtype: Text
"""
return self.short_string()

def serialize_to_string(self) -> str:
return self.to_flyte_idl().SerializeToString()

Expand Down
25 changes: 14 additions & 11 deletions tests/flytekit/unit/core/test_type_conversion_errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit tests for type conversion errors."""

import re
from datetime import timedelta
from string import ascii_lowercase
from typing import Tuple
Expand All @@ -9,6 +10,7 @@
from hypothesis import strategies as st

from flytekit import task, workflow
from flytekit.core.type_engine import TypeTransformerFailedError


@task
Expand Down Expand Up @@ -67,11 +69,12 @@ def test_task_input_error(incorrect_input):
@settings(deadline=timedelta(seconds=2))
def test_task_output_error(correct_input):
with pytest.raises(
TypeError,
TypeTransformerFailedError,
match=(
r"Failed to convert outputs of task '{}' at position 0:\n"
r" Expected value of type \<class 'int'\> but got .+ of type .+"
).format(task_incorrect_output.name),
r"Failed to convert outputs of task '{}' at position 0\.\n"
r"Failed to convert type .+ to type .+\.\n"
r"Error Message: Expected value of type \<class 'int'\> but got .+ of type .+"
).format(re.escape(task_incorrect_output.name)),
):
task_incorrect_output(a=correct_input)

Expand All @@ -80,12 +83,12 @@ def test_task_output_error(correct_input):
@settings(deadline=timedelta(seconds=2))
def test_workflow_with_task_error(correct_input):
with pytest.raises(
TypeError,
TypeTransformerFailedError,
match=(
r"Error encountered while executing 'wf_with_task_error':\n"
r" Failed to convert outputs of task '.+' at position 0:\n"
r" Expected value of type \<class 'int'\> but got .+ of type .+"
).format(),
r"Failed to convert outputs of task '.+' at position 0\.\n"
r"Failed to convert type .+ to type .+\.\n"
r"Error Message: Expected value of type \<class 'int'\> but got .+ of type .+"
),
):
wf_with_task_error(a=correct_input)

Expand All @@ -105,7 +108,7 @@ def test_workflow_with_input_error(incorrect_input):
def test_workflow_with_output_error(correct_input):
with pytest.raises(
TypeError,
match=(r"Failed to convert output in position 0 of value .+, expected type \<class 'int'\>"),
match=r"Failed to convert output in position 0 of value [\s\S]+, expected type \<class 'int'\>",
):
wf_with_output_error(a=correct_input)

Expand All @@ -122,6 +125,6 @@ def test_workflow_with_output_error(correct_input):
def test_workflow_with_multioutput_error(workflow, position, correct_input):
with pytest.raises(
TypeError,
match=(r"Failed to convert output in position {} of value .+, expected type \<class 'int'\>").format(position),
match=(r"Failed to convert output in position {} of value [\s\S]+, expected type \<class 'int'\>").format(position),
):
workflow(a=correct_input, b=correct_input)
Loading
Loading