Skip to content

Commit

Permalink
Ensure that the Task SDK regularly sends heartbeats for running tasks
Browse files Browse the repository at this point in the history
There is more nuance and edge cases to support, but this is the crux of the
behaviour we want.

This fixes the payload to be what the server expects, and fixes the URL suffix
to match latest changes too
  • Loading branch information
ashb committed Nov 18, 2024
1 parent 347a83a commit e016cb7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
10 changes: 6 additions & 4 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ConnectionResponse,
TerminalTIState,
TIEnterRunningPayload,
TIHeartbeatInfo,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
)
Expand Down Expand Up @@ -109,16 +110,17 @@ def start(self, id: uuid.UUID, pid: int, when: datetime):
"""Tell the API server that this TI has started running."""
body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when)

self.client.patch(f"task-instance/{id}/state", content=body.model_dump_json())
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))

self.client.patch(f"task-instance/{id}/state", content=body.model_dump_json())
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID):
self.client.put(f"task-instance/{id}/heartbeat")
def heartbeat(self, id: uuid.UUID, pid: int):
body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())


class ConnectionOperations:
Expand Down
4 changes: 2 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr):
sys.stderr = sys.__stderr__

# Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the
# pipes form the supervisor
# pipes from the supervisor

for handle_name, sock, mode, close in (
("stdin", child_stdin, "r", True),
Expand Down Expand Up @@ -403,7 +403,7 @@ def wait(self) -> int:
continue

try:
self.client.task_instances.heartbeat(self.ti_id)
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
self._last_heartbeat = time.monotonic()
except Exception:
log.warning("Couldn't heartbeat", exc_info=True)
Expand Down
38 changes: 38 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import os
import signal
import sys
from time import sleep
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
from uuid import UUID

import pytest
import structlog
Expand All @@ -33,6 +36,9 @@
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.utils import timezone as tz

if TYPE_CHECKING:
import kgb


def lineno():
"""Returns the current line number in our program."""
Expand Down Expand Up @@ -153,3 +159,35 @@ def subprocess_main():
rc = proc.wait()

assert rc == -9

def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch):
"""Test that the WatchedSubprocess class regularly sends heartbeat requests, up to a certain frequency"""
import airflow.sdk.execution_time.supervisor

monkeypatch.setattr(airflow.sdk.execution_time.supervisor, "FASTEST_HEARTBEAT_INTERVAL", 0.1)

def subprocess_main():
sys.stdin.readline()

for _ in range(5):
print("output", flush=True)
sleep(0.05)

id = UUID("4d828a62-a417-4936-a7a6-2b3fabacecab")
spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat)
proc = WatchedSubprocess.start(
path=os.devnull,
ti=TaskInstance(
id=id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
),
client=sdk_client.Client(base_url="", dry_run=True, token=""),
target=subprocess_main,
)
assert proc.wait() == 0
assert spy.called_with(id, pid=proc.pid) # noqa: PGH005
# The exact number we get will depend on timing behaviour, so be a little lenient
assert 2 <= len(spy.calls) <= 4

0 comments on commit e016cb7

Please sign in to comment.