Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def string_lower_type(val):
default=conf.get("api", "ssl_key"),
help="Path to the key to use with the SSL certificate",
)
ARG_DEV = Arg(("-d", "--dev"), help="Start FastAPI in development mode", action="store_true")
ARG_DEV = Arg(("-d", "--dev"), help="Start in development mode with hot-reload enabled", action="store_true")

# scheduler
ARG_NUM_RUNS = Arg(
Expand Down Expand Up @@ -1923,6 +1923,7 @@ class GroupCommand(NamedTuple):
ARG_LOG_FILE,
ARG_SKIP_SERVE_LOGS,
ARG_VERBOSE,
ARG_DEV,
),
epilog=(
"Signals:\n"
Expand All @@ -1946,6 +1947,7 @@ class GroupCommand(NamedTuple):
ARG_CAPACITY,
ARG_VERBOSE,
ARG_SKIP_SERVE_LOGS,
ARG_DEV,
),
),
ActionCommand(
Expand All @@ -1961,6 +1963,7 @@ class GroupCommand(NamedTuple):
ARG_STDERR,
ARG_LOG_FILE,
ARG_VERBOSE,
ARG_DEV,
),
),
ActionCommand(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def api_server(args: Namespace):

get_signing_args()

if args.dev:
if cli_utils.should_enable_hot_reload(args):
print(f"Starting the API server on port {args.port} and host {args.host} in development mode.")
log.warning("Running in dev mode, ignoring uvicorn args")
from fastapi_cli.cli import _run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def dag_processor(args):
"""Start Airflow Dag Processor Job."""
job_runner = _create_dag_processor_job_runner(args)

if cli_utils.should_enable_hot_reload(args):
from airflow.cli.hot_reload import run_with_reloader

run_with_reloader(
lambda: run_job(job=job_runner.job, execute_callable=job_runner._execute),
process_name="dag-processor",
)
return

run_command_with_daemon_option(
args=args,
process_name="dag-processor",
Expand Down
6 changes: 6 additions & 0 deletions airflow-core/src/airflow/cli/commands/scheduler_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def scheduler(args: Namespace):
"""Start Airflow Scheduler."""
print(settings.HEADER)

if cli_utils.should_enable_hot_reload(args):
from airflow.cli.hot_reload import run_with_reloader

run_with_reloader(lambda: _run_scheduler_job(args), process_name="scheduler")
return

run_command_with_daemon_option(
args=args,
process_name="scheduler",
Expand Down
9 changes: 9 additions & 0 deletions airflow-core/src/airflow/cli/commands/triggerer_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def triggerer(args):
print(settings.HEADER)
triggerer_heartrate = conf.getfloat("triggerer", "JOB_HEARTBEAT_SEC")

if cli_utils.should_enable_hot_reload(args):
from airflow.cli.hot_reload import run_with_reloader

run_with_reloader(
lambda: triggerer_run(args.skip_serve_logs, args.capacity, triggerer_heartrate),
process_name="triggerer",
)
return

run_command_with_daemon_option(
args=args,
process_name="triggerer",
Expand Down
197 changes: 197 additions & 0 deletions airflow-core/src/airflow/cli/hot_reload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Hot reload utilities for development mode."""

from __future__ import annotations

import os
import signal
import sys
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING

import structlog

if TYPE_CHECKING:
import subprocess

log = structlog.getLogger(__name__)


def run_with_reloader(
callback: Callable,
process_name: str = "process",
) -> None:
"""
Run a callback function with automatic reloading on file changes.

This function monitors specified paths for changes and restarts the process
when changes are detected. Useful for development mode hot-reloading.

:param callback: The function to run. This should be the main entry point
of the command that needs hot-reload support.
:param process_name: Name of the process being run (for logging purposes)
"""
# Default watch paths - watch the airflow source directory
import airflow

airflow_root = Path(airflow.__file__).parent
watch_paths = [airflow_root]

log.info("Starting %s in development mode with hot-reload enabled", process_name)
log.info("Watching paths: %s", watch_paths)

# Check if we're the main process or a reloaded child
reloader_pid = os.environ.get("AIRFLOW_DEV_RELOADER_PID")
if reloader_pid is None:
# We're the main process - set up the reloader
os.environ["AIRFLOW_DEV_RELOADER_PID"] = str(os.getpid())
_run_reloader(watch_paths)
else:
# We're a child process - just run the callback
callback()


def _terminate_process_tree(
process: subprocess.Popen[bytes],
timeout: int = 5,
force_kill_remaining: bool = True,
) -> None:
"""
Terminate a process and all its children recursively.

Uses psutil to ensure all child processes are properly terminated,
which is important for cleaning up subprocesses like serve-log servers.

:param process: The subprocess.Popen process to terminate
:param timeout: Timeout in seconds to wait for graceful termination
:param force_kill_remaining: If True, force kill processes that don't terminate gracefully
"""
import subprocess

import psutil

try:
parent = psutil.Process(process.pid)
# Get all child processes recursively
children = parent.children(recursive=True)

# Terminate all children first
for child in children:
try:
child.terminate()
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass

# Terminate the parent
parent.terminate()

# Wait for all processes to terminate
gone, alive = psutil.wait_procs(children + [parent], timeout=timeout)

# Force kill any remaining processes if requested
if force_kill_remaining:
for proc in alive:
try:
log.warning("Force killing process %s", proc.pid)
proc.kill()
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass

except (psutil.NoSuchProcess, psutil.AccessDenied):
# Process already terminated
pass
except Exception as e:
log.warning("Error terminating process tree: %s", e)
# Fallback to simple termination
try:
process.terminate()
process.wait(timeout=timeout)
except subprocess.TimeoutExpired:
if force_kill_remaining:
log.warning("Process did not terminate gracefully, killing...")
process.kill()
process.wait()


def _run_reloader(watch_paths: Sequence[str | Path]) -> None:
"""
Watch for changes and restart the process.

Watches the provided paths and restarts the process by re-executing the
Python interpreter with the same arguments.

:param watch_paths: List of paths to watch for changes.
"""
import subprocess

from watchfiles import watch

process = None
should_exit = False

def start_process():
"""Start or restart the subprocess."""
nonlocal process
if process is not None:
log.info("Stopping process and all its children...")
_terminate_process_tree(process, timeout=5, force_kill_remaining=True)

log.info("Starting process...")
# Restart the process by re-executing Python with the same arguments
# Note: sys.argv is safe here as it comes from the original CLI invocation
# and is only used in development mode for hot-reloading the same process
process = subprocess.Popen([sys.executable] + sys.argv)
return process

def signal_handler(signum, frame):
"""Handle termination signals."""
nonlocal should_exit, process
should_exit = True
log.info("Received signal %s, shutting down...", signum)
if process:
_terminate_process_tree(process, timeout=5, force_kill_remaining=False)
sys.exit(0)

# Set up signal handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

# Start the initial process
process = start_process()

log.info("Hot-reload enabled. Watching for file changes...")
log.info("Press Ctrl+C to stop")

try:
for changes in watch(*watch_paths):
if should_exit:
break

log.info("Detected changes: %s", changes)
log.info("Reloading...")

# Restart the process
process = start_process()

except KeyboardInterrupt:
log.info("Shutting down...")
if process:
process.terminate()
process.wait()
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,10 @@ def validate_dag_bundle_arg(bundle_names: list[str]) -> None:
unknown_bundles: set[str] = set(bundle_names) - known_bundles
if unknown_bundles:
raise SystemExit(f"Bundles not found: {', '.join(unknown_bundles)}")


def should_enable_hot_reload(args) -> bool:
"""Check whether hot-reload should be enabled based on --dev flag or DEV_MODE env var."""
if getattr(args, "dev", False):
return True
return os.getenv("DEV_MODE", "false").lower() == "true"
11 changes: 11 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_dag_processor_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,14 @@ def test_bundle_names_passed(self, mock_runner, configure_testing_dag_bundle):
with configure_testing_dag_bundle(os.devnull):
dag_processor_command.dag_processor(args)
assert mock_runner.call_args.kwargs["processor"].bundle_names_to_parse == ["testing"]

@mock.patch("airflow.cli.hot_reload.run_with_reloader")
def test_dag_processor_with_dev_flag(self, mock_reloader):
"""Ensure that dag-processor with --dev flag uses hot-reload"""
args = self.parser.parse_args(["dag-processor", "--dev"])
dag_processor_command.dag_processor(args)

# Verify that run_with_reloader was called
mock_reloader.assert_called_once()
# The callback function should be callable
assert callable(mock_reloader.call_args[0][0])
10 changes: 10 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_scheduler_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,13 @@ def test_run_job_exception_handling(self, mock_run_job, mock_process, mock_sched
)
mock_process.assert_called_once_with(target=serve_logs)
mock_process().terminate.assert_called_once_with()

@mock.patch("airflow.cli.hot_reload.run_with_reloader")
def test_scheduler_with_dev_flag(self, mock_reloader):
args = self.parser.parse_args(["scheduler", "--dev"])
scheduler_command.scheduler(args)

# Verify that run_with_reloader was called
mock_reloader.assert_called_once()
# The callback function should be callable
assert callable(mock_reloader.call_args[0][0])
11 changes: 11 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_triggerer_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ def test_trigger_run_serve_logs(self, mock_process, mock_run_job, mock_trigger_j
job=mock_trigger_job_runner.return_value.job,
execute_callable=mock_trigger_job_runner.return_value._execute,
)

@mock.patch("airflow.cli.hot_reload.run_with_reloader")
def test_triggerer_with_dev_flag(self, mock_reloader):
"""Ensure that triggerer with --dev flag uses hot-reload"""
args = self.parser.parse_args(["triggerer", "--dev"])
triggerer_command.triggerer(args)

# Verify that run_with_reloader was called
mock_reloader.assert_called_once()
# The callback function should be callable
assert callable(mock_reloader.call_args[0][0])
Loading
Loading