diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index cc0f7bafbe..3011544dec 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -324,7 +324,7 @@ def _dispatch_execute( logger.info(f"Engine folder written successfully to the output prefix {output_prefix}") if task_def is not None and not getattr(task_def, "disable_deck", True): - _output_deck(task_def.name.split(".")[-1], ctx.user_space_params) + _output_deck(task_name=task_def.name.split(".")[-1], new_user_params=ctx.user_space_params) logger.debug("Finished _dispatch_execute") diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 6430aa9eac..41da032fee 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -115,21 +115,18 @@ class TaskMetadata(object): See the :std:ref:`IDL ` for the protobuf definition. - Args: - cache (bool): Indicates if caching should be enabled. See :std:ref:`Caching ` - cache_serialize (bool): Indicates if identical (ie. same inputs) instances of this task should be executed in serial when caching is enabled. See :std:ref:`Caching ` - cache_version (str): Version to be used for the cached value - cache_ignore_input_vars (Tuple[str, ...]): Input variables that should not be included when calculating hash for cache - interruptible (Optional[bool]): Indicates that this task can be interrupted and/or scheduled on nodes with - lower QoS guarantees that can include pre-emption. This can reduce the monetary cost executions incur at the - cost of performance penalties due to potential interruptions - deprecated (str): Can be used to provide a warning message for deprecated task. Absence or empty str indicates - that the task is active and not deprecated + Attributes: + cache (bool): Indicates if caching should be enabled. See :std:ref:`Caching `. + cache_serialize (bool): Indicates if identical (i.e. same inputs) instances of this task should be executed in serial when caching is enabled. See :std:ref:`Caching `. + cache_version (str): Version to be used for the cached value. + cache_ignore_input_vars (Tuple[str, ...]): Input variables that should not be included when calculating hash for cache. + interruptible (Optional[bool]): Indicates that this task can be interrupted and/or scheduled on nodes with lower QoS guarantees that can include pre-emption. + deprecated (str): Can be used to provide a warning message for a deprecated task. An absence or empty string indicates that the task is active and not deprecated. retries (int): for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times. - timeout (Optional[Union[datetime.timedelta, int]]): the max amount of time for which one execution of this task - should be executed for. The execution will be terminated if the runtime exceeds the given timeout - (approximately) - pod_template_name (Optional[str]): the name of existing PodTemplate resource in the cluster which will be used in this task. + timeout (Optional[Union[datetime.timedelta, int]]): The maximum duration for which one execution of this task should run. The execution will be terminated if the runtime exceeds this timeout. + pod_template_name (Optional[str]): The name of an existing PodTemplate resource in the cluster which will be used for this task. + generates_deck (bool): Indicates whether the task will generate a Deck URI. + is_eager (bool): Indicates whether the task should be treated as eager. """ cache: bool = False @@ -141,6 +138,7 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None + generates_deck: bool = False is_eager: bool = False def __post_init__(self): @@ -179,6 +177,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: discovery_version=self.cache_version, deprecated_error_message=self.deprecated, cache_serializable=self.cache_serialize, + generates_deck=self.generates_deck, pod_template_name=self.pod_template_name, cache_ignore_input_vars=self.cache_ignore_input_vars, is_eager=self.is_eager, @@ -720,11 +719,15 @@ def dispatch_execute( may be none * ``DynamicJobSpec`` is returned when a dynamic workflow is executed """ - if DeckField.TIMELINE.value in self.deck_fields and ctx.user_space_params is not None: - ctx.user_space_params.decks.append(ctx.user_space_params.timeline_deck) + # Invoked before the task is executed new_user_params = self.pre_execute(ctx.user_space_params) + if self.enable_deck and ctx.user_space_params is not None: + if DeckField.TIMELINE.value in self.deck_fields: + ctx.user_space_params.decks.append(ctx.user_space_params.timeline_deck) + new_user_params = ctx.user_space_params.with_enable_deck(enable_deck=True).build() + # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( ctx.with_execution_state( @@ -827,8 +830,19 @@ def disable_deck(self) -> bool: """ If true, this task will not output deck html file """ + warnings.warn( + "`disable_deck` is deprecated and will be removed in the future.\n" "Please use `enable_deck` instead.", + DeprecationWarning, + ) return self._disable_deck + @property + def enable_deck(self) -> bool: + """ + If true, this task will output deck html file + """ + return not self._disable_deck + @property def deck_fields(self) -> List[DeckField]: """ diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index ab19939522..6cdbce0730 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -94,6 +94,7 @@ class Builder(object): logging: Optional[_logging.Logger] = None task_id: typing.Optional[_identifier.Identifier] = None output_metadata_prefix: Optional[str] = None + enable_deck: bool = False def __init__(self, current: typing.Optional[ExecutionParameters] = None): self.stats = current.stats if current else None @@ -107,6 +108,7 @@ def __init__(self, current: typing.Optional[ExecutionParameters] = None): self.raw_output_prefix = current.raw_output_prefix if current else None self.task_id = current.task_id if current else None self.output_metadata_prefix = current.output_metadata_prefix if current else None + self.enable_deck = current.enable_deck if current else False def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: self.attrs[key] = v @@ -126,6 +128,7 @@ def build(self) -> ExecutionParameters: raw_output_prefix=self.raw_output_prefix, task_id=self.task_id, output_metadata_prefix=self.output_metadata_prefix, + enable_deck=self.enable_deck, **self.attrs, ) @@ -147,6 +150,11 @@ def with_task_sandbox(self) -> Builder: b.working_dir = task_sandbox_dir return b + def with_enable_deck(self, enable_deck: bool) -> Builder: + b = self.new_builder(self) + b.enable_deck = enable_deck + return b + def builder(self) -> Builder: return ExecutionParameters.Builder(current=self) @@ -162,6 +170,7 @@ def __init__( checkpoint=None, decks=None, task_id: typing.Optional[_identifier.Identifier] = None, + enable_deck: bool = False, **kwargs, ): """ @@ -190,6 +199,7 @@ def __init__( self._decks = decks self._task_id = task_id self._timeline_deck = None + self._enable_deck = enable_deck @property def stats(self) -> taggable.TaggableStats: @@ -298,6 +308,13 @@ def timeline_deck(self) -> "TimeLineDeck": # type: ignore self._timeline_deck = time_line_deck return time_line_deck + @property + def enable_deck(self) -> bool: + """ + Returns whether deck is enabled or not + """ + return self._enable_deck + def __getattr__(self, attr_name: str) -> typing.Any: """ This houses certain task specific context. For example in Spark, it houses the SparkSession, etc diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index 025306d47b..c8f9fd6644 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -41,10 +41,6 @@ class Deck: scatter plots or Markdown text. In addition, users can create new decks to render their data with custom renderers. - .. warning:: - - This feature is in beta. - .. code-block:: python iris_df = px.data.iris() @@ -86,6 +82,19 @@ def name(self) -> str: def html(self) -> str: return self._html + @staticmethod + def publish(): + params = FlyteContextManager.current_context().user_space_params + task_name = params.task_id.name + + if not params.enable_deck: + logger.warning( + f"Attempted to call publish() in task '{task_name}', but Flyte decks will not be generated because enable_deck is currently set to False." + ) + return + + _output_deck(task_name=task_name, new_user_params=params) + class TimeLineDeck(Deck): """ @@ -148,7 +157,8 @@ def generate_time_table(data: dict) -> str: def _get_deck( - new_user_params: ExecutionParameters, ignore_jupyter: bool = False + new_user_params: ExecutionParameters, + ignore_jupyter: bool = False, ) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ignore """ Get flyte deck html string @@ -176,11 +186,12 @@ def _get_deck( def _output_deck(task_name: str, new_user_params: ExecutionParameters): ctx = FlyteContext.current_context() + local_dir = ctx.file_access.get_random_local_directory() local_path = f"{local_dir}{os.sep}{DECK_FILE_NAME}" try: with open(local_path, "w", encoding="utf-8") as f: - f.write(_get_deck(new_user_params, ignore_jupyter=True)) + f.write(_get_deck(new_user_params=new_user_params, ignore_jupyter=True)) logger.info(f"{task_name} task creates flyte deck html to file://{local_path}") if ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: fs = ctx.file_access.get_filesystem_for_path(new_user_params.output_metadata_prefix) @@ -197,6 +208,7 @@ def _output_deck(task_name: str, new_user_params: ExecutionParameters): def get_deck_template() -> Template: root = os.path.dirname(os.path.abspath(__file__)) templates_dir = os.path.join(root, "html", "template.html") + with open(templates_dir, "r") as f: template_content = f.read() return Template(template_content) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 9693390458..fb88962f65 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -8,6 +8,7 @@ from flyteidl.core import tasks_pb2 as _core_task from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct +from google.protobuf.wrappers_pb2 import BoolValue from kubernetes.client import ApiClient from flytekit.models import common as _common @@ -184,6 +185,7 @@ def __init__( pod_template_name, cache_ignore_input_vars, is_eager: bool = False, + generates_deck: bool = False, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -203,6 +205,7 @@ def __init__( receive deprecation warnings. :param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a single instance over identical inputs is executed, other concurrent executions wait for the cached results. + :param bool generates_deck: Whether the task will generate a Deck URI. :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. :param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache. :param is_eager: @@ -218,6 +221,7 @@ def __init__( self._pod_template_name = pod_template_name self._cache_ignore_input_vars = cache_ignore_input_vars self._is_eager = is_eager + self._generates_deck = generates_deck @property def is_eager(self): @@ -299,6 +303,14 @@ def pod_template_name(self): """ return self._pod_template_name + @property + def generates_deck(self) -> bool: + """ + Whether the task will generate a Deck. + :rtype: bool + """ + return self._generates_deck + @property def cache_ignore_input_vars(self): """ @@ -322,6 +334,7 @@ def to_flyte_idl(self): pod_template_name=self.pod_template_name, cache_ignore_input_vars=self.cache_ignore_input_vars, is_eager=self.is_eager, + generates_deck=BoolValue(value=self.generates_deck), ) if self.timeout: tm.timeout.FromTimedelta(self.timeout) @@ -345,6 +358,7 @@ def from_flyte_idl(cls, pb2_object: _core_task.TaskMetadata): pod_template_name=pb2_object.pod_template_name, cache_ignore_input_vars=pb2_object.cache_ignore_input_vars, is_eager=pb2_object.is_eager, + generates_deck=pb2_object.generates_deck.value if pb2_object.HasField("generates_deck") else False, ) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index e74f4c1c71..1d7c57cccc 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -185,11 +185,13 @@ def get_serializable_task( entity.reset_command_fn() entity_config = entity.get_config(settings) or {} - extra_config = {} - if hasattr(entity, "task_function") and isinstance(entity.task_function, ClassDecorator): - extra_config = entity.task_function.get_extra_config() + if hasattr(entity, "task_function"): + if isinstance(entity.task_function, ClassDecorator): + extra_config = entity.task_function.get_extra_config() + if entity.enable_deck: + entity.metadata.generates_deck = True merged_config = {**entity_config, **extra_config} diff --git a/pydoclint-errors-baseline.txt b/pydoclint-errors-baseline.txt index e3fd80d99d..3f9865114a 100644 --- a/pydoclint-errors-baseline.txt +++ b/pydoclint-errors-baseline.txt @@ -37,8 +37,6 @@ flytekit/core/base_sql_task.py DOC301: Class `SQLTask`: __init__() should not have a docstring; please combine it with the docstring of the class -------------------- flytekit/core/base_task.py - DOC601: Class `TaskMetadata`: Class docstring contains fewer class attributes than actual class attributes. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) - DOC603: Class `TaskMetadata`: Class docstring attributes are different from actual class attributes. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Attributes in the class definition but not in the docstring: [cache: bool, cache_ignore_input_vars: Tuple[str, ...], cache_serialize: bool, cache_version: str, deprecated: str, interruptible: Optional[bool], is_eager: bool, pod_template_name: Optional[str], retries: int, timeout: Optional[Union[datetime.timedelta, int]]]. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) DOC301: Class `PythonTask`: __init__() should not have a docstring; please combine it with the docstring of the class DOC001: Function/method `post_execute`: Potential formatting errors in docstring. Error message: Expected a colon in 'rval is returned value from call to execute'. (Note: DOC001 could trigger other unrelated violations under this function/method too. Please fix the docstring formatting first.) DOC101: Method `PythonTask.post_execute`: Docstring contains fewer arguments than in function signature. diff --git a/tests/flytekit/unit/deck/test_deck.py b/tests/flytekit/unit/deck/test_deck.py index ce07317a94..d91850e925 100644 --- a/tests/flytekit/unit/deck/test_deck.py +++ b/tests/flytekit/unit/deck/test_deck.py @@ -7,6 +7,7 @@ import flytekit from flytekit import Deck, FlyteContextManager, task + from flytekit.deck import DeckField, MarkdownRenderer, SourceCodeRenderer, TopFrameRenderer from flytekit.deck.deck import _output_deck from flytekit.deck.renderer import PythonDependencyRenderer @@ -258,3 +259,43 @@ def test_python_dependency_renderer(): # Assert that the button of copy assert 'button onclick="copyTable()"' in result + +def test_enable_deck_in_task(): + @task(enable_deck=True) + def t1(): + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == True + return + + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == False + + t1() + + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == False + +def test_disable_deck_in_task(): + @task(disable_deck=True) + def t1(): + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == False + return + + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == False + t1() + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == False + + @task + def t2(): + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == False + return + + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == False + t2() + ctx = FlyteContextManager.current_context() + assert ctx.user_space_params.enable_deck == False diff --git a/tests/flytekit/unit/test_translator.py b/tests/flytekit/unit/test_translator.py index 24f2c14131..da90ff9079 100644 --- a/tests/flytekit/unit/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -9,9 +9,11 @@ from flytekit.core.reference_entity import ReferenceSpec, ReferenceTemplate from flytekit.core.task import ReferenceTask, task from flytekit.core.workflow import ReferenceWorkflow, workflow +from flytekit.deck import Deck from flytekit.models.core import identifier as identifier_models from flytekit.models.task import Resources as resource_model from flytekit.tools.translator import get_serializable, Options +import pytest default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -93,14 +95,29 @@ def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): def t2(a: str, b: str) -> str: return b + a - ssettings = ( + settings = ( serialization_settings.new_builder() .with_fast_serialization_settings(FastSerializationSettings(enabled=True)) .build() ) - task_spec = get_serializable(OrderedDict(), ssettings, t1) + task_spec = get_serializable(OrderedDict(), settings, t1) assert "pyflyte-fast-execute" in task_spec.template.container.args +@pytest.mark.parametrize('enable_deck,expected', [(True, True), (False, False)]) +def test_deck_settings(enable_deck, expected): + @task(enable_deck=enable_deck) + def t_deck(): + if enable_deck: + Deck.publish() + + settings = ( + serialization_settings.new_builder() + .with_fast_serialization_settings(FastSerializationSettings(enabled=True)) + .build() + ) + task_spec = get_serializable(OrderedDict(), settings, t_deck) + assert task_spec.template.metadata.generates_deck == expected + def test_container(): @task