Skip to content

Commit

Permalink
introduce fail_policy
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelauv committed Nov 18, 2024
1 parent 007c4b1 commit 0474930
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 147 deletions.
5 changes: 3 additions & 2 deletions airflow/decorators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ from airflow.decorators.short_circuit import short_circuit_task
from airflow.decorators.task_group import task_group
from airflow.models.dag import dag
from airflow.providers.cncf.kubernetes.secret import Secret
from airflow.sensors.base import FailPolicy
from airflow.typing_compat import Literal

# Please keep this in sync with __init__.py's __all__.
Expand Down Expand Up @@ -708,7 +709,7 @@ class TaskDecoratorCollection:
*,
poke_interval: float = ...,
timeout: float = ...,
soft_fail: bool = False,
fail_policy: FailPolicy = ...,
mode: str = ...,
exponential_backoff: bool = False,
max_wait: timedelta | float | None = None,
Expand All @@ -720,7 +721,7 @@ class TaskDecoratorCollection:
:param poke_interval: Time in seconds that the job should wait in
between each try
:param timeout: Time, in seconds before the task times out and fails.
:param soft_fail: Set to true to mark the task as SKIPPED on failure
:param fail_policy: TODO.
:param mode: How the sensor operates.
Options are: ``{ poke | reschedule }``, default is ``poke``.
When set to ``poke`` the sensor is taking up a worker slot for its
Expand Down
22 changes: 17 additions & 5 deletions airflow/example_dags/example_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.models.dag import DAG
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.sensors.base import FailPolicy
from airflow.providers.standard.sensors.bash import BashSensor
from airflow.providers.standard.sensors.filesystem import FileSensor
from airflow.providers.standard.sensors.python import PythonSensor
Expand Down Expand Up @@ -68,7 +69,7 @@ def failure_callable():
t2 = TimeSensor(
task_id="timeout_after_second_date_in_the_future",
timeout=1,
soft_fail=True,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
target_time=(datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)).time(),
)
# [END example_time_sensors]
Expand All @@ -81,15 +82,20 @@ def failure_callable():
t2a = TimeSensorAsync(
task_id="timeout_after_second_date_in_the_future_async",
timeout=1,
soft_fail=True,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
target_time=(datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(hours=1)).time(),
)
# [END example_time_sensors_async]

# [START example_bash_sensors]
t3 = BashSensor(task_id="Sensor_succeeds", bash_command="exit 0")

t4 = BashSensor(task_id="Sensor_fails_after_3_seconds", timeout=3, soft_fail=True, bash_command="exit 1")
t4 = BashSensor(
task_id="Sensor_fails_after_3_seconds",
timeout=3,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
bash_command="exit 1",
)
# [END example_bash_sensors]

t5 = BashOperator(task_id="remove_file", bash_command="rm -rf /tmp/temporary_file_for_testing")
Expand All @@ -112,13 +118,19 @@ def failure_callable():
t9 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable)

t10 = PythonSensor(
task_id="failure_timeout_sensor_python", timeout=3, soft_fail=True, python_callable=failure_callable
task_id="failure_timeout_sensor_python",
timeout=3,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
python_callable=failure_callable,
)
# [END example_python_sensors]

# [START example_day_of_week_sensor]
t11 = DayOfWeekSensor(
task_id="week_day_sensor_failing_on_timeout", timeout=3, soft_fail=True, week_day=WeekDay.MONDAY
task_id="week_day_sensor_failing_on_timeout",
timeout=3,
fail_policy=FailPolicy.SKIP_ON_TIMEOUT,
week_day=WeekDay.MONDAY,
)
# [END example_day_of_week_sensor]

Expand Down
4 changes: 4 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class AirflowSensorTimeout(AirflowException):
"""Raise when there is a timeout on sensor polling."""


class AirflowPokeFailException(AirflowException):
"""Raise when a sensor must not try to poke again."""


class AirflowRescheduleException(AirflowException):
"""
Raise when the task should be re-scheduled at a later time.
Expand Down
70 changes: 38 additions & 32 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import datetime
import enum
import functools
import hashlib
import time
Expand All @@ -32,7 +33,7 @@
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
AirflowFailException,
AirflowPokeFailException,
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
Expand Down Expand Up @@ -109,15 +110,31 @@ def _orig_start_date(
)


class FailPolicy(str, enum.Enum):
"""Class with sensor's fail policies."""

# if poke method raise an exception, sensor will not be skipped on.
NONE = "none"

