Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] support for cluster pools throughout the sdk #3039

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def create(
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
auto_activate: bool = False,
cluster_pool: Optional[str] = None,
) -> LaunchPlan:
ctx = FlyteContextManager.current_context()
default_inputs = default_inputs or {}
Expand Down Expand Up @@ -188,6 +189,7 @@ def create(
trigger=trigger,
overwrite_cache=overwrite_cache,
auto_activate=auto_activate,
cluster_pool=cluster_pool,
)

# This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out
Expand Down Expand Up @@ -219,6 +221,7 @@ def get_or_create(
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
auto_activate: bool = False,
cluster_pool: Optional[str] = None,
) -> LaunchPlan:
"""
This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not
Expand Down Expand Up @@ -246,8 +249,10 @@ def get_or_create(
parallelism/concurrency of MapTasks is independent from this.
:param trigger: [alpha] This is a new syntax for specifying schedules.
:param overwrite_cache: If set to True, the execution will always overwrite cache
:param auto_activate: If set to True, the launch plan will be activated automatically on registration.
Default is False.
:param auto_activate: If set to True, the launch plan will be activated automatically on registration. Default is False. # noqa: E501
:param cluster_pool: The cluster pool to use for execution. If not set, the default cluster pool will be used.

:rtype: LaunchPlan
"""
if name is None and (
default_inputs is not None
Expand Down Expand Up @@ -300,6 +305,7 @@ def get_or_create(
("security_context", security_context, cached_outputs["_security_context"]),
("overwrite_cache", overwrite_cache, cached_outputs["_overwrite_cache"]),
("auto_activate", auto_activate, cached_outputs["_auto_activate"]),
("cluster_pool", cluster_pool, cached_outputs["_cluster_pool"]),
]:
if new != cached:
raise AssertionError(
Expand Down Expand Up @@ -332,6 +338,7 @@ def get_or_create(
trigger=trigger,
overwrite_cache=overwrite_cache,
auto_activate=auto_activate,
cluster_pool=cluster_pool,
)
LaunchPlan.CACHE[name or workflow.name] = lp
return lp
Expand All @@ -352,6 +359,7 @@ def __init__(
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
auto_activate: bool = False,
cluster_pool: Optional[str] = None,
):
self._name = name
self._workflow = workflow
Expand All @@ -372,6 +380,7 @@ def __init__(
self._trigger = trigger
self._overwrite_cache = overwrite_cache
self._auto_activate = auto_activate
self._cluster_pool = cluster_pool

FlyteEntities.entities.append(self)

Expand All @@ -390,6 +399,7 @@ def clone_with(
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
auto_activate: bool = False,
cluster_pool: Optional[str] = None,
) -> LaunchPlan:
return LaunchPlan(
name=name,
Expand All @@ -406,6 +416,7 @@ def clone_with(
trigger=trigger,
overwrite_cache=overwrite_cache or self.overwrite_cache,
auto_activate=auto_activate,
cluster_pool=cluster_pool,
)

@property
Expand Down Expand Up @@ -480,6 +491,10 @@ def trigger(self) -> Optional[LaunchPlanTriggerBase]:
def should_auto_activate(self) -> bool:
return self._auto_activate

@property
def cluster_pool(self) -> Optional[str]:
return self._cluster_pool

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
return self.workflow.construct_node_metadata()

Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def with_overrides(
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
cluster_pool: Optional[str] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -221,6 +222,10 @@ def with_overrides(
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize

if cluster_pool is not None:
assert_not_promise(cluster_pool, "cluster_pool")
self._metadata.add_config("cluster_pool", cluster_pool)

return self


Expand Down
2 changes: 2 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def with_overrides(
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
cluster_pool: Optional[str] = None,
*args,
**kwargs,
):
Expand All @@ -617,6 +618,7 @@ def with_overrides(
cache=cache,
cache_version=cache_version,
cache_serialize=cache_serialize,
cluster_pool=cluster_pool,
*args,
**kwargs,
)
Expand Down
12 changes: 12 additions & 0 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(
cacheable: typing.Optional[bool] = None,
cache_version: typing.Optional[str] = None,
cache_serializable: typing.Optional[bool] = None,
config: typing.Optional[typing.Dict[str, str]] = None,
):
"""
Defines extra information about the Node.
Expand All @@ -183,6 +184,7 @@ def __init__(
:param cacheable: Indicates that this nodes outputs should be cached.
:param cache_version: The version of the cached data.
:param cacheable: Indicates that cache operations on this node should be serialized.
:param config: Optional config fields for the override
"""
self._name = name
self._timeout = timeout if timeout is not None else datetime.timedelta()
Expand All @@ -191,6 +193,7 @@ def __init__(
self._cacheable = cacheable
self._cache_version = cache_version
self._cache_serializable = cache_serializable
self._config = config or {}

@property
def name(self):
Expand Down Expand Up @@ -229,6 +232,13 @@ def cache_version(self) -> typing.Optional[str]:
def cache_serializable(self) -> typing.Optional[bool]:
return self._cache_serializable

@property
def config(self) -> typing.Optional[typing.Dict[str, str]]:
return self._config

def add_config(self, key: str, value: str):
self._config[key] = value

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.workflow_pb2.NodeMetadata
Expand All @@ -240,6 +250,7 @@ def to_flyte_idl(self):
cacheable=self.cacheable,
cache_version=self.cache_version,
cache_serializable=self.cache_serializable,
config=self.config,
)
if self.timeout:
node_metadata.timeout.FromTimedelta(self.timeout)
Expand All @@ -255,6 +266,7 @@ def from_flyte_idl(cls, pb2_object):
pb2_object.cacheable if pb2_object.HasField("cacheable") else None,
pb2_object.cache_version if pb2_object.HasField("cache_version") else None,
pb2_object.cache_serializable if pb2_object.HasField("cache_serializable") else None,
config=pb2_object.config,
)


Expand Down
14 changes: 13 additions & 1 deletion flytekit/models/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit.models import schedule as _schedule
from flytekit.models import security
from flytekit.models.core import identifier as _identifier
from flytekit.models.execution import ClusterAssignment


class LaunchPlanMetadata(_common.FlyteIdlEntity):
Expand Down Expand Up @@ -133,11 +134,12 @@ def __init__(
fixed_inputs,
labels: _common.Labels,
annotations: _common.Annotations,
auth_role: _common.AuthRole,
auth_role: typing.Optional[_common.AuthRole],
raw_output_data_config: _common.RawOutputDataConfig,
max_parallelism: typing.Optional[int] = None,
security_context: typing.Optional[security.SecurityContext] = None,
overwrite_cache: typing.Optional[bool] = None,
cluster_assignment: typing.Optional[ClusterAssignment] = None,
):
"""
The spec for a Launch Plan.
Expand All @@ -158,6 +160,7 @@ def __init__(
parallelism/concurrency of MapTasks is independent from this.
:param security_context: This can be used to add security information to a LaunchPlan, which will be used by
every execution
:param cluster_assignment: Optional cluster assignment for the launch plan
"""
self._workflow_id = workflow_id
self._entity_metadata = entity_metadata
Expand All @@ -170,6 +173,7 @@ def __init__(
self._max_parallelism = max_parallelism
self._security_context = security_context
self._overwrite_cache = overwrite_cache
self._cluster_assignment = cluster_assignment

@property
def workflow_id(self):
Expand Down Expand Up @@ -246,6 +250,10 @@ def security_context(self) -> typing.Optional[security.SecurityContext]:
def overwrite_cache(self) -> typing.Optional[bool]:
return self._overwrite_cache

@property
def cluster_assignment(self) -> typing.Optional[ClusterAssignment]:
return self._cluster_assignment

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
Expand All @@ -262,6 +270,7 @@ def to_flyte_idl(self):
max_parallelism=self.max_parallelism,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
overwrite_cache=self.overwrite_cache if self.overwrite_cache else None,
cluster_assignment=self.cluster_assignment.to_flyte_idl() if self.cluster_assignment else None,
)

@classmethod
Expand Down Expand Up @@ -295,6 +304,9 @@ def from_flyte_idl(cls, pb2):
if pb2.security_context
else None,
overwrite_cache=pb2.overwrite_cache if pb2.overwrite_cache else None,
cluster_assignment=ClusterAssignment.from_flyte_idl(pb2.cluster_assignment)
if pb2.HasField("cluster_assignment")
else None,
)


Expand Down
2 changes: 2 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from flytekit.models.core.workflow import ApproveCondition, GateNode, SignalCondition, SleepCondition, TaskNodeOverrides
from flytekit.models.core.workflow import ArrayNode as ArrayNodeModel
from flytekit.models.core.workflow import BranchNode as BranchNodeModel
from flytekit.models.execution import ClusterAssignment
from flytekit.models.task import TaskSpec, TaskTemplate

FlyteLocalEntity = Union[
Expand Down Expand Up @@ -368,6 +369,7 @@ def get_serializable_launch_plan(
max_parallelism=options.max_parallelism or entity.max_parallelism,
security_context=options.security_context or entity.security_context,
overwrite_cache=options.overwrite_cache or entity.overwrite_cache,
cluster_assignment=ClusterAssignment(cluster_pool=entity.cluster_pool) if entity.cluster_pool else None,
)

lp_id = _identifier_model.Identifier(
Expand Down
28 changes: 26 additions & 2 deletions tests/flytekit/unit/core/test_launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from flytekit.models.common import Annotations, AuthRole, Labels, RawOutputDataConfig
from flytekit.models.core import execution as _execution_model
from flytekit.models.core import identifier as identifier_models
from flytekit.models.launch_plan import LaunchPlan
from flytekit.tools.translator import get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
Expand Down Expand Up @@ -401,8 +402,8 @@ def my_wf(a: int) -> typing.Tuple[int, int]:
wf_spec = get_serializable(all_entities, serialization_settings, my_wf)
assert wf_spec.template.nodes[1].workflow_node is not None
assert (
wf_spec.template.nodes[1].workflow_node.launchplan_ref.resource_type
== identifier_models.ResourceType.LAUNCH_PLAN
wf_spec.template.nodes[1].workflow_node.launchplan_ref.resource_type
== identifier_models.ResourceType.LAUNCH_PLAN
)
assert wf_spec.template.nodes[1].workflow_node.launchplan_ref.name == "my_sub_wf_lp1"

Expand Down Expand Up @@ -452,6 +453,7 @@ def wf_with_docstring(a: int) -> (str, str):
lp = launch_plan.LaunchPlan.get_or_create(workflow=wf_with_docstring)
assert lp.parameters.parameters["a"].var.description == "foo"


def test_lp_with_wf_with_default_options():
@task
def t1(a: int) -> int:
Expand All @@ -472,3 +474,25 @@ def wf_with_default_options(a: int) -> int:
assert lp.labels.values["label"] == "foo"
assert len(lp.annotations.values) == 1
assert lp.annotations.values["anno"] == "bar"


def test_launchplan_with_cluster_assignment():
@task
def t1(a: int) -> int:
return a + 2

@workflow
def wf_with_cluster_assignment(a: int) -> int:
return t1(a=a)

lp = launch_plan.LaunchPlan.get_or_create(
workflow=wf_with_cluster_assignment, name="lp_with_cluster_assignment", cluster_pool="foo"
)

assert lp.cluster_pool == "foo"
assert lp.name == "lp_with_cluster_assignment"

lp_model: LaunchPlan = get_serializable(OrderedDict(), serialization_settings, lp)
assert lp_model is not None
assert lp_model.spec.cluster_assignment is not None
assert lp_model.spec.cluster_assignment.cluster_pool == "foo"
31 changes: 26 additions & 5 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import flytekit.configuration
from flytekit import Resources, map_task
from flytekit import Resources, map_task, LaunchPlan
from flytekit.configuration import Image, ImageConfig
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.node_creation import create_node
Expand All @@ -16,6 +16,7 @@
from flytekit.extras.accelerators import A100, T4
from flytekit.image_spec.image_spec import ImageBuildEngine
from flytekit.models import literals as _literal_models
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.task import Resources as _resources_models
from flytekit.tools.translator import get_serializable

Expand Down Expand Up @@ -95,7 +96,6 @@ def empty_wf2():
assert wf_spec.template.nodes[0].metadata.name == "t2"

with pytest.raises(FlyteAssertion):

@workflow
def empty_wf2():
create_node(t2, "foo")
Expand Down Expand Up @@ -141,7 +141,6 @@ def t1(a: int) -> nt:

# Test that you can't name an output "outputs"
with pytest.raises(FlyteAssertion):

@workflow
def my_wf(a: int) -> str:
t1_node = create_node(t1, a=a)
Expand Down Expand Up @@ -333,7 +332,6 @@ def t1(a: str) -> str:
return f"*~*~*~{a}*~*~*~"

with pytest.raises(ValueError, match="datetime.timedelta or int seconds"):

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(timeout="foo")
Expand Down Expand Up @@ -450,7 +448,6 @@ def my_wf(a: str) -> str:
assert my_wf.nodes[0].flyte_entity.task_config.name == "flyte"

with pytest.raises(ValueError):

@workflow
def my_wf(a: str) -> str:
return t1(a=a).with_overrides(task_config=None)
Expand Down Expand Up @@ -518,3 +515,27 @@ def my_wf(a: str) -> str:
assert wf_spec.template.nodes[0].metadata.cache_serializable
assert wf_spec.template.nodes[0].metadata.cacheable
assert wf_spec.template.nodes[0].metadata.cache_version == "foo"


def test_cluster_pool_overrides():
@workflow
def my_wf() -> str:
return "hello"

lp = LaunchPlan.create("test", my_wf)

@workflow
def caller_wf() -> str:
return lp().with_overrides(cluster_pool="pool")

ss = flytekit.configuration.SerializationSettings(
project="proj",
domain="dom",
version="v",
image_config=ImageConfig(Image(name="name", fqn="image", tag="tag")),
env={},
)

caller_wf_spec: WorkflowSpec = get_serializable(OrderedDict(), ss, caller_wf)

assert caller_wf_spec.template.nodes[0].metadata.config["cluster_pool"] == "pool"
Loading
Loading