Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flytekit.core import tracker
from flytekit.core.array_node import array_node
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.container_task import ContainerTask
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
Expand All @@ -36,7 +37,7 @@ class ArrayNodeMapTask(PythonTask):
def __init__(
self,
# TODO: add support for other Flyte entities
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, ContainerTask, functools.partial],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
Expand Down Expand Up @@ -66,10 +67,10 @@ def __init__(
isinstance(actual_task, PythonFunctionTask)
and actual_task.execution_mode == PythonFunctionTask.ExecutionBehavior.DEFAULT
)
or isinstance(actual_task, PythonInstanceTask)
or isinstance(actual_task, (PythonInstanceTask, ContainerTask))
):
raise ValueError(
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks."
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager), PythonInstanceTask, and ContainerTask are supported in map tasks."
)

n_outputs = len(actual_task.python_interface.outputs)
Expand Down Expand Up @@ -101,6 +102,9 @@ def __init__(
if isinstance(actual_task, PythonInstanceTask):
mod = actual_task.task_type
f = actual_task.lhs
elif isinstance(actual_task, ContainerTask):
mod = actual_task.task_type
f = actual_task.name
else:
_, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function)
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs))
Expand Down Expand Up @@ -192,6 +196,10 @@ def prepare_target(self):
"""
Alters the underlying run_task command to modify it for map task execution and then resets it after.
"""
if isinstance(self._run_task, ContainerTask):
yield
return

Comment on lines +199 to +202
Copy link
Member

@Future-Outlier Future-Outlier Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just return directly?
update: we need to return a generator, please ignore my comment.

self.python_function_task.set_command_fn(self.get_command)
try:
yield
Expand Down Expand Up @@ -261,9 +269,10 @@ def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)

def _literal_map_to_python_input(
self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext
self, literal_map: _literal_models.LiteralMap, ctx: Optional[FlyteContext] = None
) -> Dict[str, Any]:
ctx = FlyteContextManager.current_context()
if ctx is None:
ctx = FlyteContextManager.current_context()
inputs_interface = self.python_interface.inputs
inputs_map = literal_map
# If we run locally, we will need to process all of the inputs. If we are running in a remote task execution
Expand Down Expand Up @@ -381,7 +390,7 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
target: Union[LaunchPlan, PythonFunctionTask, ContainerTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
Expand Down Expand Up @@ -418,7 +427,7 @@ def map_task(


def array_node_map_task(
task_function: PythonFunctionTask,
task_function: Union[PythonFunctionTask, ContainerTask],
concurrency: Optional[int] = None,
# TODO why no min_successes?
min_success_ratio: float = 1.0,
Expand Down
13 changes: 7 additions & 6 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Secret, SecurityContext

_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
Expand Down Expand Up @@ -254,14 +253,12 @@ def _get_output_dict(self, output_directory: str) -> Dict[str, Any]:
output_dict[k] = self._convert_output_val_to_correct_type(output_val, output_type)
return output_dict

def execute(self, **kwargs) -> LiteralMap:
def execute(self, **kwargs) -> Any:
try:
import docker
except ImportError:
raise ImportError(DOCKER_IMPORT_ERROR_MESSAGE)

from flytekit.core.type_engine import TypeEngine

ctx = FlyteContext.current_context()

# Normalize the input and output directories
Expand Down Expand Up @@ -289,8 +286,12 @@ def execute(self, **kwargs) -> LiteralMap:
container.wait()

output_dict = self._get_output_dict(output_directory)
outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict)
return outputs_literal_map
if len(output_dict) == 0:
return None
elif len(output_dict) == 1:
return list(output_dict.values())[0]
elif len(output_dict) > 1:
return tuple(output_dict.values())

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container
Expand Down
10 changes: 5 additions & 5 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ def get_serializable_task(
if settings.should_fast_serialize():
# This handles container tasks.
if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask, ArrayNodeMapTask)):
# For fast registration, we'll need to muck with the command, but on
# ly for certain kinds of tasks. Specifically,
# tasks that rely on user code defined in the container. This should be encapsulated by the auto container
# parent class
container._args = prefix_with_fast_execute(settings, container.args)
# For fast registration, we'll need to muck with the command, but
# only for certain kinds of tasks. Specifically, tasks that rely
# on user code defined in the container. This should be
# encapsulated by the auto container parent class
container._args = prefix_with_fast_execute(settings, container.args or [])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args for ContainerTask is None, adding this to prevent error


# If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect.
# The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because
Expand Down
Loading
Loading