Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flytekit.core import tracker
from flytekit.core.array_node import array_node
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.container_task import ContainerTask
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
Expand All @@ -36,7 +37,7 @@
def __init__(
self,
# TODO: add support for other Flyte entities
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, ContainerTask, functools.partial],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
Expand Down Expand Up @@ -66,10 +67,10 @@
isinstance(actual_task, PythonFunctionTask)
and actual_task.execution_mode == PythonFunctionTask.ExecutionBehavior.DEFAULT
)
or isinstance(actual_task, PythonInstanceTask)
or isinstance(actual_task, (PythonInstanceTask, ContainerTask))
):
raise ValueError(
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks."
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager), PythonInstanceTask, and ContainerTask are supported in map tasks."
)

n_outputs = len(actual_task.python_interface.outputs)
Expand Down Expand Up @@ -101,6 +102,9 @@
if isinstance(actual_task, PythonInstanceTask):
mod = actual_task.task_type
f = actual_task.lhs
elif isinstance(actual_task, ContainerTask):
mod = actual_task.task_type
f = actual_task.name
else:
_, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function)
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs))
Expand Down Expand Up @@ -192,6 +196,10 @@
"""
Alters the underlying run_task command to modify it for map task execution and then resets it after.
"""
if isinstance(self._run_task, ContainerTask):
yield
return

Check warning on line 201 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L200-L201

Added lines #L200 - L201 were not covered by tests

Comment on lines +199 to +202
Copy link
Member

@Future-Outlier Future-Outlier Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just return directly?
update: we need to return a generator, please ignore my comment.

self.python_function_task.set_command_fn(self.get_command)
try:
yield
Expand Down Expand Up @@ -261,9 +269,10 @@
return super().__call__(*args, **kwargs)

def _literal_map_to_python_input(
self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext
self, literal_map: _literal_models.LiteralMap, ctx: Optional[FlyteContext] = None
) -> Dict[str, Any]:
ctx = FlyteContextManager.current_context()
if ctx is None:
ctx = FlyteContextManager.current_context()

Check warning on line 275 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L275

Added line #L275 was not covered by tests
inputs_interface = self.python_interface.inputs
inputs_map = literal_map
# If we run locally, we will need to process all of the inputs. If we are running in a remote task execution
Expand Down Expand Up @@ -381,7 +390,7 @@


def map_task(
target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
target: Union[LaunchPlan, PythonFunctionTask, ContainerTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
Expand Down Expand Up @@ -418,7 +427,7 @@


def array_node_map_task(
task_function: PythonFunctionTask,
task_function: Union[PythonFunctionTask, ContainerTask],
concurrency: Optional[int] = None,
# TODO why no min_successes?
min_success_ratio: float = 1.0,
Expand Down
13 changes: 7 additions & 6 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Secret, SecurityContext

_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
Expand Down Expand Up @@ -254,14 +253,12 @@
output_dict[k] = self._convert_output_val_to_correct_type(output_val, output_type)
return output_dict

def execute(self, **kwargs) -> LiteralMap:
def execute(self, **kwargs) -> Any:
try:
import docker
except ImportError:
raise ImportError(DOCKER_IMPORT_ERROR_MESSAGE)

from flytekit.core.type_engine import TypeEngine

ctx = FlyteContext.current_context()

# Normalize the input and output directories
Expand Down Expand Up @@ -289,8 +286,12 @@
container.wait()

output_dict = self._get_output_dict(output_directory)
outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict)
return outputs_literal_map
if len(output_dict) == 0:
return None

Check warning on line 290 in flytekit/core/container_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/container_task.py#L290

Added line #L290 was not covered by tests
elif len(output_dict) == 1:
return list(output_dict.values())[0]
elif len(output_dict) > 1:
return tuple(output_dict.values())

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container
Expand Down
10 changes: 5 additions & 5 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ def get_serializable_task(
if settings.should_fast_serialize():
# This handles container tasks.
if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask, ArrayNodeMapTask)):
# For fast registration, we'll need to muck with the command, but on
# ly for certain kinds of tasks. Specifically,
# tasks that rely on user code defined in the container. This should be encapsulated by the auto container
# parent class
container._args = prefix_with_fast_execute(settings, container.args)
# For fast registration, we'll need to muck with the command, but
# only for certain kinds of tasks. Specifically, tasks that rely
# on user code defined in the container. This should be
# encapsulated by the auto container parent class
container._args = prefix_with_fast_execute(settings, container.args or [])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args for ContainerTask is None, adding this to prevent error


# If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect.
# The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because
Expand Down
47 changes: 46 additions & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math
import functools
import botocore.session
import shutil
from contextlib import ExitStack, contextmanager
Expand All @@ -20,7 +22,7 @@
import string
from dataclasses import asdict, dataclass

from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow
from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow, ContainerTask, map_task
from flytekit.configuration import Config, ImageConfig, SerializationSettings
from flytekit.core.launch_plan import reference_launch_plan
from flytekit.core.task import reference_task
Expand Down Expand Up @@ -1358,3 +1360,46 @@ def test_run_wf_with_resource_requests_override(register):
],
limits=[],
)


def test_container_task_map_execution():
# NOTE: We only take one output "area" even if this calculate-ellipse-area.py
# produce two output. This is because that map task can only return one value.
calculate_ellipse_area_python_template_style = ContainerTask(
name="calculate_ellipse_area_python_template_style",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(a=float, b=float),
outputs=kwtypes(area=float),
image="ghcr.io/flyteorg/rawcontainers-python:v2",
command=[
"python",
"calculate-ellipse-area.py",
"{{.inputs.a}}",
"{{.inputs.b}}",
"/var/outputs",
],
)

@workflow
def wf(a: list[float], b: float) -> list[float]:
partial_task = functools.partial(
calculate_ellipse_area_python_template_style, b=b
)
res = map_task(partial_task)(a=a)
return res

def calculate_area(a, b):
return math.pi * a * b

expected_area = [
calculate_area(a, b) for a, b in [(3.0, 4.0), (4.0, 4.0), (5.0, 4.0)]
]


remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(wf, inputs={"a": [3.0, 4.0, 5.0], "b": 4.0}, wait=True, version=VERSION)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=2))
assert execution.error is None, f"Execution failed with error: {execution.error}"

assert execution.outputs["o0"] == expected_area
Loading
Loading