From 0474930a4a5950beecdcc967c8d6510f51917196 Mon Sep 17 00:00:00 2001 From: raphaelauv Date: Thu, 25 Jul 2024 18:25:43 +0200 Subject: [PATCH] introduce `fail_policy` --- airflow/decorators/__init__.pyi | 5 +- airflow/example_dags/example_sensors.py | 22 +++-- airflow/exceptions.py | 4 + airflow/sensors/base.py | 70 ++++++++-------- airflow/sensors/external_task.py | 47 ++++------- airflow/triggers/external_task.py | 9 ++- .../src/airflow/providers/ftp/sensors/ftp.py | 3 +- .../airflow/providers/sftp/sensors/sftp.py | 4 +- .../providers/standard/sensors/filesystem.py | 4 +- providers/tests/ftp/sensors/test_ftp.py | 19 +++-- providers/tests/http/sensors/test_http.py | 6 +- providers/tests/sftp/sensors/test_sftp.py | 4 +- tests/decorators/test_sensor.py | 15 +++- tests/sensors/test_base.py | 53 +++++++----- tests/sensors/test_external_task_sensor.py | 80 ++++++++++++------- tests/triggers/test_external_task.py | 2 +- 16 files changed, 200 insertions(+), 147 deletions(-) diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index 7dd887431c6f5..aff3e0762d440 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -40,6 +40,7 @@ from airflow.decorators.short_circuit import short_circuit_task from airflow.decorators.task_group import task_group from airflow.models.dag import dag from airflow.providers.cncf.kubernetes.secret import Secret +from airflow.sensors.base import FailPolicy from airflow.typing_compat import Literal # Please keep this in sync with __init__.py's __all__. @@ -708,7 +709,7 @@ class TaskDecoratorCollection: *, poke_interval: float = ..., timeout: float = ..., - soft_fail: bool = False, + fail_policy: FailPolicy = ..., mode: str = ..., exponential_backoff: bool = False, max_wait: timedelta | float | None = None, @@ -720,7 +721,7 @@ class TaskDecoratorCollection: :param poke_interval: Time in seconds that the job should wait in between each try :param timeout: Time, in seconds before the task times out and fails. - :param soft_fail: Set to true to mark the task as SKIPPED on failure + :param fail_policy: TODO. :param mode: How the sensor operates. Options are: ``{ poke | reschedule }``, default is ``poke``. When set to ``poke`` the sensor is taking up a worker slot for its diff --git a/airflow/example_dags/example_sensors.py b/airflow/example_dags/example_sensors.py index 39d7b8d29635f..6ca71527112e6 100644 --- a/airflow/example_dags/example_sensors.py +++ b/airflow/example_dags/example_sensors.py @@ -23,6 +23,7 @@ from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator +from airflow.providers.standard.sensors.base import FailPolicy from airflow.providers.standard.sensors.bash import BashSensor from airflow.providers.standard.sensors.filesystem import FileSensor from airflow.providers.standard.sensors.python import PythonSensor @@ -68,7 +69,7 @@ def failure_callable(): t2 = TimeSensor( task_id="timeout_after_second_date_in_the_future", timeout=1, - soft_fail=True, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, target_time=(datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)).time(), ) # [END example_time_sensors] @@ -81,7 +82,7 @@ def failure_callable(): t2a = TimeSensorAsync( task_id="timeout_after_second_date_in_the_future_async", timeout=1, - soft_fail=True, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, target_time=(datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)).time(), ) # [END example_time_sensors_async] @@ -89,7 +90,12 @@ def failure_callable(): # [START example_bash_sensors] t3 = BashSensor(task_id="Sensor_succeeds", bash_command="exit 0") - t4 = BashSensor(task_id="Sensor_fails_after_3_seconds", timeout=3, soft_fail=True, bash_command="exit 1") + t4 = BashSensor( + task_id="Sensor_fails_after_3_seconds", + timeout=3, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + bash_command="exit 1", + ) # [END example_bash_sensors] t5 = BashOperator(task_id="remove_file", bash_command="rm -rf /tmp/temporary_file_for_testing") @@ -112,13 +118,19 @@ def failure_callable(): t9 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable) t10 = PythonSensor( - task_id="failure_timeout_sensor_python", timeout=3, soft_fail=True, python_callable=failure_callable + task_id="failure_timeout_sensor_python", + timeout=3, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + python_callable=failure_callable, ) # [END example_python_sensors] # [START example_day_of_week_sensor] t11 = DayOfWeekSensor( - task_id="week_day_sensor_failing_on_timeout", timeout=3, soft_fail=True, week_day=WeekDay.MONDAY + task_id="week_day_sensor_failing_on_timeout", + timeout=3, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + week_day=WeekDay.MONDAY, ) # [END example_day_of_week_sensor] diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 3b07b9a6fda96..3647f8b0720e3 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -68,6 +68,10 @@ class AirflowSensorTimeout(AirflowException): """Raise when there is a timeout on sensor polling.""" +class AirflowPokeFailException(AirflowException): + """Raise when a sensor must not try to poke again.""" + + class AirflowRescheduleException(AirflowException): """ Raise when the task should be re-scheduled at a later time. diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 3e5a8565e50c0..d88c21f396c55 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -18,6 +18,7 @@ from __future__ import annotations import datetime +import enum import functools import hashlib import time @@ -32,7 +33,7 @@ from airflow.configuration import conf from airflow.exceptions import ( AirflowException, - AirflowFailException, + AirflowPokeFailException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, @@ -109,6 +110,24 @@ def _orig_start_date( ) +class FailPolicy(str, enum.Enum): + """Class with sensor's fail policies.""" + + # if poke method raise an exception, sensor will not be skipped on. + NONE = "none" + + # If poke method raises an exception, sensor will be skipped on. + SKIP_ON_ANY_ERROR = "skip_on_any_error" + + # If poke method raises AirflowSensorTimeout, AirflowTaskTimeout,AirflowPokeFailException or AirflowSkipException + # sensor will be skipped on. + SKIP_ON_TIMEOUT = "skip_on_timeout" + + # If poke method raises an exception different from AirflowSensorTimeout, AirflowTaskTimeout, + # AirflowSkipException or AirflowFailException sensor will ignore exception and re-poke until timeout. + IGNORE_ERROR = "ignore_error" + + class BaseSensorOperator(BaseOperator, SkipMixin): """ Sensor operators are derived from this class and inherit these attributes. @@ -116,8 +135,6 @@ class BaseSensorOperator(BaseOperator, SkipMixin): Sensor operators keep executing at a time interval and succeed when a criteria is met and fail if and when they time out. - :param soft_fail: Set to true to mark the task as SKIPPED on failure. - Mutually exclusive with never_fail. :param poke_interval: Time that the job should wait in between each try. Can be ``timedelta`` or ``float`` seconds. :param timeout: Time elapsed before the task times out and fails. @@ -145,13 +162,10 @@ class BaseSensorOperator(BaseOperator, SkipMixin): :param exponential_backoff: allow progressive longer waits between pokes by using exponential backoff algorithm :param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds - :param silent_fail: If true, and poke method raises an exception different from - AirflowSensorTimeout, AirflowTaskTimeout, AirflowSkipException - and AirflowFailException, the sensor will log the error and continue - its execution. Otherwise, the sensor task fails, and it can be retried - based on the provided `retries` parameter. - :param never_fail: If true, and poke method raises an exception, sensor will be skipped. - Mutually exclusive with soft_fail. + :param fail_policy: defines the rule by which sensor skip itself. Options are: + ``{ none | skip_on_any_error | skip_on_timeout | ignore_error }`` + default is ``none``. Options can be set as string or + using the constants defined in the static class ``airflow.sensors.base.FailPolicy`` """ ui_color: str = "#e6f1f2" @@ -166,26 +180,19 @@ def __init__( *, poke_interval: timedelta | float = 60, timeout: timedelta | float = conf.getfloat("sensors", "default_timeout"), - soft_fail: bool = False, mode: str = "poke", exponential_backoff: bool = False, max_wait: timedelta | float | None = None, - silent_fail: bool = False, - never_fail: bool = False, + fail_policy: str = FailPolicy.NONE, **kwargs, ) -> None: super().__init__(**kwargs) self.poke_interval = self._coerce_poke_interval(poke_interval).total_seconds() - self.soft_fail = soft_fail self.timeout = self._coerce_timeout(timeout).total_seconds() self.mode = mode self.exponential_backoff = exponential_backoff self.max_wait = self._coerce_max_wait(max_wait) - if soft_fail is True and never_fail is True: - raise ValueError("soft_fail and never_fail are mutually exclusive, you can not provide both.") - - self.silent_fail = silent_fail - self.never_fail = never_fail + self.fail_policy = fail_policy self._validate_input_values() @staticmethod @@ -282,21 +289,20 @@ def run_duration() -> float: except ( AirflowSensorTimeout, AirflowTaskTimeout, - AirflowFailException, + AirflowPokeFailException, + AirflowSkipException, ) as e: - if self.soft_fail: - raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e - elif self.never_fail: - raise AirflowSkipException("Skipping due to never_fail is set to True.") from e - raise e - except AirflowSkipException as e: + if self.fail_policy == FailPolicy.SKIP_ON_TIMEOUT: + raise AirflowSkipException("Skipping due fail_policy set to SKIP_ON_TIMEOUT.") from e + elif self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR: + raise AirflowSkipException("Skipping due to SKIP_ON_ANY_ERROR is set to True.") from e raise e except Exception as e: - if self.silent_fail: + if self.fail_policy == FailPolicy.IGNORE_ERROR: self.log.error("Sensor poke failed: \n %s", traceback.format_exc()) poke_return = False - elif self.never_fail: - raise AirflowSkipException("Skipping due to never_fail is set to True.") from e + elif self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR: + raise AirflowSkipException("Skipping due to SKIP_ON_ANY_ERROR is set to True.") from e else: raise e @@ -306,13 +312,13 @@ def run_duration() -> float: break if run_duration() > self.timeout: - # If sensor is in soft fail mode but times out raise AirflowSkipException. + # If sensor is in SKIP_ON_TIMEOUT mode but times out it raise AirflowSkipException. message = ( f"Sensor has timed out; run duration of {run_duration()} seconds exceeds " f"the specified timeout of {self.timeout}." ) - if self.soft_fail: + if self.fail_policy == FailPolicy.SKIP_ON_TIMEOUT: raise AirflowSkipException(message) else: raise AirflowSensorTimeout(message) @@ -335,7 +341,7 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, try: return super().resume_execution(next_method, next_kwargs, context) except (AirflowException, TaskDeferralError) as e: - if self.soft_fail: + if self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR: raise AirflowSkipException(str(e)) from e raise diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index d5504ab2e9fb9..2517dea199c77 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowPokeFailException, AirflowSkipException from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag @@ -175,7 +175,7 @@ def __init__( total_states = set(self.allowed_states + self.skipped_states + self.failed_states) if len(total_states) != len(self.allowed_states) + len(self.skipped_states) + len(self.failed_states): - raise AirflowException( + raise ValueError( "Duplicate values provided across allowed_states, skipped_states and failed_states." ) @@ -286,32 +286,18 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool: # Fail if anything in the list has failed. if count_failed > 0: if self.external_task_ids: - if self.soft_fail: - raise AirflowSkipException( - f"Some of the external tasks {self.external_task_ids} " - f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail." - ) - raise AirflowException( + raise AirflowPokeFailException( f"Some of the external tasks {self.external_task_ids} " f"in DAG {self.external_dag_id} failed." ) elif self.external_task_group_id: - if self.soft_fail: - raise AirflowSkipException( - f"The external task_group '{self.external_task_group_id}' " - f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail." - ) - raise AirflowException( + raise AirflowPokeFailException( f"The external task_group '{self.external_task_group_id}' " f"in DAG '{self.external_dag_id}' failed." ) else: - if self.soft_fail: - raise AirflowSkipException( - f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail." - ) - raise AirflowException(f"The external DAG {self.external_dag_id} failed.") + raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} failed.") count_skipped = -1 if self.skipped_states: @@ -354,7 +340,7 @@ def execute(self, context: Context) -> None: logical_dates=self._get_dttm_filter(context), allowed_states=self.allowed_states, poke_interval=self.poll_interval, - soft_fail=self.soft_fail, + fail_policy=self.fail_policy, ), method_name="execute_complete", ) @@ -364,30 +350,27 @@ def execute_complete(self, context, event=None): if event["status"] == "success": self.log.info("External tasks %s has executed successfully.", self.external_task_ids) elif event["status"] == "skipped": - raise AirflowSkipException("External job has skipped skipping.") + raise AirflowPokeFailException("External job has skipped skipping.") else: - if self.soft_fail: - raise AirflowSkipException("External job has failed skipping.") - else: - raise AirflowException( - "Error occurred while trying to retrieve task status. Please, check the " - "name of executed task and Dag." - ) + raise AirflowPokeFailException( + "Error occurred while trying to retrieve task status. Please, check the " + "name of executed task and Dag." + ) def _check_for_existence(self, session) -> None: dag_to_wait = DagModel.get_current(self.external_dag_id, session) if not dag_to_wait: - raise AirflowException(f"The external DAG {self.external_dag_id} does not exist.") + raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} does not exist.") if not os.path.exists(correct_maybe_zipped(dag_to_wait.fileloc)): - raise AirflowException(f"The external DAG {self.external_dag_id} was deleted.") + raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} was deleted.") if self.external_task_ids: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) for external_task_id in self.external_task_ids: if not refreshed_dag_info.has_task(external_task_id): - raise AirflowException( + raise AirflowPokeFailException( f"The external task {external_task_id} in " f"DAG {self.external_dag_id} does not exist." ) @@ -395,7 +378,7 @@ def _check_for_existence(self, session) -> None: if self.external_task_group_id: refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id) if not refreshed_dag_info.has_task_group(self.external_task_group_id): - raise AirflowException( + raise AirflowPokeFailException( f"The external task group '{self.external_task_group_id}' in " f"DAG '{self.external_dag_id}' does not exist." ) diff --git a/airflow/triggers/external_task.py b/airflow/triggers/external_task.py index 159a6df909501..0f26fb6142228 100644 --- a/airflow/triggers/external_task.py +++ b/airflow/triggers/external_task.py @@ -23,6 +23,7 @@ from asgiref.sync import sync_to_async from sqlalchemy import func +from airflow.sensors.base import FailPolicy from airflow.models import DagRun from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.sensor_helper import _get_count @@ -48,7 +49,7 @@ class WorkflowTrigger(BaseTrigger): :param skipped_states: States considered as skipped for external tasks. :param allowed_states: States considered as successful for external tasks. :param poke_interval: The interval (in seconds) for poking the external tasks. - :param soft_fail: If True, the trigger will not fail the entire DAG on external task failure. + """ def __init__( @@ -61,7 +62,7 @@ def __init__( skipped_states: typing.Iterable[str] | None = None, allowed_states: typing.Iterable[str] | None = None, poke_interval: float = 2.0, - soft_fail: bool = False, + fail_policy: str = FailPolicy.NONE, **kwargs, ): self.external_dag_id = external_dag_id @@ -72,7 +73,7 @@ def __init__( self.allowed_states = allowed_states self.logical_dates = logical_dates self.poke_interval = poke_interval - self.soft_fail = soft_fail + self.fail_policy = fail_policy super().__init__(**kwargs) def serialize(self) -> tuple[str, dict[str, Any]]: @@ -88,7 +89,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "allowed_states": self.allowed_states, "logical_dates": self.logical_dates, "poke_interval": self.poke_interval, - "soft_fail": self.soft_fail, + "fail_policy": self.fail_policy, }, ) diff --git a/providers/src/airflow/providers/ftp/sensors/ftp.py b/providers/src/airflow/providers/ftp/sensors/ftp.py index 1ab56c56c0609..93bd82eb7b8cb 100644 --- a/providers/src/airflow/providers/ftp/sensors/ftp.py +++ b/providers/src/airflow/providers/ftp/sensors/ftp.py @@ -21,6 +21,7 @@ import re from typing import TYPE_CHECKING, Sequence +from airflow.exceptions import AirflowPokeFailException from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook from airflow.sensors.base import BaseSensorOperator @@ -82,7 +83,7 @@ def poke(self, context: Context) -> bool: if (error_code != 550) and ( self.fail_on_transient_errors or (error_code not in self.transient_errors) ): - raise e + raise AirflowPokeFailException from e return False diff --git a/providers/src/airflow/providers/sftp/sensors/sftp.py b/providers/src/airflow/providers/sftp/sensors/sftp.py index f6b076331fe5b..e7b0b08ed0340 100644 --- a/providers/src/airflow/providers/sftp/sensors/sftp.py +++ b/providers/src/airflow/providers/sftp/sensors/sftp.py @@ -26,7 +26,7 @@ from paramiko.sftp import SFTP_NO_SUCH_FILE from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowPokeFailException from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.sftp.triggers.sftp import SFTPTrigger from airflow.sensors.base import BaseSensorOperator, PokeReturnValue @@ -98,7 +98,7 @@ def poke(self, context: Context) -> PokeReturnValue | bool: self.log.info("Found File %s last modified: %s", actual_file_to_check, mod_time) except OSError as e: if e.errno != SFTP_NO_SUCH_FILE: - raise AirflowException from e + raise AirflowPokeFailException from e continue if self.newer_than: diff --git a/providers/src/airflow/providers/standard/sensors/filesystem.py b/providers/src/airflow/providers/standard/sensors/filesystem.py index 4496f5d6abfa4..73bbc933c4610 100644 --- a/providers/src/airflow/providers/standard/sensors/filesystem.py +++ b/providers/src/airflow/providers/standard/sensors/filesystem.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowPokeFailException from airflow.providers.standard.hooks.filesystem import FSHook from airflow.sensors.base import BaseSensorOperator from airflow.triggers.base import StartTriggerArgs @@ -134,5 +134,5 @@ def execute(self, context: Context) -> None: def execute_complete(self, context: Context, event: bool | None = None) -> None: if not event: - raise AirflowException("%s task failed as %s not found.", self.task_id, self.filepath) + raise AirflowPokeFailException("%s task failed as %s not found.", self.task_id, self.filepath) self.log.info("%s completed successfully as %s found.", self.task_id, self.filepath) diff --git a/providers/tests/ftp/sensors/test_ftp.py b/providers/tests/ftp/sensors/test_ftp.py index 0a71fbd594c8f..0a8620079cd8b 100644 --- a/providers/tests/ftp/sensors/test_ftp.py +++ b/providers/tests/ftp/sensors/test_ftp.py @@ -22,8 +22,10 @@ import pytest +from airflow.exceptions import AirflowPokeFailException, AirflowSkipException from airflow.providers.ftp.hooks.ftp import FTPHook from airflow.providers.ftp.sensors.ftp import FTPSensor +from airflow.sensors.base import FailPolicy class TestFTPSensor: @@ -51,10 +53,10 @@ def test_poke_fails_due_error(self, mock_hook): "530: Login authentication failed" ) - with pytest.raises(error_perm) as ctx: + with pytest.raises(AirflowPokeFailException) as ctx: op.execute(None) - assert "530" in str(ctx.value) + assert "530" in str(ctx.value.__cause__) @mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook) def test_poke_fail_on_transient_error(self, mock_hook): @@ -64,20 +66,25 @@ def test_poke_fail_on_transient_error(self, mock_hook): "434: Host unavailable" ) - with pytest.raises(error_perm) as ctx: + with pytest.raises(AirflowPokeFailException) as ctx: op.execute(None) - assert "434" in str(ctx.value) + assert "434" in str(ctx.value.__cause__) @mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook) def test_poke_fail_on_transient_error_and_skip(self, mock_hook): - op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", task_id="test_task") + op = FTPSensor( + path="foobar.json", + ftp_conn_id="bob_ftp", + task_id="test_task", + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + ) mock_hook.return_value.__enter__.return_value.get_mod_time.side_effect = error_perm( "434: Host unavailable" ) - with pytest.raises(error_perm): + with pytest.raises(AirflowSkipException): op.execute(None) @mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook) diff --git a/providers/tests/http/sensors/test_http.py b/providers/tests/http/sensors/test_http.py index 78a11e15bb7c1..b2dbd08e3751b 100644 --- a/providers/tests/http/sensors/test_http.py +++ b/providers/tests/http/sensors/test_http.py @@ -23,7 +23,11 @@ import pytest import requests -from airflow.exceptions import AirflowException, AirflowSensorTimeout, TaskDeferred +from airflow.exceptions import ( + AirflowException, + AirflowSensorTimeout, + TaskDeferred, +) from airflow.models.dag import DAG from airflow.providers.http.operators.http import HttpOperator from airflow.providers.http.sensors.http import HttpSensor diff --git a/providers/tests/sftp/sensors/test_sftp.py b/providers/tests/sftp/sensors/test_sftp.py index 4d1be081af16c..0d209169a7083 100644 --- a/providers/tests/sftp/sensors/test_sftp.py +++ b/providers/tests/sftp/sensors/test_sftp.py @@ -25,7 +25,7 @@ from paramiko.sftp import SFTP_FAILURE, SFTP_NO_SUCH_FILE from pendulum import datetime as pendulum_datetime, timezone -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowPokeFailException from airflow.providers.sftp.sensors.sftp import SFTPSensor from airflow.sensors.base import PokeReturnValue @@ -58,7 +58,7 @@ def test_sftp_failure(self, sftp_hook_mock): sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/1970-01-01.txt") context = {"ds": "1970-01-01"} - with pytest.raises(AirflowException): + with pytest.raises(AirflowPokeFailException): sftp_sensor.poke(context) def test_hook_not_created_during_init(self): diff --git a/tests/decorators/test_sensor.py b/tests/decorators/test_sensor.py index 3b48d13f52607..d985e20b99801 100644 --- a/tests/decorators/test_sensor.py +++ b/tests/decorators/test_sensor.py @@ -24,7 +24,12 @@ from airflow.exceptions import AirflowSensorTimeout from airflow.models import XCom from airflow.sensors.base import PokeReturnValue +from tests.test_utils.compat import ignore_provider_compatibility_error + +with ignore_provider_compatibility_error("2.10.0", __file__): + from airflow.sensors.base import FailPolicy from airflow.utils.state import State +from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS pytestmark = pytest.mark.db_test @@ -144,9 +149,10 @@ def dummy_f(): if ti.task_id == "dummy_f": assert ti.state == State.NONE + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode - def test_basic_sensor_soft_fail(self, dag_maker): - @task.sensor(timeout=0, soft_fail=True) + def test_basic_sensor_skip_on_timeout(self, dag_maker): + @task.sensor(timeout=0, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) def sensor_f(): return PokeReturnValue(is_done=False, xcom_value="xcom_value") @@ -169,9 +175,10 @@ def dummy_f(): if ti.task_id == "dummy_f": assert ti.state == State.NONE + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode - def test_basic_sensor_soft_fail_returns_bool(self, dag_maker): - @task.sensor(timeout=0, soft_fail=True) + def test_basic_sensor_skip_on_timeout_returns_bool(self, dag_maker): + @task.sensor(timeout=0, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) def sensor_f(): return False diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index 9f3bd3f8fe84b..59feb04edf345 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -27,6 +27,7 @@ from airflow.exceptions import ( AirflowException, AirflowFailException, + AirflowPokeFailException, AirflowRescheduleException, AirflowSensorTimeout, AirflowSkipException, @@ -51,7 +52,7 @@ from airflow.providers.celery.executors.celery_kubernetes_executor import CeleryKubernetesExecutor from airflow.providers.cncf.kubernetes.executors.kubernetes_executor import KubernetesExecutor from airflow.providers.cncf.kubernetes.executors.local_kubernetes_executor import LocalKubernetesExecutor -from airflow.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only +from airflow.sensors.base import BaseSensorOperator, FailPolicy, PokeReturnValue, poke_mode_only from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep from airflow.utils import timezone from airflow.utils.session import create_session @@ -96,7 +97,7 @@ def __init__(self, return_value=False, **kwargs): self.return_value = return_value def execute_complete(self, context, event=None): - raise AirflowException("Should be skipped") + raise AirflowException() class DummySensorWithXcomValue(BaseSensorOperator): @@ -180,8 +181,8 @@ def test_fail(self, make_sensor): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_soft_fail(self, make_sensor): - sensor, dr = make_sensor(False, soft_fail=True) + def test_skip_on_timeout(self, make_sensor): + sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) self._run(sensor) tis = dr.get_task_instances() @@ -196,8 +197,8 @@ def test_soft_fail(self, make_sensor): "exception_cls", (ValueError,), ) - def test_soft_fail_with_exception(self, make_sensor, exception_cls): - sensor, dr = make_sensor(False, soft_fail=True) + def test_skip_on_timeout_with_exception(self, make_sensor, exception_cls): + sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) sensor.poke = Mock(side_effect=[exception_cls(None)]) with pytest.raises(ValueError): self._run(sensor) @@ -215,11 +216,11 @@ def test_soft_fail_with_exception(self, make_sensor, exception_cls): ( AirflowSensorTimeout, AirflowTaskTimeout, - AirflowFailException, + AirflowPokeFailException, ), ) - def test_soft_fail_with_skip_exception(self, make_sensor, exception_cls): - sensor, dr = make_sensor(False, soft_fail=True) + def test_skip_on_timeout_with_skip_exception(self, make_sensor, exception_cls): + sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_TIMEOUT) sensor.poke = Mock(side_effect=[exception_cls(None)]) self._run(sensor) @@ -233,10 +234,10 @@ def test_soft_fail_with_skip_exception(self, make_sensor, exception_cls): @pytest.mark.parametrize( "exception_cls", - (AirflowSensorTimeout, AirflowTaskTimeout, AirflowFailException, Exception), + (AirflowSensorTimeout, AirflowTaskTimeout, AirflowFailException, AirflowPokeFailException, Exception), ) - def test_never_fail_with_skip_exception(self, make_sensor, exception_cls): - sensor, dr = make_sensor(False, never_fail=True) + def test_skip_on_any_error_with_skip_exception(self, make_sensor, exception_cls): + sensor, dr = make_sensor(False, fail_policy=FailPolicy.SKIP_ON_ANY_ERROR) sensor.poke = Mock(side_effect=[exception_cls(None)]) self._run(sensor) @@ -248,9 +249,12 @@ def test_never_fail_with_skip_exception(self, make_sensor, exception_cls): if ti.task_id == DUMMY_OP: assert ti.state == State.NONE - def test_soft_fail_with_retries(self, make_sensor): + def test_skip_on_timeout_with_retries(self, make_sensor): sensor, dr = make_sensor( - return_value=False, soft_fail=True, retries=1, retry_delay=timedelta(milliseconds=1) + return_value=False, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + retries=1, + retry_delay=timedelta(milliseconds=1), ) # first run times out and task instance is skipped @@ -358,9 +362,13 @@ def _get_tis(): assert sensor_ti.state == State.FAILED assert dummy_ti.state == State.NONE - def test_soft_fail_with_reschedule(self, make_sensor, time_machine, session): + def test_skip_on_timeout_with_reschedule(self, make_sensor, time_machine, session): sensor, dr = make_sensor( - return_value=False, poke_interval=10, timeout=5, soft_fail=True, mode="reschedule" + return_value=False, + poke_interval=10, + timeout=5, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, + mode="reschedule", ) def _get_tis(): @@ -912,7 +920,7 @@ def test_reschedule_and_retry_timeout_and_silent_fail(self, make_sensor, time_ma retries=2, retry_delay=timedelta(seconds=3), mode="reschedule", - silent_fail=True, + fail_policy=FailPolicy.IGNORE_ERROR, ) def _get_sensor_ti(): @@ -1112,14 +1120,15 @@ def test_poke_mode_only_bad_poke(self): class TestAsyncSensor: @pytest.mark.parametrize( - "soft_fail, expected_exception", + "fail_policy, expected_exception", [ - (True, AirflowSkipException), - (False, AirflowException), + (FailPolicy.SKIP_ON_TIMEOUT, AirflowException), + (FailPolicy.SKIP_ON_ANY_ERROR, AirflowSkipException), + (FailPolicy.NONE, AirflowException), ], ) - def test_fail_after_resuming_deferred_sensor(self, soft_fail, expected_exception): - async_sensor = DummyAsyncSensor(task_id="dummy_async_sensor", soft_fail=soft_fail) + def test_fail_after_resuming_deferred_sensor(self, fail_policy, expected_exception): + async_sensor = DummyAsyncSensor(task_id="dummy_async_sensor", fail_policy=fail_policy) ti = TaskInstance(task=async_sensor) ti.next_method = "execute_complete" with pytest.raises(expected_exception): diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 4cd7b5e5f8c6f..9b283926f5f0f 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -30,7 +30,13 @@ from airflow import settings from airflow.decorators import task as task_deco -from airflow.exceptions import AirflowException, AirflowSensorTimeout, AirflowSkipException, TaskDeferred +from airflow.exceptions import ( + AirflowException, + AirflowPokeFailException, + AirflowSensorTimeout, + AirflowSkipException, + TaskDeferred, +) from airflow.models import DagBag, DagRun, TaskInstance from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel @@ -43,6 +49,10 @@ ExternalTaskMarker, ExternalTaskSensor, ) +from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS, ignore_provider_compatibility_error +with ignore_provider_compatibility_error("2.10.0", __file__): + from airflow.sensors.base import FailPolicy + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.triggers.external_task import WorkflowTrigger from airflow.utils.hashlib_wrapper import md5 @@ -255,7 +265,7 @@ def test_external_task_group_not_exists_without_check_existence(self): dag=self.dag, poke_interval=0.1, ) - with pytest.raises(AirflowException, match="Sensor has timed out"): + with pytest.raises(AirflowSensorTimeout, match="Sensor has timed out"): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode @@ -284,13 +294,13 @@ def test_external_task_group_sensor_failed_states(self): dag=self.dag, ) with pytest.raises( - AirflowException, + AirflowPokeFailException, match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.", ): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_catch_overlap_allowed_failed_state(self): - with pytest.raises(AirflowException): + with pytest.raises(ValueError): ExternalTaskSensor( task_id="test_external_task_sensor_check", external_dag_id=TEST_DAG_ID, @@ -336,14 +346,15 @@ def test_external_task_sensor_failed_states_as_success(self, caplog): error_message = rf"Some of the external tasks \['{TEST_TASK_ID}'\] in DAG {TEST_DAG_ID} failed\." with caplog.at_level(logging.INFO, logger=op.log.name): caplog.clear() - with pytest.raises(AirflowException, match=error_message): + with pytest.raises(AirflowPokeFailException, match=error_message): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert ( f"Poking for tasks ['{TEST_TASK_ID}'] in dag {TEST_DAG_ID} on {DEFAULT_DATE.isoformat()} ... " ) in caplog.messages + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode - def test_external_task_sensor_soft_fail_failed_states_as_skipped(self): + def test_external_task_sensor_skip_on_timeout_failed_states_as_skipped(self): self.add_time_sensor() op = ExternalTaskSensor( task_id="test_external_task_sensor_check", @@ -351,7 +362,7 @@ def test_external_task_sensor_soft_fail_failed_states_as_skipped(self): external_task_id=TEST_TASK_ID, allowed_states=[State.FAILED], failed_states=[State.SUCCESS], - soft_fail=True, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, dag=self.dag, ) @@ -444,7 +455,7 @@ def test_external_task_sensor_failed_states_as_success_mulitple_task_ids(self, c ) with caplog.at_level(logging.INFO, logger=op.log.name): caplog.clear() - with pytest.raises(AirflowException, match=error_message): + with pytest.raises(AirflowPokeFailException, match=error_message): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert ( f"Poking for tasks ['{TEST_TASK_ID}', '{TEST_TASK_ID_ALTERNATE}'] " @@ -491,8 +502,9 @@ def test_external_dag_sensor_log(self, caplog): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert (f"Poking for DAG 'other_dag' on {DEFAULT_DATE.isoformat()} ... ") in caplog.messages + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode - def test_external_dag_sensor_soft_fail_as_skipped(self): + def test_external_dag_sensor_skip_on_timeout_as_skipped(self): other_dag = DAG("other_dag", default_args=self.args, end_date=DEFAULT_DATE, schedule="@once") triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} other_dag.create_dagrun( @@ -509,7 +521,7 @@ def test_external_dag_sensor_soft_fail_as_skipped(self): external_task_id=None, allowed_states=[State.FAILED], failed_states=[State.SUCCESS], - soft_fail=True, + fail_policy=FailPolicy.SKIP_ON_TIMEOUT, dag=self.dag, ) @@ -617,12 +629,12 @@ def test_external_task_sensor_fn_multiple_logical_dates(self): dag=dag, ) - # We need to test for an AirflowException explicitly since + # We need to test for an AirflowPokeFailException explicitly since # AirflowSensorTimeout is a subclass that will be raised if this does # not execute properly. - with pytest.raises(AirflowException) as ex_ctx: + with pytest.raises(AirflowPokeFailException) as ex_ctx: task_chain_with_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - assert type(ex_ctx.value) is AirflowException + assert type(ex_ctx.value) is AirflowPokeFailException @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode def test_external_task_sensor_delta(self): @@ -870,7 +882,7 @@ def test_external_task_group_with_mapped_tasks_failed_states(self): dag=self.dag, ) with pytest.raises( - AirflowException, + AirflowPokeFailException, match=f"The external task_group '{TEST_TASK_GROUP_ID}' in DAG '{TEST_DAG_ID}' failed.", ): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -896,6 +908,7 @@ def test_external_task_group_when_there_is_no_TIs(self): ignore_ti_state=True, ) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") @pytest.mark.parametrize( "kwargs, expected_message", ( @@ -922,14 +935,14 @@ def test_external_task_group_when_there_is_no_TIs(self): ), ) @pytest.mark.parametrize( - "soft_fail, expected_exception", + "fail_policy, expected_exception", ( ( - False, - AirflowException, + FailPolicy.NONE, + AirflowPokeFailException, ), ( - True, + FailPolicy.SKIP_ON_TIMEOUT, AirflowSkipException, ), ), @@ -937,7 +950,7 @@ def test_external_task_group_when_there_is_no_TIs(self): @mock.patch("airflow.sensors.external_task.ExternalTaskSensor.get_count") @mock.patch("airflow.sensors.external_task.ExternalTaskSensor._get_dttm_filter") def test_fail_poke( - self, _get_dttm_filter, get_count, soft_fail, expected_exception, kwargs, expected_message + self, _get_dttm_filter, get_count, fail_policy, expected_exception, kwargs, expected_message ): _get_dttm_filter.return_value = [] get_count.return_value = 1 @@ -946,13 +959,16 @@ def test_fail_poke( external_dag_id=TEST_DAG_ID, allowed_states=["success"], dag=self.dag, - soft_fail=soft_fail, + fail_policy=fail_policy, deferrable=False, **kwargs, ) + if fail_policy == FailPolicy.SKIP_ON_TIMEOUT: + expected_message = "Skipping due fail_policy set to SKIP_ON_TIMEOUT." with pytest.raises(expected_exception, match=expected_message): op.execute(context={}) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="FailPolicy present from Airflow 2.10.0") @pytest.mark.parametrize( "response_get_current, response_exists, kwargs, expected_message", ( @@ -979,15 +995,15 @@ def test_fail_poke( ), ) @pytest.mark.parametrize( - "soft_fail, expected_exception", + "fail_policy, expected_exception", ( ( - False, - AirflowException, + FailPolicy.NONE, + AirflowPokeFailException, ), ( - True, - AirflowException, + FailPolicy.SKIP_ON_TIMEOUT, + AirflowSkipException, ), ), ) @@ -1001,7 +1017,7 @@ def test_fail__check_for_existence( exists, get_dag, _get_dttm_filter, - soft_fail, + fail_policy, expected_exception, response_get_current, response_exists, @@ -1020,10 +1036,12 @@ def test_fail__check_for_existence( external_dag_id=TEST_DAG_ID, allowed_states=["success"], dag=self.dag, - soft_fail=soft_fail, + fail_policy=fail_policy, check_existence=True, **kwargs, ) + if fail_policy == FailPolicy.SKIP_ON_TIMEOUT: + expected_message = "Skipping due fail_policy set to SKIP_ON_TIMEOUT." with pytest.raises(expected_exception, match=expected_message): op.execute(context={}) @@ -1052,7 +1070,7 @@ def test_defer_and_fire_task_state_trigger(self): assert isinstance(exc.value.trigger, WorkflowTrigger), "Trigger is not a WorkflowTrigger" def test_defer_and_fire_failed_state_trigger(self): - """Tests that an AirflowException is raised in case of error event""" + """Tests that an AirflowPokeFailException is raised in case of error event""" sensor = ExternalTaskSensor( task_id=TASK_ID, external_task_id=EXTERNAL_TASK_ID, @@ -1060,13 +1078,13 @@ def test_defer_and_fire_failed_state_trigger(self): deferrable=True, ) - with pytest.raises(AirflowException): + with pytest.raises(AirflowPokeFailException): sensor.execute_complete( context=mock.MagicMock(), event={"status": "error", "message": "test failure message"} ) def test_defer_and_fire_timeout_state_trigger(self): - """Tests that an AirflowException is raised in case of timeout event""" + """Tests that an AirflowPokeFailException is raised in case of timeout event""" sensor = ExternalTaskSensor( task_id=TASK_ID, external_task_id=EXTERNAL_TASK_ID, @@ -1074,7 +1092,7 @@ def test_defer_and_fire_timeout_state_trigger(self): deferrable=True, ) - with pytest.raises(AirflowException): + with pytest.raises(AirflowPokeFailException): sensor.execute_complete( context=mock.MagicMock(), event={"status": "timeout", "message": "Dag was not started within 1 minute, assuming fail."}, diff --git a/tests/triggers/test_external_task.py b/tests/triggers/test_external_task.py index 4a193c5fef9c5..b49fadbe6ea70 100644 --- a/tests/triggers/test_external_task.py +++ b/tests/triggers/test_external_task.py @@ -213,7 +213,7 @@ def test_serialization(self): "skipped_states": None, "allowed_states": self.STATES, "poke_interval": 5, - "soft_fail": False, + "fail_policy": "none", }