Skip to content

Commit 19cad35

Browse files
committed
Add container_image at the workflow level.
The container_image should be propagated to all underlying levels of a task. Signed-off-by: Rafael Raposo <[email protected]>
1 parent 8abecfe commit 19cad35

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

flytekit/core/workflow.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
FlyteValidationException,
5151
FlyteValueException,
5252
)
53+
from flytekit.image_spec.image_spec import ImageSpec
5354
from flytekit.loggers import logger
5455
from flytekit.models import interface as _interface_models
5556
from flytekit.models import literals as _literal_models
@@ -192,6 +193,7 @@ def __init__(
192193
on_failure: Optional[Union[WorkflowBase, Task]] = None,
193194
docs: Optional[Documentation] = None,
194195
default_options: Optional[Options] = None,
196+
container_image: Optional[ImageSpec] = None,
195197
**kwargs,
196198
):
197199
self._name = name
@@ -207,6 +209,7 @@ def __init__(
207209
self._failure_node = None
208210
self._docs = docs
209211
self._default_options = default_options
212+
self._container_image = container_image
210213

211214
if self._python_interface.docstring:
212215
if self.docs is None:
@@ -275,6 +278,10 @@ def failure_node(self) -> Optional[Node]:
275278
def default_options(self) -> Optional[Options]:
276279
return self._default_options
277280

281+
@property
282+
def container_image(self) -> Optional[ImageSpec]:
283+
return self._container_image
284+
278285
def __repr__(self):
279286
return (
280287
f"WorkflowBase - {self._name} && "
@@ -715,6 +722,7 @@ def __init__(
715722
docs: Optional[Documentation] = None,
716723
pickle_untyped: bool = False,
717724
default_options: Optional[Options] = None,
725+
container_image: Optional[ImageSpec] = None,
718726
):
719727
name, _, _, _ = extract_task_module(workflow_function)
720728
self._workflow_function = workflow_function
@@ -734,6 +742,7 @@ def __init__(
734742
on_failure=on_failure,
735743
docs=docs,
736744
default_options=default_options,
745+
container_image=container_image,
737746
)
738747

739748
# Set this here so that the lhs call doesn't fail at least. This is only useful in the context of the
@@ -902,6 +911,7 @@ def workflow(
902911
docs: Optional[Documentation] = ...,
903912
pickle_untyped: bool = ...,
904913
default_options: Optional[Options] = ...,
914+
container_image: Optional[Union[str, ImageSpec]] = ...,
905915
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ...
906916

907917

@@ -914,6 +924,7 @@ def workflow(
914924
docs: Optional[Documentation] = ...,
915925
pickle_untyped: bool = ...,
916926
default_options: Optional[Options] = ...,
927+
container_image: Optional[Union[str, ImageSpec]] = ...,
917928
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...
918929

919930

@@ -925,6 +936,7 @@ def workflow(
925936
docs: Optional[Documentation] = None,
926937
pickle_untyped: bool = False,
927938
default_options: Optional[Options] = None,
939+
container_image: Optional[Union[str, ImageSpec]] = None,
928940
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
929941
"""
930942
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
@@ -960,6 +972,7 @@ def workflow(
960972
the workflow. This is not recommended for general use.
961973
:param default_options: Default options for the workflow when creating a default launch plan. Currently only
962974
the labels and annotations are allowed to be set as defaults.
975+
:param container_image: A container image spec to use for the workflow.
963976
"""
964977

965978
def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
@@ -976,6 +989,7 @@ def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
976989
docs=docs,
977990
pickle_untyped=pickle_untyped,
978991
default_options=default_options,
992+
container_image=container_image,
979993
)
980994
update_wrapper(workflow_instance, fn)
981995
return workflow_instance

pydoclint-errors-baseline.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ flytekit/core/utils.py
134134
--------------------
135135
flytekit/core/workflow.py
136136
DOC101: Function `workflow`: Docstring contains fewer arguments than in function signature.
137-
DOC103: Function `workflow`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [_workflow_function: Optional[Callable[P, FuncOut]], default_options: Optional[Options], docs: Optional[Documentation], failure_policy: Optional[WorkflowFailurePolicy], interruptible: bool, on_failure: Optional[Union[WorkflowBase, Task]], pickle_untyped: bool].
138137
DOC201: Function `workflow` does not have a return section in docstring
139138
DOC203: Function `workflow` return type(s) in docstring not consistent with the return annotation. Return annotation has 1 type(s); docstring return section has 0 type(s).
140139
DOC101: Function `reference_workflow`: Docstring contains fewer arguments than in function signature.

tests/flytekit/unit/core/test_workflows.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,44 @@ def wf(a: bool = True) -> bool:
7878
assert wf(a=False) is False
7979

8080

81+
def test_container_image_wf():
82+
modified_image = Image(name="modified", fqn="test1", tag="tag1")
83+
@task
84+
def t1(a: int) -> int:
85+
a = a + 5
86+
return a
87+
@workflow(container_image=modified_image)
88+
def container_image_wf(a: int) -> int:
89+
return t1(a=a)
90+
91+
x = container_image_wf(a=3)
92+
assert x == 8
93+
wf_spec = get_serializable(OrderedDict(), serialization_settings, container_image_wf)
94+
assert wf_spec.template.container_image == modified_image
95+
assert wf_spec.template.nodes[0].task_node.container_image == modified_image
96+
97+
def test_container_image_wf_override_task():
98+
modified_image = Image(name="modified", fqn="test1", tag="tag1")
99+
@task
100+
def t1(a: int) -> int:
101+
a = a + 5
102+
return a
103+
104+
@task(container_image=default_img)
105+
def t2(a: int) -> int:
106+
a = a + 5
107+
return a
108+
@workflow(container_image=modified_image)
109+
def container_image_wf(a: int) -> int:
110+
return t1(a=a) + t2(a=a)
111+
x = container_image_wf(a=3)
112+
assert x == 16
113+
wf_spec = get_serializable(OrderedDict(), serialization_settings, container_image_wf)
114+
assert wf_spec.template.container_image == modified_image
115+
assert wf_spec.template.nodes[0].task_node.container_image == modified_image
116+
assert wf_spec.template.nodes[1].task_node.container_image == default_img
117+
118+
81119
def test_list_output_wf():
82120
@task
83121
def t1(a: int) -> int:

0 commit comments

Comments
 (0)