-
Notifications
You must be signed in to change notification settings - Fork 333
[Core feature] map_task to support ContainerTask #3249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 8 commits
281bd66
1d16454
137ab21
70b8bef
0d34130
e96b480
c53a64e
5c95774
6aae613
0f6dd70
4827e55
006493a
f15415b
0126bf3
d8b4362
2b9387d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,11 @@ | |
from flytekit.core import tracker | ||
from flytekit.core.array_node import array_node | ||
from flytekit.core.base_task import PythonTask, TaskResolverMixin | ||
from flytekit.core.container_task import ContainerTask | ||
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager | ||
from flytekit.core.interface import transform_interface_to_list_interface | ||
from flytekit.core.launch_plan import LaunchPlan | ||
from flytekit.core.promise import Promise, create_native_named_tuple, create_task_output | ||
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask | ||
from flytekit.core.task import ReferenceTask | ||
from flytekit.core.type_engine import TypeEngine | ||
|
@@ -36,7 +38,7 @@ | |
def __init__( | ||
self, | ||
# TODO: add support for other Flyte entities | ||
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial], | ||
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, ContainerTask, functools.partial], | ||
concurrency: Optional[int] = None, | ||
min_successes: Optional[int] = None, | ||
min_success_ratio: Optional[float] = None, | ||
|
@@ -66,10 +68,10 @@ | |
isinstance(actual_task, PythonFunctionTask) | ||
and actual_task.execution_mode == PythonFunctionTask.ExecutionBehavior.DEFAULT | ||
) | ||
or isinstance(actual_task, PythonInstanceTask) | ||
or isinstance(actual_task, (PythonInstanceTask, ContainerTask)) | ||
): | ||
raise ValueError( | ||
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager) and PythonInstanceTask are supported in map tasks." | ||
"Only PythonFunctionTask with default execution mode (not @dynamic or @eager), PythonInstanceTask, and ContainerTask are supported in map tasks." | ||
) | ||
|
||
n_outputs = len(actual_task.python_interface.outputs) | ||
|
@@ -101,6 +103,9 @@ | |
if isinstance(actual_task, PythonInstanceTask): | ||
mod = actual_task.task_type | ||
f = actual_task.lhs | ||
elif isinstance(actual_task, ContainerTask): | ||
mod = actual_task.task_type | ||
f = actual_task.name | ||
else: | ||
_, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function) | ||
sorted_bounded_inputs = ",".join(sorted(self._bound_inputs)) | ||
|
@@ -205,14 +210,20 @@ | |
return self.python_function_task.get_config(settings) | ||
|
||
def get_container(self, settings: SerializationSettings) -> Container: | ||
if isinstance(self._run_task, ContainerTask): | ||
return self.python_function_task.get_container(settings) | ||
with self.prepare_target(): | ||
return self.python_function_task.get_container(settings) | ||
|
||
def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod: | ||
if isinstance(self._run_task, ContainerTask): | ||
return self.python_function_task.get_k8s_pod(settings) | ||
with self.prepare_target(): | ||
return self.python_function_task.get_k8s_pod(settings) | ||
|
||
def get_sql(self, settings: SerializationSettings) -> Sql: | ||
if isinstance(self._run_task, ContainerTask): | ||
return self.python_function_task.get_sql(settings) | ||
with self.prepare_target(): | ||
return self.python_function_task.get_sql(settings) | ||
|
||
|
@@ -261,9 +272,10 @@ | |
return super().__call__(*args, **kwargs) | ||
|
||
def _literal_map_to_python_input( | ||
self, literal_map: _literal_models.LiteralMap, ctx: FlyteContext | ||
self, literal_map: _literal_models.LiteralMap, ctx: Optional[FlyteContext] = None | ||
) -> Dict[str, Any]: | ||
ctx = FlyteContextManager.current_context() | ||
if ctx is None: | ||
ctx = FlyteContextManager.current_context() | ||
inputs_interface = self.python_interface.inputs | ||
inputs_map = literal_map | ||
# If we run locally, we will need to process all of the inputs. If we are running in a remote task execution | ||
|
@@ -368,6 +380,13 @@ | |
single_instance_inputs[k] = kwargs[k] | ||
try: | ||
o = self._run_task.execute(**single_instance_inputs) | ||
# For Container task, it will return the LiteralMap. We need to convert it to native | ||
# type here. | ||
|
||
if isinstance(o, _literal_models.LiteralMap): | ||
vals = [Promise(var, o.literals[var]) for var in o.literals.keys()] | ||
result = create_task_output(vals, self.python_interface) | ||
ctx = FlyteContextManager.current_context() | ||
o = create_native_named_tuple(ctx, result, self._run_task.python_interface) | ||
if outputs_expected: | ||
outputs.append(o) | ||
except Exception as exc: | ||
|
@@ -381,7 +400,7 @@ | |
|
||
|
||
def map_task( | ||
target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"], | ||
target: Union[LaunchPlan, PythonFunctionTask, ContainerTask, "FlyteLaunchPlan"], | ||
concurrency: Optional[int] = None, | ||
min_successes: Optional[int] = None, | ||
min_success_ratio: float = 1.0, | ||
|
@@ -418,7 +437,7 @@ | |
|
||
|
||
def array_node_map_task( | ||
task_function: PythonFunctionTask, | ||
task_function: Union[PythonFunctionTask, ContainerTask], | ||
concurrency: Optional[int] = None, | ||
# TODO why no min_successes? | ||
min_success_ratio: float = 1.0, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,7 +170,7 @@ def get_serializable_task( | |
# ly for certain kinds of tasks. Specifically, | ||
# tasks that rely on user code defined in the container. This should be encapsulated by the auto container | ||
# parent class | ||
container._args = prefix_with_fast_execute(settings, container.args) | ||
container._args = prefix_with_fast_execute(settings, container.args or []) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. | ||
# The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
We can modify
self.prepare_target
, do nothing for container taskThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed! Thanks