Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
FlyteValidationException,
FlyteValueException,
)
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
Expand Down Expand Up @@ -192,6 +193,7 @@ def __init__(
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
default_options: Optional[Options] = None,
container_image: Optional[ImageSpec] = None,
**kwargs,
):
self._name = name
Expand All @@ -207,6 +209,7 @@ def __init__(
self._failure_node = None
self._docs = docs
self._default_options = default_options
self._container_image = container_image

if self._python_interface.docstring:
if self.docs is None:
Expand Down Expand Up @@ -275,6 +278,10 @@ def failure_node(self) -> Optional[Node]:
def default_options(self) -> Optional[Options]:
return self._default_options

@property
def container_image(self) -> Optional[ImageSpec]:
return self._container_image

def __repr__(self):
return (
f"WorkflowBase - {self._name} && "
Expand Down Expand Up @@ -715,6 +722,7 @@ def __init__(
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
container_image: Optional[ImageSpec] = None,
):
name, _, _, _ = extract_task_module(workflow_function)
self._workflow_function = workflow_function
Expand All @@ -734,6 +742,7 @@ def __init__(
on_failure=on_failure,
docs=docs,
default_options=default_options,
container_image=container_image,
)

# Set this here so that the lhs call doesn't fail at least. This is only useful in the context of the
Expand Down Expand Up @@ -902,6 +911,7 @@ def workflow(
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
container_image: Optional[Union[str, ImageSpec]] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ...


Expand All @@ -914,6 +924,7 @@ def workflow(
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
container_image: Optional[Union[str, ImageSpec]] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...


Expand All @@ -925,6 +936,7 @@ def workflow(
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
container_image: Optional[Union[str, ImageSpec]] = None,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
Expand Down Expand Up @@ -960,6 +972,7 @@ def workflow(
the workflow. This is not recommended for general use.
:param default_options: Default options for the workflow when creating a default launch plan. Currently only
the labels and annotations are allowed to be set as defaults.
:param container_image: A container image spec to use for the workflow.
"""

def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
Expand All @@ -976,6 +989,7 @@ def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
docs=docs,
pickle_untyped=pickle_untyped,
default_options=default_options,
container_image=container_image,
)
update_wrapper(workflow_instance, fn)
return workflow_instance
Expand Down
2 changes: 1 addition & 1 deletion pydoclint-errors-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ flytekit/core/utils.py
--------------------
flytekit/core/workflow.py
DOC101: Function `workflow`: Docstring contains fewer arguments than in function signature.
DOC103: Function `workflow`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [_workflow_function: Optional[Callable[P, FuncOut]], default_options: Optional[Options], docs: Optional[Documentation], failure_policy: Optional[WorkflowFailurePolicy], interruptible: bool, on_failure: Optional[Union[WorkflowBase, Task]], pickle_untyped: bool].
DOC103: Function `workflow`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [_workflow_function: Optional[Callable[P, FuncOut]], container_image: Optional[Union[str, ImageSpec]], default_options: Optional[Options], docs: Optional[Documentation], failure_policy: Optional[WorkflowFailurePolicy], interruptible: bool, on_failure: Optional[Union[WorkflowBase, Task]], pickle_untyped: bool].
DOC201: Function `workflow` does not have a return section in docstring
DOC203: Function `workflow` return type(s) in docstring not consistent with the return annotation. Return annotation has 1 type(s); docstring return section has 0 type(s).
DOC101: Function `reference_workflow`: Docstring contains fewer arguments than in function signature.
Expand Down
38 changes: 38 additions & 0 deletions tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,44 @@ def wf(a: bool = True) -> bool:
assert wf(a=False) is False


def test_container_image_wf():
modified_image = Image(name="modified", fqn="test1", tag="tag1")
@task
def t1(a: int) -> int:
a = a + 5
return a
@workflow(container_image=modified_image)
def container_image_wf(a: int) -> int:
return t1(a=a)

x = container_image_wf(a=3)
assert x == 8
wf_spec = get_serializable(OrderedDict(), serialization_settings, container_image_wf)
assert wf_spec.template.container_image == modified_image
assert wf_spec.template.nodes[0].task_node.container_image == modified_image

def test_container_image_wf_override_task():
modified_image = Image(name="modified", fqn="test1", tag="tag1")
@task
def t1(a: int) -> int:
a = a + 5
return a

@task(container_image=default_img)
def t2(a: int) -> int:
a = a + 5
return a
@workflow(container_image=modified_image)
def container_image_wf(a: int) -> int:
return t1(a=a) + t2(a=a)
x = container_image_wf(a=3)
assert x == 16
wf_spec = get_serializable(OrderedDict(), serialization_settings, container_image_wf)
assert wf_spec.template.container_image == modified_image
assert wf_spec.template.nodes[0].task_node.container_image == modified_image
assert wf_spec.template.nodes[1].task_node.container_image == default_img


def test_list_output_wf():
@task
def t1(a: int) -> int:
Expand Down
Loading