Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flytekit.core import tracker
from flytekit.core.array_node import array_node
from flytekit.core.base_task import PythonTask, TaskResolverMixin
from flytekit.core.container_task import ContainerTask
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
Expand All @@ -36,7 +37,7 @@
def __init__(
self,
# TODO: add support for other Flyte entities
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, ContainerTask, functools.partial],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
Expand Down Expand Up @@ -66,10 +67,10 @@
isinstance(actual_task, PythonFunctionTask)
and actual_task.execution_mode == PythonFunctionTask.ExecutionBehavior.DEFAULT
)
or isinstance(actual_task, PythonInstanceTask)
or isinstance(actual_task, (PythonInstanceTask, ContainerTask))
):
raise ValueError(
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks."
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager), PythonInstanceTask, and ContainerTask are supported in map tasks."
)

n_outputs = len(actual_task.python_interface.outputs)
Expand Down Expand Up @@ -101,6 +102,9 @@
if isinstance(actual_task, PythonInstanceTask):
mod = actual_task.task_type
f = actual_task.lhs
elif isinstance(actual_task, ContainerTask):
mod = actual_task.task_type
f = actual_task.name
else:
_, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function)
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs))
Expand Down Expand Up @@ -205,14 +209,20 @@
return self.python_function_task.get_config(settings)

def get_container(self, settings: SerializationSettings) -> Container:
if isinstance(self._run_task, ContainerTask):
return self.python_function_task.get_container(settings)

Check warning on line 213 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L213

Added line #L213 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
We can modify self.prepare_target, do nothing for container task

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Thanks

with self.prepare_target():
return self.python_function_task.get_container(settings)

def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod:
if isinstance(self._run_task, ContainerTask):
return self.python_function_task.get_k8s_pod(settings)

Check warning on line 219 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L219

Added line #L219 was not covered by tests
with self.prepare_target():
return self.python_function_task.get_k8s_pod(settings)

def get_sql(self, settings: SerializationSettings) -> Sql:
if isinstance(self._run_task, ContainerTask):
return self.python_function_task.get_sql(settings)

Check warning on line 225 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L225

Added line #L225 was not covered by tests
with self.prepare_target():
return self.python_function_task.get_sql(settings)

Expand Down Expand Up @@ -261,9 +271,10 @@
return super().__call__(*args, **kwargs)

def _literal_map_to_python_input(
self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext
self, literal_map: _literal_models.LiteralMap, ctx: Optional[FlyteContext] = None
) -> Dict[str, Any]:
ctx = FlyteContextManager.current_context()
if ctx is None:
ctx = FlyteContextManager.current_context()

Check warning on line 277 in flytekit/core/array_node_map_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/array_node_map_task.py#L277

Added line #L277 was not covered by tests
inputs_interface = self.python_interface.inputs
inputs_map = literal_map
# If we run locally, we will need to process all of the inputs. If we are running in a remote task execution
Expand Down Expand Up @@ -381,7 +392,7 @@


def map_task(
target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
target: Union[LaunchPlan, PythonFunctionTask, ContainerTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
Expand Down Expand Up @@ -418,7 +429,7 @@


def array_node_map_task(
task_function: PythonFunctionTask,
task_function: Union[PythonFunctionTask, ContainerTask],
concurrency: Optional[int] = None,
# TODO why no min_successes?
min_success_ratio: float = 1.0,
Expand Down
13 changes: 7 additions & 6 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Secret, SecurityContext

_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
Expand Down Expand Up @@ -254,14 +253,12 @@
output_dict[k] = self._convert_output_val_to_correct_type(output_val, output_type)
return output_dict

def execute(self, **kwargs) -> LiteralMap:
def execute(self, **kwargs) -> Any:
try:
import docker
except ImportError:
raise ImportError(DOCKER_IMPORT_ERROR_MESSAGE)

from flytekit.core.type_engine import TypeEngine

ctx = FlyteContext.current_context()

# Normalize the input and output directories
Expand Down Expand Up @@ -289,8 +286,12 @@
container.wait()

output_dict = self._get_output_dict(output_directory)
outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict)
return outputs_literal_map
if len(output_dict) == 0:
return None

Check warning on line 290 in flytekit/core/container_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/container_task.py#L290

Added line #L290 was not covered by tests
elif len(output_dict) == 1:
return list(output_dict.values())[0]
elif len(output_dict) > 1:
return tuple(output_dict.values())

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container
Expand Down
10 changes: 5 additions & 5 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ def get_serializable_task(
if settings.should_fast_serialize():
# This handles container tasks.
if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask, ArrayNodeMapTask)):
# For fast registration, we'll need to muck with the command, but on
# ly for certain kinds of tasks. Specifically,
# tasks that rely on user code defined in the container. This should be encapsulated by the auto container
# parent class
container._args = prefix_with_fast_execute(settings, container.args)
# For fast registration, we'll need to muck with the command, but
# only for certain kinds of tasks. Specifically, tasks that rely
# on user code defined in the container. This should be
# encapsulated by the auto container parent class
container._args = prefix_with_fast_execute(settings, container.args or [])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args for ContainerTask is None, adding this to prevent error


# If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect.
# The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because
Expand Down
Loading
Loading