# If poke method raises an exception, sensor will be skipped on.
SKIP_ON_ANY_ERROR = "skip_on_any_error"

# If poke method raises AirflowSensorTimeout, AirflowTaskTimeout,AirflowPokeFailException or AirflowSkipException
# sensor will be skipped on.
SKIP_ON_TIMEOUT = "skip_on_timeout"

# If poke method raises an exception different from AirflowSensorTimeout, AirflowTaskTimeout,
# AirflowSkipException or AirflowFailException sensor will ignore exception and re-poke until timeout.
IGNORE_ERROR = "ignore_error"


class BaseSensorOperator(BaseOperator, SkipMixin):
"""
Sensor operators are derived from this class and inherit these attributes.
Sensor operators keep executing at a time interval and succeed when
a criteria is met and fail if and when they time out.
:param soft_fail: Set to true to mark the task as SKIPPED on failure.
Mutually exclusive with never_fail.
:param poke_interval: Time that the job should wait in between each try.
Can be ``timedelta`` or ``float`` seconds.
:param timeout: Time elapsed before the task times out and fails.
Expand Down Expand Up @@ -145,13 +162,10 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
:param exponential_backoff: allow progressive longer waits between
pokes by using exponential backoff algorithm
:param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
:param silent_fail: If true, and poke method raises an exception different from
AirflowSensorTimeout, AirflowTaskTimeout, AirflowSkipException
and AirflowFailException, the sensor will log the error and continue
its execution. Otherwise, the sensor task fails, and it can be retried
based on the provided `retries` parameter.
:param never_fail: If true, and poke method raises an exception, sensor will be skipped.
Mutually exclusive with soft_fail.
:param fail_policy: defines the rule by which sensor skip itself. Options are:
``{ none | skip_on_any_error | skip_on_timeout | ignore_error }``
default is ``none``. Options can be set as string or
using the constants defined in the static class ``airflow.sensors.base.FailPolicy``
"""

