Skip to content
8 changes: 6 additions & 2 deletions airflow-core/src/airflow/logging/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import structlog.typing

from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
from airflow.utils.log.file_task_handler import LegacyLogResponse, LogResponse


class RemoteLogIO(Protocol):
Expand All @@ -44,6 +44,10 @@ def upload(self, path: os.PathLike | str, ti: RuntimeTI) -> None:
"""Upload the given log path to the remote storage."""
...

def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
def read(self, relative_path: str, ti: RuntimeTI) -> LegacyLogResponse:
"""Read logs from the given remote log path."""
...

def stream(self, relative_path: str, ti: RuntimeTI) -> LogResponse:
"""Stream-based read interface for reading logs from the given remote log path."""
...
21 changes: 16 additions & 5 deletions airflow-core/src/airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"""Information _about_ the log fetching process for display to a user"""
RawLogStream: TypeAlias = Generator[str, None, None]
"""Raw log stream, containing unparsed log lines."""
LegacyLogResponse: TypeAlias = tuple[LogSourceInfo, LogMessages]
LegacyLogResponse: TypeAlias = tuple[LogSourceInfo, LogMessages | None]
"""Legacy log response, containing source information and log messages."""
LogResponse: TypeAlias = tuple[LogSourceInfo, list[RawLogStream]]
LogResponseWithSize: TypeAlias = tuple[LogSourceInfo, list[RawLogStream], int]
Expand Down Expand Up @@ -197,7 +197,8 @@ def _fetch_logs_from_service(url: str, log_relative_path: str) -> Response:
if not _parse_timestamp:

def _parse_timestamp(line: str):
timestamp_str, _ = line.split(" ", 1)
# Make this resilient to all input types, ensure it's always a string.
timestamp_str, _ = str(line).split(" ", 1)
return pendulum.parse(timestamp_str.strip("[]"))


Expand Down Expand Up @@ -262,7 +263,10 @@ def _log_stream_to_parsed_log_stream(
for line in log_stream:
if line:
try:
log = StructuredLogMessage.model_validate_json(line)
if isinstance(line, dict):
log = StructuredLogMessage.model_validate(line)
else:
log = StructuredLogMessage.model_validate_json(line)
except ValidationError:
with suppress(Exception):
# If we can't parse the timestamp, don't attach one to the row
Expand Down Expand Up @@ -936,5 +940,12 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> LegacyLogResponse
# This living here is not really a good plan, but it just about works for now.
# Ideally we move all the read+combine logic in to TaskLogReader and out of the task handler.
path = self._render_filename(ti, try_number)
sources, logs = remote_io.read(path, ti)
return sources, logs or []
logs: LogMessages | list[RawLogStream] | None # extra typing to void mypy assignment error
try:
# Use .stream interface if provider's RemoteIO supports it
sources, logs = remote_io.stream(path, ti)
return sources, logs or []
except (AttributeError, NotImplementedError):
# Fallback to .read interface
sources, logs = remote_io.read(path, ti)
return sources, logs or []
12 changes: 10 additions & 2 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import asyncio
from collections.abc import AsyncGenerator, Generator
from typing import Any
from typing import Any, TypedDict

from botocore.exceptions import ClientError

Expand All @@ -35,6 +35,14 @@
NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD = 3


class CloudWatchLogEvent(TypedDict):
"""TypedDict for CloudWatch Log Event."""

timestamp: int
message: str
ingestionTime: int


class AwsLogsHook(AwsBaseHook):
"""
Interact with Amazon CloudWatch Logs.
Expand Down Expand Up @@ -67,7 +75,7 @@ def get_log_events(
start_from_head: bool | None = None,
continuation_token: ContinuationToken | None = None,
end_time: int | None = None,
) -> Generator:
) -> Generator[CloudWatchLogEvent, None, None]:
"""
Return a generator for log items in a single stream; yields all items available at the current moment.

Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes to this file backward compatible with Airflow 2.10? PRs that change both core and providers may hide compatibility issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and I agree.
I will test the CloudWatchHandler with Airflow 2.10 as well later on.

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import os
from collections.abc import Generator
from datetime import date, datetime, timedelta, timezone
from functools import cached_property
from pathlib import Path
Expand All @@ -40,8 +41,15 @@
import structlog.typing

from airflow.models.taskinstance import TaskInstance
from airflow.providers.amazon.aws.hooks.logs import CloudWatchLogEvent
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
from airflow.utils.log.file_task_handler import LogMessages, LogSourceInfo
from airflow.utils.log.file_task_handler import (
LegacyLogResponse,
LogMessages,
LogResponse,
LogSourceInfo,
RawLogStream,
)


def json_serialize_legacy(value: Any) -> str | None:
Expand Down Expand Up @@ -163,20 +171,35 @@ def upload(self, path: os.PathLike | str, ti: RuntimeTI):
self.close()
return

def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages | None]:
logs: LogMessages | None = []
def read(self, relative_path: str, ti: RuntimeTI) -> LegacyLogResponse:
messages, logs = self.stream(relative_path, ti)
str_logs: list[str] = []

