Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from flytekit.core import tracker
from flytekit.core.array_node import array_node
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.container_task import ContainerTask
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.promise import Promise, create_native_named_tuple, create_task_output
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.task import ReferenceTask
from flytekit.core.type_engine import TypeEngine
Expand All @@ -36,7 +38,7 @@
def __init__(
self,
# TODO: add support for other Flyte entities
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, ContainerTask, functools.partial],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
Expand Down Expand Up @@ -66,10 +68,10 @@
isinstance(actual_task, PythonFunctionTask)
and actual_task.execution_mode == PythonFunctionTask.ExecutionBehavior.DEFAULT
)
or isinstance(actual_task, PythonInstanceTask)
or isinstance(actual_task, (PythonInstanceTask, ContainerTask))
):
raise ValueError(
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks."
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager), PythonInstanceTask, and ContainerTask are supported in map tasks."
)

n_outputs = len(actual_task.python_interface.outputs)
Expand Down Expand Up @@ -101,6 +103,9 @@
if isinstance(actual_task, PythonInstanceTask):
mod = actual_task.task_type
f = actual_task.lhs
elif isinstance(actual_task, ContainerTask):
mod = actual_task.task_type
f = actual_task.name

Check warning on line 108 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L107-L108

Added lines #L107 - L108 were not covered by tests
else:
_, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function)
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs))
Expand Down Expand Up @@ -205,14 +210,20 @@
return self.python_function_task.get_config(settings)

def get_container(self, settings: SerializationSettings) -> Container:
if isinstance(self._run_task, ContainerTask):
return self.python_function_task.get_container(settings)

Check warning on line 214 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L214

Added line #L214 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
We can modify self.prepare_target, do nothing for container task

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Thanks

with self.prepare_target():
return self.python_function_task.get_container(settings)

def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod:
if isinstance(self._run_task, ContainerTask):
return self.python_function_task.get_k8s_pod(settings)

Check warning on line 220 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L220

Added line #L220 was not covered by tests
with self.prepare_target():
return self.python_function_task.get_k8s_pod(settings)

def get_sql(self, settings: SerializationSettings) -> Sql:
if isinstance(self._run_task, ContainerTask):
return self.python_function_task.get_sql(settings)

Check warning on line 226 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L226

Added line #L226 was not covered by tests
with self.prepare_target():
return self.python_function_task.get_sql(settings)

Expand Down Expand Up @@ -261,9 +272,10 @@
return super().__call__(*args, **kwargs)

def _literal_map_to_python_input(
self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext
self, literal_map: _literal_models.LiteralMap, ctx: Optional[FlyteContext] = None
) -> Dict[str, Any]:
ctx = FlyteContextManager.current_context()
if ctx is None:
ctx = FlyteContextManager.current_context()

Check warning on line 278 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L278

Added line #L278 was not covered by tests
inputs_interface = self.python_interface.inputs
inputs_map = literal_map
# If we run locally, we will need to process all of the inputs. If we are running in a remote task execution
Expand Down Expand Up @@ -368,6 +380,13 @@
single_instance_inputs[k] = kwargs[k]
try:
o = self._run_task.execute(**single_instance_inputs)
# For Container task, it will return the LiteralMap. We need to convert it to native
# type here.
Copy link
Contributor

@wild-endeavor wild-endeavor May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@machichima can you add comment please that this is only for local execution for container tasks? this code doesn't run for backend cluster runs of container task.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I make ContainerTask return Python type instead of LiteralMap. So this part is not needed anymore.
Thanks!

if isinstance(o, _literal_models.LiteralMap):
vals = [Promise(var, o.literals[var]) for var in o.literals.keys()]
result = create_task_output(vals, self.python_interface)
ctx = FlyteContextManager.current_context()
o = create_native_named_tuple(ctx, result, self._run_task.python_interface)

Check warning on line 389 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L386-L389

Added lines #L386 - L389 were not covered by tests
if outputs_expected:
outputs.append(o)
except Exception as exc:
Expand All @@ -381,7 +400,7 @@


def map_task(
target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
target: Union[LaunchPlan, PythonFunctionTask, ContainerTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
Expand Down Expand Up @@ -418,7 +437,7 @@


def array_node_map_task(
task_function: PythonFunctionTask,
task_function: Union[PythonFunctionTask, ContainerTask],
concurrency: Optional[int] = None,
# TODO why no min_successes?
min_success_ratio: float = 1.0,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_serializable_task(
# ly for certain kinds of tasks. Specifically,
# tasks that rely on user code defined in the container. This should be encapsulated by the auto container
# parent class
container._args = prefix_with_fast_execute(settings, container.args)
container._args = prefix_with_fast_execute(settings, container.args or [])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args for ContainerTask is None, adding this to prevent error


# If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect.
# The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because
Expand Down
Loading
Loading