ui_color: str = "#e6f1f2"
Expand All @@ -166,26 +180,19 @@ def __init__(
*,
poke_interval: timedelta | float = 60,
timeout: timedelta | float = conf.getfloat("sensors", "default_timeout"),
soft_fail: bool = False,
mode: str = "poke",
exponential_backoff: bool = False,
max_wait: timedelta | float | None = None,
silent_fail: bool = False,
never_fail: bool = False,
fail_policy: str = FailPolicy.NONE,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.poke_interval = self._coerce_poke_interval(poke_interval).total_seconds()
self.soft_fail = soft_fail
self.timeout = self._coerce_timeout(timeout).total_seconds()
self.mode = mode
self.exponential_backoff = exponential_backoff
self.max_wait = self._coerce_max_wait(max_wait)
if soft_fail is True and never_fail is True:
raise ValueError("soft_fail and never_fail are mutually exclusive, you can not provide both.")

self.silent_fail = silent_fail
self.never_fail = never_fail
self.fail_policy = fail_policy
self._validate_input_values()

@staticmethod
Expand Down Expand Up @@ -282,21 +289,20 @@ def run_duration() -> float:
except (
AirflowSensorTimeout,
AirflowTaskTimeout,
AirflowFailException,
AirflowPokeFailException,
AirflowSkipException,
) as e:
if self.soft_fail:
raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e
elif self.never_fail:
raise AirflowSkipException("Skipping due to never_fail is set to True.") from e
raise e
except AirflowSkipException as e:
if self.fail_policy == FailPolicy.SKIP_ON_TIMEOUT:
raise AirflowSkipException("Skipping due fail_policy set to SKIP_ON_TIMEOUT.") from e
elif self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR:
raise AirflowSkipException("Skipping due to SKIP_ON_ANY_ERROR is set to True.") from e
raise e
except Exception as e:
if self.silent_fail:
if self.fail_policy == FailPolicy.IGNORE_ERROR:
self.log.error("Sensor poke failed: \n %s", traceback.format_exc())
poke_return = False
elif self.never_fail:
raise AirflowSkipException("Skipping due to never_fail is set to True.") from e
elif self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR:
raise AirflowSkipException("Skipping due to SKIP_ON_ANY_ERROR is set to True.") from e
else:
raise e

Expand All @@ -306,13 +312,13 @@ def run_duration() -> float:
break

if run_duration() > self.timeout:
# If sensor is in soft fail mode but times out raise AirflowSkipException.
# If sensor is in SKIP_ON_TIMEOUT mode but times out it raise AirflowSkipException.
message = (
f"Sensor has timed out; run duration of {run_duration()} seconds exceeds "
f"the specified timeout of {self.timeout}."
)

if self.soft_fail:
if self.fail_policy == FailPolicy.SKIP_ON_TIMEOUT:
raise AirflowSkipException(message)
else:
raise AirflowSensorTimeout(message)
Expand All @@ -335,7 +341,7 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None,
try:
return super().resume_execution(next_method, next_kwargs, context)
except (AirflowException, TaskDeferralError) as e:
if self.soft_fail:
if self.fail_policy == FailPolicy.SKIP_ON_ANY_ERROR:
raise AirflowSkipException(str(e)) from e
raise

Expand Down
47 changes: 15 additions & 32 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowPokeFailException, AirflowSkipException
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
total_states = set(self.allowed_states + self.skipped_states + self.failed_states)

if len(total_states) != len(self.allowed_states) + len(self.skipped_states) + len(self.failed_states):
raise AirflowException(
raise ValueError(
"Duplicate values provided across allowed_states, skipped_states and failed_states."
)

Expand Down Expand Up @@ -286,32 +286,18 @@ def poke(self, context: Context, session: Session = NEW_SESSION) -> bool:
# Fail if anything in the list has failed.
if count_failed > 0:
if self.external_task_ids:
if self.soft_fail:
raise AirflowSkipException(
f"Some of the external tasks {self.external_task_ids} "
f"in DAG {self.external_dag_id} failed. Skipping due to soft_fail."
)
raise AirflowException(
raise AirflowPokeFailException(
f"Some of the external tasks {self.external_task_ids} "
f"in DAG {self.external_dag_id} failed."
)
elif self.external_task_group_id:
if self.soft_fail:
raise AirflowSkipException(
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed. Skipping due to soft_fail."
)
raise AirflowException(
raise AirflowPokeFailException(
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed."
)

else:
if self.soft_fail:
raise AirflowSkipException(
f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail."
)
raise AirflowException(f"The external DAG {self.external_dag_id} failed.")
raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} failed.")

count_skipped = -1
if self.skipped_states:
Expand Down Expand Up @@ -354,7 +340,7 @@ def execute(self, context: Context) -> None:
logical_dates=self._get_dttm_filter(context),
allowed_states=self.allowed_states,
poke_interval=self.poll_interval,
soft_fail=self.soft_fail,
fail_policy=self.fail_policy,
),
method_name="execute_complete",
)
Expand All @@ -364,38 +350,35 @@ def execute_complete(self, context, event=None):
if event["status"] == "success":
self.log.info("External tasks %s has executed successfully.", self.external_task_ids)
elif event["status"] == "skipped":
raise AirflowSkipException("External job has skipped skipping.")
raise AirflowPokeFailException("External job has skipped skipping.")
else:
if self.soft_fail:
raise AirflowSkipException("External job has failed skipping.")
else:
raise AirflowException(
"Error occurred while trying to retrieve task status. Please, check the "
"name of executed task and Dag."
)
raise AirflowPokeFailException(
"Error occurred while trying to retrieve task status. Please, check the "
"name of executed task and Dag."
)

def _check_for_existence(self, session) -> None:
dag_to_wait = DagModel.get_current(self.external_dag_id, session)

if not dag_to_wait:
raise AirflowException(f"The external DAG {self.external_dag_id} does not exist.")
raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} does not exist.")

if not os.path.exists(correct_maybe_zipped(dag_to_wait.fileloc)):
raise AirflowException(f"The external DAG {self.external_dag_id} was deleted.")
raise AirflowPokeFailException(f"The external DAG {self.external_dag_id} was deleted.")

if self.external_task_ids:
refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id)
for external_task_id in self.external_task_ids:
if not refreshed_dag_info.has_task(external_task_id):
raise AirflowException(
raise AirflowPokeFailException(
f"The external task {external_task_id} in "
f"DAG {self.external_dag_id} does not exist."
)

if self.external_task_group_id:
refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id)
if not refreshed_dag_info.has_task_group(self.external_task_group_id):
raise AirflowException(
raise AirflowPokeFailException(
f"The external task group '{self.external_task_group_id}' in "
f"DAG '{self.external_dag_id}' does not exist."
)
Expand Down
Loading

0 comments on commit 0474930

Please sign in to comment.