Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import os
import typing
from enum import Enum
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
pod_template_name: Optional[str] = None,
local_logs: bool = False,
resources: Optional[Resources] = None,
timeout: Optional["datetime.timedelta"] = None,
**kwargs,
):
sec_ctx = None
Expand All @@ -74,6 +76,9 @@ def __init__(
metadata = metadata or TaskMetadata()
metadata.pod_template_name = pod_template_name

if timeout is not None:
metadata.timeout = timeout

super().__init__(
task_type="raw-container",
name=name,
Expand Down Expand Up @@ -103,6 +108,7 @@ def __init__(
)
self.pod_template = pod_template
self.local_logs = local_logs
self._timeout = timeout

@property
def resources(self) -> ResourceSpec:
Expand Down Expand Up @@ -279,14 +285,16 @@ def execute(self, **kwargs) -> LiteralMap:
container = client.containers.run(
self._image, command=commands, remove=True, volumes=volume_bindings, detach=True
)
# Wait for the container to finish the task
# TODO: Add a 'timeout' parameter to control the max wait time for the container to finish the task.

timeout_seconds = None
if self._timeout is not None:
timeout_seconds = self._timeout.total_seconds()

if self.local_logs:
for log in container.logs(stream=True):
print(f"[Local Container] {log.strip()}")

container.wait()
container.wait(timeout=timeout_seconds)

output_dict = self._get_output_dict(output_directory)
outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict)
Expand Down Expand Up @@ -330,8 +338,12 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain
def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod:
if self.pod_template is None:
return None
pod_spec = _serialize_pod_spec(self.pod_template, self._get_container(settings), settings)
if self._timeout is not None:
timeout_seconds = int(self._timeout.total_seconds())
pod_spec["activeDeadlineSeconds"] = timeout_seconds
return _task_model.K8sPod(
pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings), settings),
pod_spec=pod_spec,
metadata=_task_model.K8sObjectMetadata(
labels=self.pod_template.labels,
annotations=self.pod_template.annotations,
Expand Down
115 changes: 115 additions & 0 deletions tests/flytekit/unit/core/test_container_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import sys
import time
import docker

from collections import OrderedDict
from typing import Tuple
from datetime import timedelta

import pytest
from kubernetes.client.models import (
Expand Down Expand Up @@ -238,3 +242,114 @@ def test_container_task_image_spec(mock_image_spec_builder):
pod = ct.get_k8s_pod(default_serialization_settings)
assert pod.pod_spec["containers"][0]["image"] == image_spec_1.image_name()
assert pod.pod_spec["containers"][1]["image"] == image_spec_2.image_name()

@pytest.mark.skipif(
sys.platform in ["darwin", "win32"],
reason="Skip if running on windows or macos due to CI Docker environment setup failure",
)
def test_container_task_timeout():
ct_with_timedelta = ContainerTask(
name="timedelta-timeout-test",
image="busybox",
command=["sleep", "100"],
timeout=timedelta(seconds=1),
)

with pytest.raises((docker.errors.APIError, Exception)):
ct_with_timedelta.execute()

@pytest.mark.skipif(
sys.platform in ["darwin", "win32"],
reason="Skip if running on windows or macos due to CI Docker environment setup failure",
)
def test_container_task_timeout_k8s_serialization():

ps = V1PodSpec(
containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")]
)
pt = PodTemplate(pod_spec=ps, labels={"test": "timeout"})

default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash")
default_image_config = ImageConfig(default_image=default_image)
default_serialization_settings = SerializationSettings(
project="p", domain="d", version="v", image_config=default_image_config
)

ct_timedelta = ContainerTask(
name="timeout-k8s-timedelta-test",
image="busybox",
command=["echo", "hello"],
pod_template=pt,
timeout=timedelta(minutes=2),
)

k8s_pod_timedelta = ct_timedelta.get_k8s_pod(default_serialization_settings)
assert k8s_pod_timedelta.pod_spec["activeDeadlineSeconds"] == 120


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"],
reason="Skip if running on windows or macos due to CI Docker environment setup failure",
)
def test_container_task_timeout_in_metadata():
from flytekit.core.base_task import TaskMetadata

ct_with_timedelta = ContainerTask(
name="timeout-metadata-test",
image="busybox",
command=["echo", "hello"],
timeout=timedelta(minutes=5),
)

assert ct_with_timedelta.metadata.timeout == timedelta(minutes=5)

# Test with custom metadata - timeout should be set in the provided metadata
custom_metadata = TaskMetadata(retries=3)
ct_with_custom_metadata = ContainerTask(
name="custom-metadata-timeout-test",
image="busybox",
command=["echo", "hello"],
metadata=custom_metadata,
timeout=timedelta(seconds=30),
)

# Verify timeout is set in the custom metadata and retries are preserved
assert ct_with_custom_metadata.metadata.timeout == timedelta(seconds=30)
assert ct_with_custom_metadata.metadata.retries == 3

ct_without_timeout = ContainerTask(
name="no-timeout-test",
image="busybox",
command=["echo", "hello"]
)

assert ct_without_timeout.metadata.timeout is None


def test_container_task_timeout_serialization():
ps = V1PodSpec(
containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")]
)
pt = PodTemplate(pod_spec=ps, labels={"test": "timeout"})

default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash")
default_image_config = ImageConfig(default_image=default_image)
default_serialization_settings = SerializationSettings(
project="p", domain="d", version="v", image_config=default_image_config
)

ct_with_timeout = ContainerTask(
name="timeout-serialization-test",
image="busybox",
command=["echo", "hello"],
pod_template=pt,
timeout=timedelta(minutes=10),
)

from flytekit.tools.translator import get_serializable_task
from collections import OrderedDict

serialized_task = get_serializable_task(OrderedDict(), default_serialization_settings, ct_with_timeout)

k8s_pod = ct_with_timeout.get_k8s_pod(default_serialization_settings)
assert k8s_pod.pod_spec["activeDeadlineSeconds"] == 600 # 10 minutes in seconds
Loading