for group in logs:
for msg in group:
str_logs.append(f"{msg}\n")

return messages, str_logs

def stream(self, relative_path: str, ti: RuntimeTI) -> LogResponse:
logs: list[RawLogStream] = []
messages = [
f"Reading remote log from Cloudwatch log_group: {self.log_group} log_stream: {relative_path}"
]
try:
logs = [self.get_cloudwatch_logs(relative_path, ti)]
gen: RawLogStream = (
self._parse_cloudwatch_log_event(event)
for event in self.get_cloudwatch_logs(relative_path, ti)
)
logs = [gen]
except Exception as e:
logs = None
messages.append(str(e))

return messages, logs

def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI):
def get_cloudwatch_logs(
self, stream_name: str, task_instance: RuntimeTI
) -> Generator[CloudWatchLogEvent, None, None]:
"""
Return all logs from the given log stream.

Expand All @@ -192,29 +215,22 @@ def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI):
if (end_date := getattr(task_instance, "end_date", None)) is None
else datetime_to_epoch_utc_ms(end_date + timedelta(seconds=30))
)
events = self.hook.get_log_events(
return self.hook.get_log_events(
log_group=self.log_group,
log_stream_name=stream_name,
end_time=end_time,
)
return "\n".join(self._event_to_str(event) for event in events)

def _event_to_dict(self, event: dict) -> dict:
def _parse_cloudwatch_log_event(self, event: CloudWatchLogEvent) -> str:
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc).isoformat()
message = event["message"]
event_msg = event["message"]
try:
message = json.loads(message)
message = json.loads(event_msg)
message["timestamp"] = event_dt
return message
except Exception:
return {"timestamp": event_dt, "event": message}
message = {"timestamp": event_dt, "event": event_msg}

def _event_to_str(self, event: dict) -> str:
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
# Format a datetime object to a string in Zulu time without milliseconds.
formatted_event_dt = event_dt.strftime("%Y-%m-%dT%H:%M:%SZ")
message = event["message"]
return f"[{formatted_event_dt}] {message}"
return json.dumps(message)


class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
Expand Down Expand Up @@ -291,4 +307,22 @@ def _read_remote_logs(
) -> tuple[LogSourceInfo, LogMessages]:
stream_name = self._render_filename(task_instance, try_number)
messages, logs = self.io.read(stream_name, task_instance)
return messages, logs or []

messages = [
f"Reading remote log from Cloudwatch log_group: {self.io.log_group} log_stream: {stream_name}"
]
try:
events = self.io.get_cloudwatch_logs(stream_name, task_instance)
logs = ["\n".join(self._event_to_str(event) for event in events)]
except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whehter it's possible to list all possible exception instead of using Exception

logs = []
messages.append(str(e))

return messages, logs

def _event_to_str(self, event: CloudWatchLogEvent) -> str:
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
# Format a datetime object to a string in Zulu time without milliseconds.
formatted_event_dt = event_dt.strftime("%Y-%m-%dT%H:%M:%SZ")
message = event["message"]
return f"[{formatted_event_dt}] {message}"
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,9 @@ def test_log_message(self):
assert metadata == [
f"Reading remote log from Cloudwatch log_group: log_group_name log_stream: {stream_name}"
]
assert logs == ['[2025-03-27T21:58:01Z] {"foo": "bar", "event": "Hi", "level": "info"}']

def test_event_to_str(self):
handler = self.subject
current_time = int(time.time()) * 1000
events = [
{"timestamp": current_time - 2000, "message": "First"},
{"timestamp": current_time - 1000, "message": "Second"},
{"timestamp": current_time, "message": "Third"},
]
assert [handler._event_to_str(event) for event in events] == (
[
f"[{get_time_str(current_time - 2000)}] First",
f"[{get_time_str(current_time - 1000)}] Second",
f"[{get_time_str(current_time)}] Third",
assert logs == [
'{"foo": "bar", "event": "Hi", "level": "info", "timestamp": "2025-03-27T21:58:01.002000+00:00"}\n'
]
)


@pytest.mark.db_test
Expand Down Expand Up @@ -426,6 +412,22 @@ def test_filename_template_for_backward_compatibility(self):
filename_template=None,
)

def test_event_to_str(self):
handler = self.cloudwatch_task_handler
current_time = int(time.time()) * 1000
events = [
{"timestamp": current_time - 2000, "message": "First"},
{"timestamp": current_time - 1000, "message": "Second"},
{"timestamp": current_time, "message": "Third"},
]
assert [handler._event_to_str(event) for event in events] == (
[
f"[{get_time_str(current_time - 2000)}] First",
f"[{get_time_str(current_time - 1000)}] Second",
f"[{get_time_str(current_time)}] Third",
]
)


def generate_log_events(conn, log_group_name, log_stream_name, log_events):
conn.create_log_group(logGroupName=log_group_name)
Expand Down