Skip to content

Commit 23b9dee

Browse files
arbaobaopingsutw
authored andcommitted
Support overriding task pod_template via with_overrides (flyteorg#2981)
Signed-off-by: Nelson Chen <[email protected]> Signed-off-by: Kevin Su <[email protected]> Co-authored-by: Kevin Su <[email protected]> Signed-off-by: Atharva <[email protected]>
1 parent 10fbb22 commit 23b9dee

File tree

7 files changed

+131
-56
lines changed

7 files changed

+131
-56
lines changed

flytekit/core/node.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from flyteidl.core import tasks_pb2
88

9+
from flytekit.core.pod_template import PodTemplate
910
from flytekit.core.resources import Resources, convert_resources_to_resource_model
1011
from flytekit.core.utils import _dnsify
1112
from flytekit.extras.accelerators import BaseAccelerator
@@ -67,6 +68,7 @@ def __init__(
6768
self._resources: typing.Optional[_resources_model] = None
6869
self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None
6970
self._container_image: typing.Optional[str] = None
71+
self._pod_template: typing.Optional[PodTemplate] = None
7072

7173
def runs_before(self, other: Node):
7274
"""
@@ -191,6 +193,7 @@ def with_overrides(
191193
cache: Optional[bool] = None,
192194
cache_version: Optional[str] = None,
193195
cache_serialize: Optional[bool] = None,
196+
pod_template: Optional[PodTemplate] = None,
194197
*args,
195198
**kwargs,
196199
):
@@ -241,6 +244,10 @@ def with_overrides(
241244

242245
self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)
243246

247+
if pod_template is not None:
248+
assert_not_promise(pod_template, "podtemplate")
249+
self._pod_template = pod_template
250+
244251
return self
245252

246253

flytekit/models/core/workflow.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import datetime
2+
import json
23
import typing
34

45
from flyteidl.core import tasks_pb2
56
from flyteidl.core import workflow_pb2 as _core_workflow
7+
from google.protobuf import json_format, struct_pb2
68
from google.protobuf.wrappers_pb2 import BoolValue
79

10+
from flytekit.core.pod_template import PodTemplate
811
from flytekit.models import common as _common
912
from flytekit.models import interface as _interface
1013
from flytekit.models import types as type_models
1114
from flytekit.models.core import condition as _condition
1215
from flytekit.models.core import identifier as _identifier
1316
from flytekit.models.literals import Binding as _Binding
1417
from flytekit.models.literals import RetryStrategy as _RetryStrategy
15-
from flytekit.models.task import Resources
18+
from flytekit.models.task import K8sObjectMetadata, Resources
1619

1720

1821
class IfBlock(_common.FlyteIdlEntity):
@@ -615,10 +618,12 @@ def __init__(
615618
resources: typing.Optional[Resources],
616619
extended_resources: typing.Optional[tasks_pb2.ExtendedResources],
617620
container_image: typing.Optional[str] = None,
621+
pod_template: typing.Optional[PodTemplate] = None,
618622
):
619623
self._resources = resources
620624
self._extended_resources = extended_resources
621625
self._container_image = container_image
626+
self._pod_template = pod_template
622627

623628
@property
624629
def resources(self) -> Resources:
@@ -632,11 +637,27 @@ def extended_resources(self) -> tasks_pb2.ExtendedResources:
632637
def container_image(self) -> typing.Optional[str]:
633638
return self._container_image
634639

640+
@property
641+
def pod_template(self) -> typing.Optional[PodTemplate]:
642+
return self._pod_template
643+
635644
def to_flyte_idl(self):
636645
return _core_workflow.TaskNodeOverrides(
637646
resources=self.resources.to_flyte_idl() if self.resources is not None else None,
638647
extended_resources=self.extended_resources,
639648
container_image=self.container_image,
649+
pod_template=tasks_pb2.K8sPod(
650+
metadata=K8sObjectMetadata(
651+
labels=self.pod_template.labels if self.pod_template else None,
652+
annotations=self.pod_template.annotations if self.pod_template else None,
653+
).to_flyte_idl()
654+
if self.pod_template is not None
655+
else None,
656+
pod_spec=json_format.Parse(json.dumps(self.pod_template.pod_spec), struct_pb2.Struct())
657+
if self.pod_template
658+
else None,
659+
primary_container_name=self.pod_template.primary_container_name if self.pod_template else None,
660+
),
640661
)
641662

642663
@classmethod

flytekit/models/task.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,13 +1009,15 @@ def __init__(
10091009
metadata: K8sObjectMetadata = None,
10101010
pod_spec: typing.Dict[str, typing.Any] = None,
10111011
data_config: typing.Optional[DataLoadingConfig] = None,
1012+
primary_container_name: typing.Optional[str] = None,
10121013
):
10131014
"""
10141015
This defines a kubernetes pod target. It will build the pod target during task execution
10151016
"""
10161017
self._metadata = metadata
10171018
self._pod_spec = pod_spec
10181019
self._data_config = data_config
1020+
self._primary_container_name = primary_container_name
10191021

10201022
@property
10211023
def metadata(self) -> K8sObjectMetadata:
@@ -1029,6 +1031,10 @@ def pod_spec(self) -> typing.Dict[str, typing.Any]:
10291031
def data_config(self) -> typing.Optional[DataLoadingConfig]:
10301032
return self._data_config
10311033

1034+
@property
1035+
def primary_container_name(self) -> typing.Optional[str]:
1036+
return self._primary_container_name
1037+
10321038
def to_flyte_idl(self) -> _core_task.K8sPod:
10331039
return _core_task.K8sPod(
10341040
metadata=self._metadata.to_flyte_idl() if self.metadata else None,

flytekit/tools/translator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from flyteidl.admin import schedule_pb2
77

8-
from flytekit import ImageSpec, PythonFunctionTask, SourceCode
8+
from flytekit import ImageSpec, PodTemplate, PythonFunctionTask, SourceCode
99
from flytekit.configuration import Image, ImageConfig, SerializationSettings
1010
from flytekit.core import constants as _common_constants
1111
from flytekit.core import context_manager
@@ -25,7 +25,7 @@
2525
from flytekit.core.python_function_task import EagerAsyncPythonFunctionTask
2626
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate
2727
from flytekit.core.task import ReferenceTask
28-
from flytekit.core.utils import ClassDecorator, _dnsify
28+
from flytekit.core.utils import ClassDecorator, _dnsify, _serialize_pod_spec
2929
from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase
3030
from flytekit.models import common as _common_models
3131
from flytekit.models import interface as interface_models
@@ -453,6 +453,13 @@ def get_serializable_node(
453453
# if entity._aliases:
454454
# node_model._output_aliases = entity._aliases
455455
elif isinstance(entity.flyte_entity, PythonTask):
456+
# handle pod template overrides
457+
override_pod_spec = {}
458+
if entity._pod_template is not None and settings.should_fast_serialize():
459+
entity.flyte_entity.set_command_fn(_fast_serialize_command_fn(settings, entity.flyte_entity))
460+
override_pod_spec = _serialize_pod_spec(
461+
entity._pod_template, entity.flyte_entity._get_container(settings), settings
462+
)
456463
task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options)
457464
node_model = workflow_model.Node(
458465
id=_dnsify(entity.id),
@@ -466,6 +473,16 @@ def get_serializable_node(
466473
resources=entity._resources,
467474
extended_resources=entity._extended_resources,
468475
container_image=entity._container_image,
476+
pod_template=PodTemplate(
477+
pod_spec=override_pod_spec,
478+
labels=entity._pod_template.labels if entity._pod_template.labels else None,
479+
annotations=entity._pod_template.annotations if entity._pod_template.annotations else None,
480+
primary_container_name=entity._pod_template.primary_container_name
481+
if entity._pod_template.primary_container_name
482+
else None,
483+
)
484+
if entity._pod_template
485+
else None,
469486
),
470487
),
471488
)

tests/flytekit/unit/core/test_array_node_map_task.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from flyteidl.core import workflow_pb2 as _core_workflow
1111

12-
from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask, Resources
12+
from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask
1313
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
1414
from flytekit.core import context_manager
1515
from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver
@@ -21,7 +21,6 @@
2121
LiteralMap,
2222
LiteralOffloadedMetadata,
2323
)
24-
from flytekit.models.task import Resources as _resources_models
2524
from flytekit.tools.translator import get_serializable
2625
from flytekit.types.directory import FlyteDirectory
2726

@@ -350,59 +349,16 @@ def my_wf1() -> typing.List[typing.Optional[int]]:
350349
assert my_wf1() == [1, None, 3, 4]
351350

352351

353-
@task
354-
def my_mappable_task(a: int) -> typing.Optional[str]:
355-
return str(a)
356-
357-
358-
@task(
359-
container_image="original-image",
360-
timeout=timedelta(seconds=10),
361-
interruptible=False,
362-
retries=10,
363-
cache=True,
364-
cache_version="original-version",
365-
requests=Resources(cpu=1)
366-
)
367-
def my_mappable_task_1(a: int) -> typing.Optional[str]:
368-
return str(a)
369-
370-
371-
@pytest.mark.parametrize(
372-
"task_func",
373-
[my_mappable_task, my_mappable_task_1]
374-
)
375-
def test_map_task_override(serialization_settings, task_func):
376-
array_node_map_task = map_task(task_func)
352+
def test_map_task_override(serialization_settings):
353+
@task
354+
def my_mappable_task(a: int) -> typing.Optional[str]:
355+
return str(a)
377356

378357
@workflow
379358
def wf(x: typing.List[int]):
380-
array_node_map_task(a=x).with_overrides(
381-
container_image="new-image",
382-
timeout=timedelta(seconds=20),
383-
interruptible=True,
384-
retries=5,
385-
cache=True,
386-
cache_version="new-version",
387-
requests=Resources(cpu=2)
388-
)
389-
390-
assert wf.nodes[0]._container_image == "new-image"
391-
392-
od = OrderedDict()
393-
wf_spec = get_serializable(od, serialization_settings, wf)
359+
map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")
394360

395-
array_node = wf_spec.template.nodes[0]
396-
assert array_node.metadata.timeout == timedelta()
397-
sub_node_spec = array_node.array_node.node
398-
assert sub_node_spec.metadata.timeout == timedelta(seconds=20)
399-
assert sub_node_spec.metadata.interruptible
400-
assert sub_node_spec.metadata.retries.retries == 5
401-
assert sub_node_spec.metadata.cacheable
402-
assert sub_node_spec.metadata.cache_version == "new-version"
403-
assert sub_node_spec.target.overrides.resources.requests == [
404-
_resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "2")
405-
]
361+
assert wf.nodes[0]._container_image == "random:image"
406362

407363

408364
def test_serialization_metadata(serialization_settings):

tests/flytekit/unit/core/test_map_task.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from collections import OrderedDict
44

55
import pytest
6+
from kubernetes.client import V1PodSpec, V1Container, V1EnvVar
67

78
import flytekit.configuration
8-
from flytekit import LaunchPlan, Resources
9+
from flytekit import LaunchPlan, Resources, PodTemplate
910
from flytekit.configuration import Image, ImageConfig
1011
from flytekit.core.legacy_map_task import MapPythonTask, MapTaskResolver, map_task
1112
from flytekit.core.task import TaskMetadata, task
@@ -354,6 +355,39 @@ def wf(x: typing.List[int]):
354355

355356
assert wf.nodes[0]._container_image == "random:image"
356357

358+
def test_map_task_pod_template_override(serialization_settings):
359+
@task
360+
def my_mappable_task(a: int) -> typing.Optional[str]:
361+
return str(a)
362+
363+
@workflow
364+
def wf(x: typing.List[int]):
365+
map_task(my_mappable_task)(a=x).with_overrides(pod_template=PodTemplate(
366+
primary_container_name="primary1",
367+
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
368+
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
369+
pod_spec=V1PodSpec(
370+
containers=[
371+
V1Container(
372+
name="primary1",
373+
image="random:image",
374+
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
375+
),
376+
V1Container(
377+
name="primary2",
378+
image="random:image2",
379+
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
380+
),
381+
],
382+
)
383+
))
384+
385+
386+
assert wf.nodes[0]._pod_template.primary_container_name == "primary1"
387+
assert wf.nodes[0]._pod_template.pod_spec.containers[0].image == "random:image"
388+
assert wf.nodes[0]._pod_template.labels == {"lKeyA": "lValA", "lKeyB": "lValB"}
389+
assert wf.nodes[0]._pod_template.annotations["aKeyA"] == "aValA"
390+
357391

358392
def test_bounded_inputs_vars_order(serialization_settings):
359393
mt = map_task(functools.partial(t3, c=1.0, b="hello", a=1))

tests/flytekit/unit/core/test_node_creation.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from dataclasses import dataclass
55

66
import pytest
7+
from kubernetes.client import V1PodSpec, V1Container, V1EnvVar
78

89
import flytekit.configuration
9-
from flytekit import Resources, map_task
10+
from flytekit import Resources, map_task, PodTemplate
1011
from flytekit.configuration import Image, ImageConfig
1112
from flytekit.core.dynamic_workflow_task import dynamic
1213
from flytekit.core.node_creation import create_node
@@ -470,6 +471,39 @@ def wf() -> str:
470471

471472
assert wf.nodes[0]._container_image == "hello/world"
472473

474+
def test_pod_template_override():
475+
@task
476+
def bar():
477+
print("hello")
478+
479+
@workflow
480+
def wf() -> str:
481+
bar().with_overrides(pod_template=PodTemplate(
482+
primary_container_name="primary1",
483+
labels={"lKeyA": "lValA", "lKeyB": "lValB"},
484+
annotations={"aKeyA": "aValA", "aKeyB": "aValB"},
485+
pod_spec=V1PodSpec(
486+
containers=[
487+
V1Container(
488+
name="primary1",
489+
image="random:image",
490+
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
491+
),
492+
V1Container(
493+
name="primary2",
494+
image="random:image2",
495+
env=[V1EnvVar(name="eKeyC", value="eValC"), V1EnvVar(name="eKeyD", value="eValD")],
496+
),
497+
],
498+
)
499+
))
500+
return "hi"
501+
502+
assert wf.nodes[0]._pod_template.primary_container_name == "primary1"
503+
assert wf.nodes[0]._pod_template.pod_spec.containers[0].image == "random:image"
504+
assert wf.nodes[0]._pod_template.labels == {"lKeyA": "lValA", "lKeyB": "lValB"}
505+
assert wf.nodes[0]._pod_template.annotations["aKeyA"] == "aValA"
506+
473507

474508
def test_override_accelerator():
475509
@task(accelerator=T4)

0 commit comments

Comments
 (0)