diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py index 277e705c18d2f..688a563d2b6d4 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py @@ -20,7 +20,8 @@ import contextlib import textwrap -from fastapi import Depends, HTTPException, Request, Response, status +from fastapi import Depends, HTTPException, Request, status +from fastapi.responses import StreamingResponse from itsdangerous import BadSignature, URLSafeSerializer from pydantic import NonNegativeInt, PositiveInt from sqlalchemy.orm import joinedload @@ -120,12 +121,17 @@ def get_log( ) ti = session.scalar(query) if ti is None: - query = select(TaskInstanceHistory).where( - TaskInstanceHistory.task_id == task_id, - TaskInstanceHistory.dag_id == dag_id, - TaskInstanceHistory.run_id == dag_run_id, - TaskInstanceHistory.map_index == map_index, - TaskInstanceHistory.try_number == try_number, + query = ( + select(TaskInstanceHistory) + .where( + TaskInstanceHistory.task_id == task_id, + TaskInstanceHistory.dag_id == dag_id, + TaskInstanceHistory.run_id == dag_run_id, + TaskInstanceHistory.map_index == map_index, + TaskInstanceHistory.try_number == try_number, + ) + .options(joinedload(TaskInstanceHistory.dag_run)) + # we need to joinedload the dag_run, since FileTaskHandler._render_filename needs ti.dag_run ) ti = session.scalar(query) @@ -138,24 +144,27 @@ def get_log( with contextlib.suppress(TaskNotFound): ti.task = dag.get_task(ti.task_id) - if accept == Mimetype.JSON or accept == Mimetype.ANY: # default - logs, metadata = task_log_reader.read_log_chunks(ti, try_number, metadata) - encoded_token = None + if accept == Mimetype.NDJSON: # only specified application/x-ndjson will return streaming response + # LogMetadata(TypedDict) is used as type annotation for log_reader; added ignore to suppress mypy error + log_stream = task_log_reader.read_log_stream(ti, try_number, metadata) # type: ignore[arg-type] + headers = None if not metadata.get("end_of_log", False): - encoded_token = URLSafeSerializer(request.app.state.secret_key).dumps(metadata) - return TaskInstancesLogResponse.model_construct(continuation_token=encoded_token, content=logs) - # text/plain, or something else we don't understand. Return raw log content - - # We need to exhaust the iterator before we can generate the continuation token. - # We could improve this by making it a streaming/async response, and by then setting the header using - # HTTP Trailers - logs = "".join(task_log_reader.read_log_stream(ti, try_number, metadata)) - headers = None - if not metadata.get("end_of_log", False): - headers = { - "Airflow-Continuation-Token": URLSafeSerializer(request.app.state.secret_key).dumps(metadata) - } - return Response(media_type="application/x-ndjson", content=logs, headers=headers) + headers = { + "Airflow-Continuation-Token": URLSafeSerializer(request.app.state.secret_key).dumps(metadata) + } + return StreamingResponse(media_type="application/x-ndjson", content=log_stream, headers=headers) + + # application/json, or something else we don't understand. + # Return JSON format, which will be more easily for users to debug. + + # LogMetadata(TypedDict) is used as type annotation for log_reader; added ignore to suppress mypy error + structured_log_stream, out_metadata = task_log_reader.read_log_chunks(ti, try_number, metadata) # type: ignore[arg-type] + encoded_token = None + if not out_metadata.get("end_of_log", False): + encoded_token = URLSafeSerializer(request.app.state.secret_key).dumps(out_metadata) + return TaskInstancesLogResponse.model_construct( + continuation_token=encoded_token, content=list(structured_log_stream) + ) @task_instances_log_router.get( diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index 9b086457f604d..ae151512e436e 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -19,51 +19,114 @@ from __future__ import annotations -import itertools +import heapq +import io import logging import os -from collections.abc import Callable, Iterable +from collections.abc import Callable, Generator, Iterator from contextlib import suppress from datetime import datetime from enum import Enum +from itertools import chain, islice from pathlib import Path -from typing import TYPE_CHECKING, Any +from types import GeneratorType +from typing import IO, TYPE_CHECKING, TypedDict, cast from urllib.parse import urljoin import pendulum from pydantic import BaseModel, ConfigDict, ValidationError +from typing_extensions import NotRequired from airflow.configuration import conf from airflow.executors.executor_loader import ExecutorLoader from airflow.utils.helpers import parse_template_string, render_template +from airflow.utils.log.log_stream_accumulator import LogStreamAccumulator from airflow.utils.log.logging_mixin import SetContextPropagate from airflow.utils.log.non_caching_file_handler import NonCachingRotatingFileHandler from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: + from requests import Response + from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.typing_compat import TypeAlias +CHUNK_SIZE = 1024 * 1024 * 5 # 5MB +DEFAULT_SORT_DATETIME = pendulum.datetime(2000, 1, 1) +DEFAULT_SORT_TIMESTAMP = int(DEFAULT_SORT_DATETIME.timestamp() * 1000) +SORT_KEY_OFFSET = 10000000 +"""An offset used by the _create_sort_key utility. + +Assuming 50 characters per line, an offset of 10,000,000 can represent approximately 500 MB of file data, which is sufficient for use as a constant. +""" +HEAP_DUMP_SIZE = 5000 +HALF_HEAP_DUMP_SIZE = HEAP_DUMP_SIZE // 2 # These types are similar, but have distinct names to make processing them less error prone -LogMessages: TypeAlias = list["StructuredLogMessage"] | list[str] -"""The log messages themselves, either in already sturcutured form, or a single string blob to be parsed later""" +LogMessages: TypeAlias = list[str] +"""The legacy format of log messages before 3.0.2""" LogSourceInfo: TypeAlias = list[str] """Information _about_ the log fetching process for display to a user""" -LogMetadata: TypeAlias = dict[str, Any] +RawLogStream: TypeAlias = Generator[str, None, None] +"""Raw log stream, containing unparsed log lines.""" +LegacyLogResponse: TypeAlias = tuple[LogSourceInfo, LogMessages] +"""Legacy log response, containing source information and log messages.""" +LogResponse: TypeAlias = tuple[LogSourceInfo, list[RawLogStream]] +LogResponseWithSize: TypeAlias = tuple[LogSourceInfo, list[RawLogStream], int] +"""Log response, containing source information, stream of log lines, and total log size.""" +StructuredLogStream: TypeAlias = Generator["StructuredLogMessage", None, None] +"""Structured log stream, containing structured log messages.""" +LogHandlerOutputStream: TypeAlias = ( + StructuredLogStream | Iterator["StructuredLogMessage"] | chain["StructuredLogMessage"] +) +"""Output stream, containing structured log messages or a chain of them.""" +ParsedLog: TypeAlias = tuple[datetime | None, int, "StructuredLogMessage"] +"""Parsed log record, containing timestamp, line_num and the structured log message.""" +ParsedLogStream: TypeAlias = Generator[ParsedLog, None, None] +LegacyProvidersLogType: TypeAlias = list["StructuredLogMessage"] | str | list[str] +"""Return type used by legacy `_read` methods for Alibaba Cloud, Elasticsearch, OpenSearch, and Redis log handlers. + +- For Elasticsearch and OpenSearch: returns either a list of structured log messages. +- For Alibaba Cloud: returns a string. +- For Redis: returns a list of strings. +""" + logger = logging.getLogger(__name__) +class LogMetadata(TypedDict): + """Metadata about the log fetching process, including `end_of_log` and `log_pos`.""" + + end_of_log: bool + log_pos: NotRequired[int] + # the following attributes are used for Elasticsearch and OpenSearch log handlers + offset: NotRequired[str | int] + # Ensure a string here. Large offset numbers will get JSON.parsed incorrectly + # on the client. Sending as a string prevents this issue. + # https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Number/MAX_SAFE_INTEGER + last_log_timestamp: NotRequired[str] + max_offset: NotRequired[str] + + class StructuredLogMessage(BaseModel): """An individual log message.""" timestamp: datetime | None = None event: str + # Collisions of sort_key may occur due to duplicated messages. If this happens, the heap will use the second element, + # which is the StructuredLogMessage for comparison. Therefore, we need to define a comparator for it. + def __lt__(self, other: StructuredLogMessage) -> bool: + return self.sort_key < other.sort_key + + @property + def sort_key(self) -> datetime: + return self.timestamp or DEFAULT_SORT_DATETIME + # We don't need to cache string when parsing in to this, as almost every line will have a different # values; `extra=allow` means we'll create extra properties as needed. Only timestamp and event are # required, everything else is up to what ever is producing the logs @@ -100,7 +163,7 @@ def _set_task_deferred_context_var(): h.ctx_task_deferred = True -def _fetch_logs_from_service(url, log_relative_path): +def _fetch_logs_from_service(url: str, log_relative_path: str) -> Response: # Import occurs in function scope for perf. Ref: https://github.com/apache/airflow/pull/21438 import requests @@ -111,7 +174,6 @@ def _fetch_logs_from_service(url, log_relative_path): secret_key=get_signing_key("api", "secret_key"), # Since we are using a secret key, we need to be explicit about the algorithm here too algorithm="HS512", - private_key=None, issuer=None, valid_for=conf.getint("webserver", "log_request_clock_grace", fallback=30), audience="task-instance-logs", @@ -120,6 +182,7 @@ def _fetch_logs_from_service(url, log_relative_path): url, timeout=timeout, headers={"Authorization": generator.generate({"filename": log_relative_path})}, + stream=True, ) response.encoding = "utf-8" return response @@ -134,28 +197,68 @@ def _parse_timestamp(line: str): return pendulum.parse(timestamp_str.strip("[]")) -def _parse_log_lines( - lines: str | LogMessages, -) -> Iterable[tuple[datetime | None, int, StructuredLogMessage]]: +def _stream_lines_by_chunk( + log_io: IO[str], +) -> RawLogStream: + """ + Stream lines from a file-like IO object. + + :param log_io: A file-like IO object to read from. + :return: A generator that yields individual lines within the specified range. + """ + # Skip processing if file is already closed + if log_io.closed: + return + + # Seek to beginning if possible + if log_io.seekable(): + try: + log_io.seek(0) + except Exception as e: + logger.error("Error seeking in log stream: %s", e) + return + + buffer = "" + while True: + # Check if file is already closed + if log_io.closed: + break + + try: + chunk = log_io.read(CHUNK_SIZE) + except Exception as e: + logger.error("Error reading log stream: %s", e) + break + + if not chunk: + break + + buffer += chunk + *lines, buffer = buffer.split("\n") + yield from lines + + if buffer: + yield from buffer.split("\n") + + +def _log_stream_to_parsed_log_stream( + log_stream: RawLogStream, +) -> ParsedLogStream: + """ + Turn a str log stream into a generator of parsed log lines. + + :param log_stream: The stream to parse. + :return: A generator of parsed log lines. + """ from airflow.utils.timezone import coerce_datetime timestamp = None next_timestamp = None - if isinstance(lines, str): - lines = lines.splitlines() - if isinstance(lines, list) and len(lines) and isinstance(lines[0], str): - # A list of content from each location. It's a super odd format, but this is what we load - # [['a\nb\n'], ['c\nd\ne\n']] -> ['a', 'b', 'c', 'd', 'e'] - lines = itertools.chain.from_iterable(map(str.splitlines, lines)) # type: ignore[assignment,arg-type] - - # https://github.com/python/mypy/issues/8586 - for idx, line in enumerate[str | StructuredLogMessage](lines): + idx = 0 + for line in log_stream: if line: try: - if isinstance(line, StructuredLogMessage): - log = line - else: - log = StructuredLogMessage.model_validate_json(line) + log = StructuredLogMessage.model_validate_json(line) except ValidationError: with suppress(Exception): # If we can't parse the timestamp, don't attach one to the row @@ -166,17 +269,146 @@ def _parse_log_lines( log.timestamp = coerce_datetime(log.timestamp) timestamp = log.timestamp yield timestamp, idx, log + idx += 1 + + +def _create_sort_key(timestamp: datetime | None, line_num: int) -> int: + """ + Create a sort key for log record, to be used in K-way merge. + + :param timestamp: timestamp of the log line + :param line_num: line number of the log line + :return: a integer as sort key to avoid overhead of memory usage + """ + return int((timestamp or DEFAULT_SORT_DATETIME).timestamp() * 1000) * SORT_KEY_OFFSET + line_num + + +def _is_sort_key_with_default_timestamp(sort_key: int) -> bool: + """ + Check if the sort key was generated with the DEFAULT_SORT_TIMESTAMP. + This is used to identify log records that don't have timestamp. -def _interleave_logs(*logs: str | LogMessages) -> Iterable[StructuredLogMessage]: - min_date = pendulum.datetime(2000, 1, 1) + :param sort_key: The sort key to check + :return: True if the sort key was generated with DEFAULT_SORT_TIMESTAMP, False otherwise + """ + # Extract the timestamp part from the sort key (remove the line number part) + timestamp_part = sort_key // SORT_KEY_OFFSET + return timestamp_part == DEFAULT_SORT_TIMESTAMP + + +def _add_log_from_parsed_log_streams_to_heap( + heap: list[tuple[int, StructuredLogMessage]], + parsed_log_streams: dict[int, ParsedLogStream], +) -> None: + """ + Add one log record from each parsed log stream to the heap, and will remove empty log stream from the dict after iterating. + + :param heap: heap to store log records + :param parsed_log_streams: dict of parsed log streams + """ + # We intend to initialize the list lazily, as in most cases we don't need to remove any log streams. + # This reduces memory overhead, since this function is called repeatedly until all log streams are empty. + log_stream_to_remove: list[int] | None = None + for idx, log_stream in parsed_log_streams.items(): + record: ParsedLog | None = next(log_stream, None) + if record is None: + if log_stream_to_remove is None: + log_stream_to_remove = [] + log_stream_to_remove.append(idx) + continue + # add type hint to avoid mypy error + record = cast("ParsedLog", record) + timestamp, line_num, line = record + # take int as sort key to avoid overhead of memory usage + heapq.heappush(heap, (_create_sort_key(timestamp, line_num), line)) + # remove empty log stream from the dict + if log_stream_to_remove is not None: + for idx in log_stream_to_remove: + del parsed_log_streams[idx] + + +def _flush_logs_out_of_heap( + heap: list[tuple[int, StructuredLogMessage]], + flush_size: int, + last_log_container: list[StructuredLogMessage | None], +) -> Generator[StructuredLogMessage, None, None]: + """ + Flush logs out of the heap, deduplicating them based on the last log. + + :param heap: heap to flush logs from + :param flush_size: number of logs to flush + :param last_log_container: a container to store the last log, to avoid duplicate logs + :return: a generator that yields deduplicated logs + """ + last_log = last_log_container[0] + for _ in range(flush_size): + sort_key, line = heapq.heappop(heap) + if line != last_log or _is_sort_key_with_default_timestamp(sort_key): # dedupe + yield line + last_log = line + # update the last log container with the last log + last_log_container[0] = last_log + + +def _interleave_logs(*log_streams: RawLogStream) -> StructuredLogStream: + """ + Merge parsed log streams using K-way merge. + + By yielding HALF_CHUNK_SIZE records when heap size exceeds CHUNK_SIZE, we can reduce the chance of messing up the global order. + Since there are multiple log streams, we can't guarantee that the records are in global order. + + e.g. + + log_stream1: ---------- + log_stream2: ---- + log_stream3: -------- + + The first record of log_stream3 is later than the fourth record of log_stream1 ! + :param parsed_log_streams: parsed log streams + :return: interleaved log stream + """ + # don't need to push whole tuple into heap, which increases too much overhead + # push only sort_key and line into heap + heap: list[tuple[int, StructuredLogMessage]] = [] + # to allow removing empty streams while iterating, also turn the str stream into parsed log stream + parsed_log_streams: dict[int, ParsedLogStream] = { + idx: _log_stream_to_parsed_log_stream(log_stream) for idx, log_stream in enumerate(log_streams) + } + + # keep adding records from logs until all logs are empty + last_log_container: list[StructuredLogMessage | None] = [None] + while parsed_log_streams: + _add_log_from_parsed_log_streams_to_heap(heap, parsed_log_streams) + + # yield HALF_HEAP_DUMP_SIZE records when heap size exceeds HEAP_DUMP_SIZE + if len(heap) >= HEAP_DUMP_SIZE: + yield from _flush_logs_out_of_heap(heap, HALF_HEAP_DUMP_SIZE, last_log_container) + + # yield remaining records + yield from _flush_logs_out_of_heap(heap, len(heap), last_log_container) + # free memory + del heap + del parsed_log_streams + + +def _is_logs_stream_like(log) -> bool: + """Check if the logs are stream-like.""" + return isinstance(log, (chain, GeneratorType)) + + +def _get_compatible_log_stream( + log_messages: LogMessages, +) -> RawLogStream: + """ + Convert legacy log message blobs into a generator that yields log lines. - records = itertools.chain.from_iterable(_parse_log_lines(log) for log in logs) - last = None - for timestamp, _, msg in sorted(records, key=lambda x: (x[0] or min_date, x[1])): - if msg != last or not timestamp: # dedupe - yield msg - last = msg + :param log_messages: List of legacy log message strings. + :return: A generator that yields interleaved log lines. + """ + yield from chain.from_iterable( + _stream_lines_by_chunk(io.StringIO(log_message)) for log_message in log_messages + ) class FileTaskHandler(logging.Handler): @@ -345,8 +577,8 @@ def _read( self, ti: TaskInstance | TaskInstanceHistory, try_number: int, - metadata: dict[str, Any] | None = None, - ): + metadata: LogMetadata | None = None, + ) -> tuple[LogHandlerOutputStream | LegacyProvidersLogType, LogMetadata]: """ Template method that contains custom logic of reading logs given the try_number. @@ -370,22 +602,38 @@ def _read( # initializing the handler. Thus explicitly getting log location # is needed to get correct log path. worker_log_rel_path = self._render_filename(ti, try_number) + sources: LogSourceInfo = [] source_list: list[str] = [] - remote_logs: LogMessages | None = [] - local_logs: list[str] = [] - sources: list[str] = [] - executor_logs: list[str] = [] - served_logs: LogMessages = [] + remote_logs: list[RawLogStream] = [] + local_logs: list[RawLogStream] = [] + executor_logs: list[RawLogStream] = [] + served_logs: list[RawLogStream] = [] with suppress(NotImplementedError): - sources, remote_logs = self._read_remote_logs(ti, try_number, metadata) - + sources, logs = self._read_remote_logs(ti, try_number, metadata) + if not logs: + remote_logs = [] + elif isinstance(logs, list) and isinstance(logs[0], str): + # If the logs are in legacy format, convert them to a generator of log lines + remote_logs = [ + # We don't need to use the log_pos here, as we are using the metadata to track the position + _get_compatible_log_stream(cast("list[str]", logs)) + ] + elif isinstance(logs, list) and _is_logs_stream_like(logs[0]): + # If the logs are already in a stream-like format, we can use them directly + remote_logs = cast("list[RawLogStream]", logs) + else: + # If the logs are in a different format, raise an error + raise TypeError("Logs should be either a list of strings or a generator of log lines.") + # Extend LogSourceInfo source_list.extend(sources) has_k8s_exec_pod = False if ti.state == TaskInstanceState.RUNNING: executor_get_task_log = self._get_executor_get_task_log(ti) response = executor_get_task_log(ti, try_number) if response: - sources, executor_logs = response + sources, logs = response + # make the logs stream-like compatible + executor_logs = [_get_compatible_log_stream(logs)] if sources: source_list.extend(sources) has_k8s_exec_pod = True @@ -404,15 +652,13 @@ def _read( sources, served_logs = self._read_from_logs_server(ti, worker_log_rel_path) source_list.extend(sources) - logs = list( - _interleave_logs( - *local_logs, - (remote_logs or []), - *(executor_logs or []), - *served_logs, - ) + out_stream: LogHandlerOutputStream = _interleave_logs( + *local_logs, + *remote_logs, + *executor_logs, + *served_logs, ) - log_pos = len(logs) + # Log message source details are grouped: they are not relevant for most users and can # distract them from finding the root cause of their errors header = [ @@ -423,12 +669,22 @@ def _read( TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, ) - if metadata and "log_pos" in metadata: - previous_line = metadata["log_pos"] - logs = logs[previous_line:] # Cut off previously passed log test as new tail - else: - logs = header + logs - return logs, {"end_of_log": end_of_log, "log_pos": log_pos} + + with LogStreamAccumulator(out_stream, HEAP_DUMP_SIZE) as stream_accumulator: + log_pos = stream_accumulator.total_lines + out_stream = stream_accumulator.stream + + # skip log stream until the last position + if metadata and "log_pos" in metadata: + islice(out_stream, metadata["log_pos"]) + else: + # first time reading log, add messages before interleaved log stream + out_stream = chain(header, out_stream) + + return out_stream, { + "end_of_log": end_of_log, + "log_pos": log_pos, + } @staticmethod @staticmethod @@ -469,8 +725,8 @@ def read( self, task_instance: TaskInstance | TaskInstanceHistory, try_number: int | None = None, - metadata: dict[str, Any] | None = None, - ) -> tuple[list[StructuredLogMessage] | str, dict[str, Any]]: + metadata: LogMetadata | None = None, + ) -> tuple[LogHandlerOutputStream, LogMetadata]: """ Read logs of given task instance from local machine. @@ -489,7 +745,7 @@ def read( event="Task was skipped, no logs available." ) ] - return logs, {"end_of_log": True} + return chain(logs), {"end_of_log": True} if try_number is None or try_number < 1: logs = [ @@ -497,9 +753,38 @@ def read( level="error", event=f"Error fetching the logs. Try number {try_number} is invalid." ) ] - return logs, {"end_of_log": True} - - return self._read(task_instance, try_number, metadata) + return chain(logs), {"end_of_log": True} + + # compatibility for es_task_handler and os_task_handler + read_result = self._read(task_instance, try_number, metadata) + out_stream, metadata = read_result + # If the out_stream is None or empty, return the read result + if not out_stream: + out_stream = cast("Generator[StructuredLogMessage, None, None]", out_stream) + return out_stream, metadata + + if _is_logs_stream_like(out_stream): + out_stream = cast("Generator[StructuredLogMessage, None, None]", out_stream) + return out_stream, metadata + if isinstance(out_stream, list) and isinstance(out_stream[0], StructuredLogMessage): + out_stream = cast("list[StructuredLogMessage]", out_stream) + return (log for log in out_stream), metadata + if isinstance(out_stream, list) and isinstance(out_stream[0], str): + # If the out_stream is a list of strings, convert it to a generator + out_stream = cast("list[str]", out_stream) + raw_stream = _stream_lines_by_chunk(io.StringIO("".join(out_stream))) + out_stream = (log for _, _, log in _log_stream_to_parsed_log_stream(raw_stream)) + return out_stream, metadata + if isinstance(out_stream, str): + # If the out_stream is a string, convert it to a generator + raw_stream = _stream_lines_by_chunk(io.StringIO(out_stream)) + out_stream = (log for _, _, log in _log_stream_to_parsed_log_stream(raw_stream)) + return out_stream, metadata + raise TypeError( + "Invalid log stream type. Expected a generator of StructuredLogMessage, list of StructuredLogMessage, list of str or str." + f" Got {type(out_stream).__name__} instead." + f" Content type: {type(out_stream[0]).__name__ if isinstance(out_stream, (list, tuple)) and out_stream else 'empty'}" + ) @staticmethod def _prepare_log_folder(directory: Path, new_folder_permissions: int): @@ -565,15 +850,28 @@ def _init_file(self, ti, *, identifier: str | None = None): return full_path @staticmethod - def _read_from_local(worker_log_path: Path) -> tuple[list[str], list[str]]: + def _read_from_local( + worker_log_path: Path, + ) -> LogResponse: + sources: LogSourceInfo = [] + log_streams: list[RawLogStream] = [] paths = sorted(worker_log_path.parent.glob(worker_log_path.name + "*")) - sources = [os.fspath(x) for x in paths] - logs = [file.read_text() for file in paths] - return sources, logs + if not paths: + return sources, log_streams - def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[LogSourceInfo, LogMessages]: - sources = [] - logs = [] + for path in paths: + sources.append(os.fspath(path)) + # Read the log file and yield lines + log_streams.append(_stream_lines_by_chunk(open(path, encoding="utf-8"))) + return sources, log_streams + + def _read_from_logs_server( + self, + ti: TaskInstance, + worker_log_rel_path: str, + ) -> LogResponse: + sources: LogSourceInfo = [] + log_streams: list[RawLogStream] = [] try: log_type = LogType.TRIGGER if ti.triggerer_job else LogType.WORKER url, rel_path = self._get_log_retrieval_url(ti, worker_log_rel_path, log_type=log_type) @@ -590,20 +888,26 @@ def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[LogSourceInfo else: # Check if the resource was properly fetched response.raise_for_status() - if response.text: + if int(response.headers.get("Content-Length", 0)) > 0: sources.append(url) - logs.append(response.text) + log_streams.append( + _stream_lines_by_chunk(io.TextIOWrapper(cast("IO[bytes]", response.raw))) + ) except Exception as e: from requests.exceptions import InvalidURL - if isinstance(e, InvalidURL) and ti.task.inherits_from_empty_operator is True: + if ( + isinstance(e, InvalidURL) + and ti.task is not None + and ti.task.inherits_from_empty_operator is True + ): sources.append(self.inherits_from_empty_operator_log_message) else: sources.append(f"Could not read served logs: {e}") logger.exception("Could not read served logs") - return sources, logs + return sources, log_streams - def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[LogSourceInfo, LogMessages]: + def _read_remote_logs(self, ti, try_number, metadata=None) -> LegacyLogResponse | LogResponse: """ Implement in subclasses to read from the remote service. diff --git a/airflow-core/src/airflow/utils/log/log_reader.py b/airflow-core/src/airflow/utils/log/log_reader.py index 0bb61c52dbc57..9f61c2f730c36 100644 --- a/airflow-core/src/airflow/utils/log/log_reader.py +++ b/airflow-core/src/airflow/utils/log/log_reader.py @@ -18,13 +18,12 @@ import logging import time -from collections.abc import Iterator +from collections.abc import Generator, Iterator from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from airflow.configuration import conf from airflow.utils.helpers import render_log_filename -from airflow.utils.log.file_task_handler import StructuredLogMessage from airflow.utils.log.logging_mixin import ExternalLoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import TaskInstanceState @@ -35,9 +34,11 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.typing_compat import TypeAlias + from airflow.utils.log.file_task_handler import LogHandlerOutputStream, LogMetadata -LogMessages: TypeAlias = list[StructuredLogMessage] | str -LogMetadata: TypeAlias = dict[str, Any] +LogReaderOutputStream: TypeAlias = Generator[str, None, None] + +READ_BATCH_SIZE = 1024 class TaskLogReader: @@ -54,7 +55,7 @@ def read_log_chunks( ti: TaskInstance | TaskInstanceHistory, try_number: int | None, metadata: LogMetadata, - ) -> tuple[LogMessages, LogMetadata]: + ) -> tuple[LogHandlerOutputStream, LogMetadata]: """ Read chunks of Task Instance logs. @@ -92,24 +93,19 @@ def read_log_stream( try_number = ti.try_number for key in ("end_of_log", "max_offset", "offset", "log_pos"): - metadata.pop(key, None) + # https://mypy.readthedocs.io/en/stable/typed_dict.html#supported-operations + metadata.pop(key, None) # type: ignore[misc] empty_iterations = 0 while True: - logs, out_metadata = self.read_log_chunks(ti, try_number, metadata) - # Update the metadata dict in place so caller can get new values/end-of-log etc. - - for log in logs: - # It's a bit wasteful here to parse the JSON then dump it back again. - # Optimize this so in stream mode we can just pass logs right through, or even better add - # support to 307 redirect to a signed URL etc. - yield (log if isinstance(log, str) else log.model_dump_json()) + "\n" + log_stream, out_metadata = self.read_log_chunks(ti, try_number, metadata) + yield from (f"{log.model_dump_json()}\n" for log in log_stream) if not out_metadata.get("end_of_log", False) and ti.state not in ( TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, ): - if logs: + if log_stream: empty_iterations = 0 else: # we did not receive any logs in this loop @@ -121,7 +117,8 @@ def read_log_stream( yield "(Log stream stopped - End of log marker not found; logs may be incomplete.)\n" return else: - metadata.clear() + # https://mypy.readthedocs.io/en/stable/typed_dict.html#supported-operations + metadata.clear() # type: ignore[attr-defined] metadata.update(out_metadata) return diff --git a/airflow-core/src/airflow/utils/log/log_stream_accumulator.py b/airflow-core/src/airflow/utils/log/log_stream_accumulator.py new file mode 100644 index 0000000000000..953b47dd9719b --- /dev/null +++ b/airflow-core/src/airflow/utils/log/log_stream_accumulator.py @@ -0,0 +1,155 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +import tempfile +from itertools import islice +from typing import IO, TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.typing_compat import Self + from airflow.utils.log.file_task_handler import ( + LogHandlerOutputStream, + StructuredLogMessage, + StructuredLogStream, + ) + + +class LogStreamAccumulator: + """ + Memory-efficient log stream accumulator that tracks the total number of lines while preserving the original stream. + + This class captures logs from a stream and stores them in a buffer, flushing them to disk when the buffer + exceeds a specified threshold. This approach optimizes memory usage while handling large log streams. + + Usage: + + .. code-block:: python + + with LogStreamAccumulator(stream, threshold) as log_accumulator: + # Get total number of lines captured + total_lines = log_accumulator.get_total_lines() + + # Retrieve the original stream of logs + for log in log_accumulator.get_stream(): + print(log) + """ + + def __init__( + self, + stream: LogHandlerOutputStream, + threshold: int, + ) -> None: + """ + Initialize the LogStreamAccumulator. + + Args: + stream: The input log stream to capture and count. + threshold: Maximum number of lines to keep in memory before flushing to disk. + """ + self._stream = stream + self._threshold = threshold + self._buffer: list[StructuredLogMessage] = [] + self._disk_lines: int = 0 + self._tmpfile: IO[str] | None = None + + def _flush_buffer_to_disk(self) -> None: + """Flush the buffer contents to a temporary file on disk.""" + if self._tmpfile is None: + self._tmpfile = tempfile.NamedTemporaryFile(delete=False, mode="w+", encoding="utf-8") + + self._disk_lines += len(self._buffer) + self._tmpfile.writelines(f"{log.model_dump_json()}\n" for log in self._buffer) + self._tmpfile.flush() + self._buffer.clear() + + def _capture(self) -> None: + """Capture logs from the stream into the buffer, flushing to disk when threshold is reached.""" + while True: + # `islice` will try to get up to `self._threshold` lines from the stream. + self._buffer.extend(islice(self._stream, self._threshold)) + # If there are no more lines to capture, exit the loop. + if len(self._buffer) < self._threshold: + break + self._flush_buffer_to_disk() + + def _cleanup(self) -> None: + """Clean up the temporary file if it exists.""" + self._buffer.clear() + if self._tmpfile: + self._tmpfile.close() + os.remove(self._tmpfile.name) + self._tmpfile = None + + @property + def total_lines(self) -> int: + """ + Return the total number of lines captured from the stream. + + Returns: + The sum of lines stored in the buffer and lines written to disk. + """ + return self._disk_lines + len(self._buffer) + + @property + def stream(self) -> StructuredLogStream: + """ + Return the original stream of logs and clean up resources. + + Important: This method automatically cleans up resources after all logs have been yielded. + Make sure to fully consume the returned generator to ensure proper cleanup. + + Returns: + A stream of the captured log messages. + """ + try: + if not self._tmpfile: + # if no temporary file was created, return from the buffer + yield from self._buffer + else: + # avoid circular import + from airflow.utils.log.file_task_handler import StructuredLogMessage + + with open(self._tmpfile.name, encoding="utf-8") as f: + yield from (StructuredLogMessage.model_validate_json(line.strip()) for line in f) + # yield the remaining buffer + yield from self._buffer + finally: + # Ensure cleanup after yielding + self._cleanup() + + def __enter__(self) -> Self: + """ + Context manager entry point that initiates log capture. + + Returns: + Self instance for use in context manager. + """ + self._capture() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """ + Context manager exit that doesn't perform resource cleanup. + + Note: Resources are not cleaned up here. Cleanup is deferred until + get_stream() is called and fully consumed, ensuring all logs are properly + yielded before cleanup occurs. + """ diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py index 64dedcfd328e2..adb5aebcb2106 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py @@ -18,6 +18,7 @@ from __future__ import annotations import copy +import json import logging.config import sys from unittest import mock @@ -36,6 +37,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.db import clear_db_runs +from tests_common.test_utils.file_task_handler import convert_list_to_stream pytestmark = [pytest.mark.db_test, pytest.mark.need_serialized_dag] @@ -233,6 +235,12 @@ def test_should_respond_200_ndjson(self, request_url, expected_filename, extra_q assert expected_filename in resp_content assert log_content in resp_content + # check content is in ndjson format + for line in resp_content.splitlines(): + log = json.loads(line) + assert "event" in log + assert "timestamp" in log + @pytest.mark.parametrize( "request_url, expected_filename, extra_query_string, try_number", [ @@ -304,11 +312,22 @@ def test_get_logs_response_with_ti_equal_to_none(self, try_number): @pytest.mark.parametrize("try_number", [1, 2]) def test_get_logs_with_metadata_as_download_large_file(self, try_number): + from airflow.utils.log.file_task_handler import StructuredLogMessage + with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as read_mock: - first_return = (["", "1st line"], {}) - second_return = (["", "2nd line"], {"end_of_log": False}) - third_return = (["", "3rd line"], {"end_of_log": True}) - fourth_return = (["", "should never be read"], {"end_of_log": True}) + first_return = (convert_list_to_stream([StructuredLogMessage(event="", message="1st line")]), {}) + second_return = ( + convert_list_to_stream([StructuredLogMessage(event="", message="2nd line")]), + {"end_of_log": False}, + ) + third_return = ( + convert_list_to_stream([StructuredLogMessage(event="", message="3rd line")]), + {"end_of_log": True}, + ) + fourth_return = ( + convert_list_to_stream([StructuredLogMessage(event="", message="should never be read")]), + {"end_of_log": True}, + ) read_mock.side_effect = [first_return, second_return, third_return, fourth_return] response = self.client.get( diff --git a/airflow-core/tests/unit/utils/log/test_log_reader.py b/airflow-core/tests/unit/utils/log/test_log_reader.py index 5ca675f8d08c4..cd8b430090a27 100644 --- a/airflow-core/tests/unit/utils/log/test_log_reader.py +++ b/airflow-core/tests/unit/utils/log/test_log_reader.py @@ -41,6 +41,7 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_dags, clear_db_runs +from tests_common.test_utils.file_task_handler import convert_list_to_stream pytestmark = pytest.mark.db_test @@ -127,6 +128,7 @@ def test_test_read_log_chunks_should_read_one_try(self): ti.state = TaskInstanceState.SUCCESS logs, metadata = task_log_reader.read_log_chunks(ti=ti, try_number=1, metadata={}) + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == [ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/1.log" @@ -141,6 +143,7 @@ def test_test_read_log_chunks_should_read_latest_files(self): ti.state = TaskInstanceState.SUCCESS logs, metadata = task_log_reader.read_log_chunks(ti=ti, try_number=None, metadata={}) + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == [ f"{self.log_dir}/dag_log_reader/task_log_reader/2017-09-01T00.00.00+00.00/3.log" @@ -180,16 +183,31 @@ def test_test_test_read_log_stream_should_read_latest_logs(self): @mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") def test_read_log_stream_should_support_multiple_chunks(self, mock_read): - first_return = (["1st line"], {}) - second_return = (["2nd line"], {"end_of_log": False}) - third_return = (["3rd line"], {"end_of_log": True}) - fourth_return = (["should never be read"], {"end_of_log": True}) + from airflow.utils.log.file_task_handler import StructuredLogMessage + + first_return = (convert_list_to_stream([StructuredLogMessage(event="1st line")]), {}) + second_return = ( + convert_list_to_stream([StructuredLogMessage(event="2nd line")]), + {"end_of_log": False}, + ) + third_return = ( + convert_list_to_stream([StructuredLogMessage(event="3rd line")]), + {"end_of_log": True}, + ) + fourth_return = ( + convert_list_to_stream([StructuredLogMessage(event="should never be read")]), + {"end_of_log": True}, + ) mock_read.side_effect = [first_return, second_return, third_return, fourth_return] task_log_reader = TaskLogReader() self.ti.state = TaskInstanceState.SUCCESS log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=1, metadata={}) - assert list(log_stream) == ["1st line\n", "2nd line\n", "3rd line\n"] + assert list(log_stream) == [ + '{"timestamp":null,"event":"1st line"}\n', + '{"timestamp":null,"event":"2nd line"}\n', + '{"timestamp":null,"event":"3rd line"}\n', + ] # as the metadata is now updated in place, when the latest run update metadata. # the metadata stored in the mock_read will also be updated @@ -205,11 +223,18 @@ def test_read_log_stream_should_support_multiple_chunks(self, mock_read): @mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") def test_read_log_stream_should_read_each_try_in_turn(self, mock_read): - mock_read.side_effect = [(["try_number=3."], {"end_of_log": True})] + from airflow.utils.log.file_task_handler import StructuredLogMessage + + mock_read.side_effect = [ + ( + convert_list_to_stream([StructuredLogMessage(event="try_number=3.")]), + {"end_of_log": True}, + ) + ] task_log_reader = TaskLogReader() log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=None, metadata={}) - assert list(log_stream) == ["try_number=3.\n"] + assert list(log_stream) == ['{"timestamp":null,"event":"try_number=3."}\n'] mock_read.assert_has_calls( [ @@ -220,8 +245,10 @@ def test_read_log_stream_should_read_each_try_in_turn(self, mock_read): @mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") def test_read_log_stream_no_end_of_log_marker(self, mock_read): + from airflow.utils.log.file_task_handler import StructuredLogMessage + mock_read.side_effect = [ - (["hello"], {"end_of_log": False}), + ([StructuredLogMessage(event="hello")], {"end_of_log": False}), *[([], {"end_of_log": False}) for _ in range(10)], ] @@ -230,7 +257,7 @@ def test_read_log_stream_no_end_of_log_marker(self, mock_read): task_log_reader.STREAM_LOOP_SLEEP_SECONDS = 0.001 # to speed up the test log_stream = task_log_reader.read_log_stream(ti=self.ti, try_number=1, metadata={}) assert list(log_stream) == [ - "hello\n", + '{"timestamp":null,"event":"hello"}\n', "(Log stream stopped - End of log marker not found; logs may be incomplete.)\n", ] assert mock_read.call_count == 11 diff --git a/airflow-core/tests/unit/utils/log/test_stream_accumulator.py b/airflow-core/tests/unit/utils/log/test_stream_accumulator.py new file mode 100644 index 0000000000000..fd2856d851e4e --- /dev/null +++ b/airflow-core/tests/unit/utils/log/test_stream_accumulator.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING +from unittest import mock + +import pendulum +import pytest + +from airflow.utils.log.file_task_handler import StructuredLogMessage +from airflow.utils.log.log_stream_accumulator import LogStreamAccumulator + +if TYPE_CHECKING: + from airflow.utils.log.file_task_handler import LogHandlerOutputStream + +LOG_START_DATETIME = pendulum.datetime(2023, 10, 1, 0, 0, 0) +LOG_COUNT = 20 + + +class TestLogStreamAccumulator: + """Test cases for the LogStreamAccumulator class.""" + + @pytest.fixture + def structured_logs(self): + """Create a stream of mock structured log messages.""" + + def generate_logs(): + yield from ( + StructuredLogMessage( + event=f"test_event_{i + 1}", + timestamp=LOG_START_DATETIME.add(seconds=i), + level="INFO", + message=f"Test log message {i + 1}", + ) + for i in range(LOG_COUNT) + ) + + return generate_logs() + + def validate_log_stream(self, log_stream: LogHandlerOutputStream): + """Validate the log stream by checking the number of lines.""" + + count = 0 + for i, log in enumerate(log_stream): + assert log.event == f"test_event_{i + 1}" + assert log.timestamp == LOG_START_DATETIME.add(seconds=i) + count += 1 + assert count == 20 + + def test__capture(self, structured_logs): + """Test that temporary file is properly cleaned up during get_stream, not when exiting context.""" + + accumulator = LogStreamAccumulator(structured_logs, 5) + with ( + mock.patch.object(accumulator, "_capture") as mock_setup, + ): + with accumulator: + mock_setup.assert_called_once() + + def test__flush_buffer_to_disk(self, structured_logs): + """Test flush-to-disk behavior with a small threshold.""" + threshold = 6 + + # Mock the temporary file to verify it's being written to + with ( + mock.patch("tempfile.NamedTemporaryFile") as mock_tmpfile, + ): + mock_file = mock.MagicMock() + mock_tmpfile.return_value = mock_file + + with LogStreamAccumulator(structured_logs, threshold) as accumulator: + mock_tmpfile.assert_called_once_with( + delete=False, + mode="w+", + encoding="utf-8", + ) + # Verify _flush_buffer_to_disk was called multiple times + # (20 logs / 6 threshold = 3 flushes + 2 remaining logs in buffer) + assert accumulator._disk_lines == 18 + assert mock_file.writelines.call_count == 3 + assert len(accumulator._buffer) == 2 + + @pytest.mark.parametrize( + "threshold", + [ + pytest.param(30, id="buffer_only"), + pytest.param(5, id="flush_to_disk"), + ], + ) + def test_get_stream(self, structured_logs, threshold): + """Test that stream property returns all logs regardless of whether they were flushed to disk.""" + + tmpfile_name = None + with LogStreamAccumulator(structured_logs, threshold) as accumulator: + out_stream = accumulator.stream + + # Check if the temporary file was created + if threshold < LOG_COUNT: + tmpfile_name = accumulator._tmpfile.name + assert os.path.exists(tmpfile_name) + else: + assert accumulator._tmpfile is None + + # Validate the log stream + self.validate_log_stream(out_stream) + + # Verify temp file was created and cleaned up + if threshold < LOG_COUNT: + assert accumulator._tmpfile is None + assert not os.path.exists(tmpfile_name) if tmpfile_name else True + + @pytest.mark.parametrize( + "threshold, expected_buffer_size, expected_disk_lines", + [ + pytest.param(30, 20, 0, id="no_flush_needed"), + pytest.param(10, 0, 20, id="single_flush_needed"), + pytest.param(3, 2, 18, id="multiple_flushes_needed"), + ], + ) + def test_total_lines(self, structured_logs, threshold, expected_buffer_size, expected_disk_lines): + """Test that LogStreamAccumulator correctly counts lines across buffer and disk.""" + + with LogStreamAccumulator(structured_logs, threshold) as accumulator: + # Check buffer and disk line counts + assert len(accumulator._buffer) == expected_buffer_size + assert accumulator._disk_lines == expected_disk_lines + # Validate the log stream and line counts + self.validate_log_stream(accumulator.stream) + + def test__cleanup(self, structured_logs): + """Test that cleanup happens when stream property is fully consumed, not on context exit.""" + + accumulator = LogStreamAccumulator(structured_logs, 5) + with mock.patch.object(accumulator, "_cleanup") as mock_cleanup: + with accumulator: + # _cleanup should not be called yet + mock_cleanup.assert_not_called() + + # Get the stream but don't iterate through it yet + stream = accumulator.stream + mock_cleanup.assert_not_called() + + # Now iterate through the stream + for _ in stream: + pass + + # After fully consuming the stream, cleanup should be called + mock_cleanup.assert_called_once() diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py b/airflow-core/tests/unit/utils/test_log_handlers.py index 8367aed97ce67..9c32e9b232aa6 100644 --- a/airflow-core/tests/unit/utils/test_log_handlers.py +++ b/airflow-core/tests/unit/utils/test_log_handlers.py @@ -17,15 +17,17 @@ # under the License. from __future__ import annotations +import heapq +import io import itertools import logging import logging.config import os import re -from collections.abc import Iterable from http import HTTPStatus from importlib import reload from pathlib import Path +from typing import cast from unittest import mock from unittest.mock import patch @@ -47,43 +49,42 @@ from airflow.models.trigger import Trigger from airflow.providers.standard.operators.python import PythonOperator from airflow.utils.log.file_task_handler import ( + DEFAULT_SORT_DATETIME, FileTaskHandler, LogType, + ParsedLogStream, StructuredLogMessage, + _add_log_from_parsed_log_streams_to_heap, + _create_sort_key, _fetch_logs_from_service, + _flush_logs_out_of_heap, _interleave_logs, - _parse_log_lines, + _is_logs_stream_like, + _is_sort_key_with_default_timestamp, + _log_stream_to_parsed_log_stream, + _stream_lines_by_chunk, ) from airflow.utils.log.logging_mixin import set_context from airflow.utils.net import get_hostname from airflow.utils.session import create_session from airflow.utils.state import State, TaskInstanceState -from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.file_task_handler import ( + convert_list_to_stream, + extract_events, + mock_parsed_logs_factory, +) from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker pytestmark = pytest.mark.db_test -DEFAULT_DATE = datetime(2016, 1, 1) +DEFAULT_DATE = pendulum.datetime(2016, 1, 1) TASK_LOGGER = "airflow.task" FILE_TASK_HANDLER = "task" -def events(logs: Iterable[StructuredLogMessage], skip_source_info=True) -> list[str]: - """Helper function to return just the event (a.k.a message) from a list of StructuredLogMessage""" - logs = iter(logs) - if skip_source_info: - - def is_source_group(log: StructuredLogMessage): - return not hasattr(log, "timestamp") or log.event == "::endgroup" - - logs = itertools.dropwhile(is_source_group, logs) - - return [s.event for s in logs] - - class TestFileTaskLogHandler: def clean_up(self): with create_session() as session: @@ -146,9 +147,11 @@ def task_callable(ti): assert hasattr(file_handler, "read") # Return value of read must be a tuple of list and list. # passing invalid `try_number` to read function - log, metadata = file_handler.read(ti, 0) + log_handler_output_stream, metadata = file_handler.read(ti, 0) assert isinstance(metadata, dict) - assert log[0].event == "Error fetching the logs. Try number 0 is invalid." + assert extract_events(log_handler_output_stream) == [ + "Error fetching the logs. Try number 0 is invalid." + ] # Remove the generated tmp log file. os.remove(log_filename) @@ -185,7 +188,8 @@ def task_callable(ti): assert log_filename.endswith("0.log"), log_filename # Return value of read must be a tuple of list and list. - logs, metadata = file_handler.read(ti) + log_handler_output_stream, metadata = file_handler.read(ti) + logs = list(log_handler_output_stream) assert logs[0].event == "Task was skipped, no logs available." # Remove the generated tmp log file. @@ -228,14 +232,16 @@ def task_callable(ti): file_handler.close() assert hasattr(file_handler, "read") - log, metadata = file_handler.read(ti, 1) + log_handler_output_stream, metadata = file_handler.read(ti, 1) assert isinstance(metadata, dict) target_re = re.compile(r"\A\[[^\]]+\] {test_log_handlers.py:\d+} INFO - test\Z") # We should expect our log line from the callable above to appear in # the logs we read back - assert any(re.search(target_re, e) for e in events(log)), "Logs were " + str(log) + assert any(re.search(target_re, e) for e in extract_events(log_handler_output_stream)), ( + f"Logs were {log_handler_output_stream}" + ) # Remove the generated tmp log file. os.remove(log_filename) @@ -356,8 +362,8 @@ def task_callable(ti): logger.info("Test") # Return value of read must be a tuple of list and list. - logs, metadata = file_handler.read(ti) - assert isinstance(logs, list) + log_handler_output_stream, metadata = file_handler.read(ti) + assert _is_logs_stream_like(log_handler_output_stream) # Logs for running tasks should show up too. assert isinstance(metadata, dict) @@ -435,7 +441,7 @@ def task_callable(ti): assert find_rotate_log_1 is True # Logs for running tasks should show up too. - assert isinstance(logs, list) + assert _is_logs_stream_like(logs) # Remove the two generated tmp log files. os.remove(log_filename) @@ -450,7 +456,7 @@ def test__read_when_local(self, mock_read_local, create_task_instance): path = Path( "dag_id=dag_for_testing_local_log_read/run_id=scheduled__2016-01-01T00:00:00+00:00/task_id=task_for_testing_local_log_read/attempt=1.log" ) - mock_read_local.return_value = (["the messages"], ["the log"]) + mock_read_local.return_value = (["the messages"], [convert_list_to_stream(["the log"])]) local_log_file_read = create_task_instance( dag_id="dag_for_testing_local_log_read", task_id="task_for_testing_local_log_read", @@ -458,24 +464,23 @@ def test__read_when_local(self, mock_read_local, create_task_instance): logical_date=DEFAULT_DATE, ) fth = FileTaskHandler("") - logs, metadata = fth._read(ti=local_log_file_read, try_number=1) + log_handler_output_stream, metadata = fth._read(ti=local_log_file_read, try_number=1) mock_read_local.assert_called_with(path) - as_text = events(logs) - assert logs[0].sources == ["the messages"] - assert as_text[-1] == "the log" + assert extract_events(log_handler_output_stream) == ["the log"] assert metadata == {"end_of_log": True, "log_pos": 1} def test__read_from_local(self, tmp_path): """Tests the behavior of method _read_from_local""" path1 = tmp_path / "hello1.log" path2 = tmp_path / "hello1.log.suffix.log" - path1.write_text("file1 content") - path2.write_text("file2 content") + path1.write_text("file1 content\nfile1 content2") + path2.write_text("file2 content\nfile2 content2") fth = FileTaskHandler("") - assert fth._read_from_local(path1) == ( - [str(path1), str(path2)], - ["file1 content", "file2 content"], - ) + log_source_info, log_streams = fth._read_from_local(path1) + assert log_source_info == [str(path1), str(path2)] + assert len(log_streams) == 2 + assert list(log_streams[0]) == ["file1 content", "file1 content2"] + assert list(log_streams[1]) == ["file2 content", "file2 content2"] @pytest.mark.parametrize( "remote_logs, local_logs, served_logs_checked", @@ -505,32 +510,57 @@ def test__read_served_logs_checked_when_done_and_no_local_or_remote_logs( logical_date=DEFAULT_DATE, ) ti.state = TaskInstanceState.SUCCESS # we're testing scenario when task is done + expected_logs = ["::group::Log message source details", "::endgroup::"] with conf_vars({("core", "executor"): executor_name}): reload(executor_loader) fth = FileTaskHandler("") if remote_logs: fth._read_remote_logs = mock.Mock() fth._read_remote_logs.return_value = ["found remote logs"], ["remote\nlog\ncontent"] + expected_logs.extend( + [ + "remote", + "log", + "content", + ] + ) if local_logs: fth._read_from_local = mock.Mock() - fth._read_from_local.return_value = ["found local logs"], ["local\nlog\ncontent"] + fth._read_from_local.return_value = ( + ["found local logs"], + [convert_list_to_stream("local\nlog\ncontent".splitlines())], + ) + # only when not read from remote and TI is unfinished will read from local + if not remote_logs: + expected_logs.extend( + [ + "local", + "log", + "content", + ] + ) fth._read_from_logs_server = mock.Mock() - fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"] + fth._read_from_logs_server.return_value = ( + ["this message"], + [convert_list_to_stream("this\nlog\ncontent".splitlines())], + ) + # only when not read from remote and not read from local will read from logs server + if served_logs_checked: + expected_logs.extend( + [ + "this", + "log", + "content", + ] + ) + logs, metadata = fth._read(ti=ti, try_number=1) if served_logs_checked: fth._read_from_logs_server.assert_called_once() - assert events(logs) == [ - "::group::Log message source details", - "::endgroup::", - "this", - "log", - "content", - ] - assert metadata == {"end_of_log": True, "log_pos": 3} else: fth._read_from_logs_server.assert_not_called() - assert logs - assert metadata + assert extract_events(logs, False) == expected_logs + assert metadata == {"end_of_log": True, "log_pos": 3} def test_add_triggerer_suffix(self): sample = "any/path/to/thing.txt" @@ -738,11 +768,119 @@ def test_log_retrieval_valid_trigger(self, create_task_instance): """ -def test_parse_timestamps(): - actual = [] - for timestamp, _, _ in _parse_log_lines(log_sample.splitlines()): - actual.append(timestamp) - assert actual == [ +@pytest.mark.parametrize( + "chunk_size, expected_read_calls", + [ + (10, 4), + (20, 3), + # will read all logs in one call, but still need another call to get empty string at the end to escape the loop + (50, 2), + (100, 2), + ], +) +def test__stream_lines_by_chunk(chunk_size, expected_read_calls): + # Mock CHUNK_SIZE to a smaller value to test + with mock.patch("airflow.utils.log.file_task_handler.CHUNK_SIZE", chunk_size): + log_io = io.StringIO("line1\nline2\nline3\nline4\n") + log_io.read = mock.MagicMock(wraps=log_io.read) + + # Stream lines using the function + streamed_lines = list(_stream_lines_by_chunk(log_io)) + + # Verify the output matches the input split by lines + expected_output = ["line1", "line2", "line3", "line4"] + assert log_io.read.call_count == expected_read_calls, ( + f"Expected {expected_read_calls} calls to read, got {log_io.read.call_count}" + ) + assert streamed_lines == expected_output, f"Expected {expected_output}, got {streamed_lines}" + + +@pytest.mark.parametrize( + "seekable", + [ + pytest.param(True, id="seekable_stream"), + pytest.param(False, id="non_seekable_stream"), + ], +) +@pytest.mark.parametrize( + "closed", + [ + pytest.param(False, id="not_closed_stream"), + pytest.param(True, id="closed_stream"), + ], +) +@pytest.mark.parametrize( + "unexpected_exception", + [ + pytest.param(None, id="no_exception"), + pytest.param(ValueError, id="value_error"), + pytest.param(IOError, id="io_error"), + pytest.param(Exception, id="generic_exception"), + ], +) +@mock.patch( + "airflow.utils.log.file_task_handler.CHUNK_SIZE", 10 +) # Mock CHUNK_SIZE to a smaller value for testing +def test__stream_lines_by_chunk_error_handling(seekable, closed, unexpected_exception): + """ + Test that _stream_lines_by_chunk handles errors correctly. + """ + log_io = io.StringIO("line1\nline2\nline3\nline4\n") + log_io.seekable = mock.MagicMock(return_value=seekable) + log_io.seek = mock.MagicMock(wraps=log_io.seek) + # Mock the read method to check the call count and handle exceptions + if unexpected_exception: + expected_error = unexpected_exception("An error occurred while reading the log stream.") + log_io.read = mock.MagicMock(side_effect=expected_error) + else: + log_io.read = mock.MagicMock(wraps=log_io.read) + + # Setup closed state if needed - must be done before starting the test + if closed: + log_io.close() + + # If an exception is expected, we mock the read method to raise it + if unexpected_exception and not closed: + # Only expect logger error if stream is not closed and there's an exception + with mock.patch("airflow.utils.log.file_task_handler.logger.error") as mock_logger_error: + result = list(_stream_lines_by_chunk(log_io)) + mock_logger_error.assert_called_once_with("Error reading log stream: %s", expected_error) + else: + # For normal case or closed stream with exception, collect the output + result = list(_stream_lines_by_chunk(log_io)) + + # Check if seekable was called properly + if seekable and not closed: + log_io.seek.assert_called_once_with(0) + if not seekable: + log_io.seek.assert_not_called() + + # Validate the results based on the conditions + if not closed and not unexpected_exception: # Non-seekable streams without errors should still get lines + assert log_io.read.call_count > 1, "Expected read method to be called at least once." + assert result == ["line1", "line2", "line3", "line4"] + elif closed: + assert log_io.read.call_count == 0, "Read method should not be called on a closed stream." + assert result == [], "Expected no lines to be yield from a closed stream." + elif unexpected_exception: # If an exception was raised + assert log_io.read.call_count == 1, "Read method should be called once." + assert result == [], "Expected no lines to be yield from a stream that raised an exception." + + +def test__log_stream_to_parsed_log_stream(): + parsed_log_stream = _log_stream_to_parsed_log_stream(io.StringIO(log_sample)) + + actual_timestamps = [] + last_idx = -1 + for parsed_log in parsed_log_stream: + timestamp, idx, structured_log = parsed_log + actual_timestamps.append(timestamp) + if last_idx != -1: + assert idx > last_idx + last_idx = idx + assert isinstance(structured_log, StructuredLogMessage) + + assert actual_timestamps == [ pendulum.parse("2022-11-16T00:05:54.278000-08:00"), pendulum.parse("2022-11-16T00:05:54.278000-08:00"), pendulum.parse("2022-11-16T00:05:54.278000-08:00"), @@ -766,34 +904,249 @@ def test_parse_timestamps(): ] +def test__create_sort_key(): + # assert _sort_key should return int + sort_key = _create_sort_key(pendulum.parse("2022-11-16T00:05:54.278000-08:00"), 10) + assert sort_key == 16685859542780000010 + + +@pytest.mark.parametrize( + "timestamp, line_num, expected", + [ + pytest.param( + pendulum.parse("2022-11-16T00:05:54.278000-08:00"), + 10, + False, + id="normal_timestamp_1", + ), + pytest.param( + pendulum.parse("2022-11-16T00:05:54.457000-08:00"), + 2025, + False, + id="normal_timestamp_2", + ), + pytest.param( + DEFAULT_SORT_DATETIME, + 200, + True, + id="default_timestamp", + ), + ], +) +def test__is_sort_key_with_default_timestamp(timestamp, line_num, expected): + assert _is_sort_key_with_default_timestamp(_create_sort_key(timestamp, line_num)) == expected + + +@pytest.mark.parametrize( + "log_stream, expected", + [ + pytest.param( + convert_list_to_stream( + [ + "2022-11-16T00:05:54.278000-08:00", + "2022-11-16T00:05:54.457000-08:00", + ] + ), + True, + id="normal_log_stream", + ), + pytest.param( + itertools.chain( + [ + "2022-11-16T00:05:54.278000-08:00", + "2022-11-16T00:05:54.457000-08:00", + ], + convert_list_to_stream( + [ + "2022-11-16T00:05:54.278000-08:00", + "2022-11-16T00:05:54.457000-08:00", + ] + ), + ), + True, + id="chain_log_stream", + ), + pytest.param( + [ + "2022-11-16T00:05:54.278000-08:00", + "2022-11-16T00:05:54.457000-08:00", + ], + False, + id="non_stream_log", + ), + ], +) +def test__is_logs_stream_like(log_stream, expected): + assert _is_logs_stream_like(log_stream) == expected + + +def test__add_log_from_parsed_log_streams_to_heap(): + """ + Test cases: + + Timestamp: 26 27 28 29 30 31 + Source 1: -- + Source 2: -- -- + Source 3: -- -- -- + """ + heap: list[tuple[int, StructuredLogMessage]] = [] + input_parsed_log_streams: dict[int, ParsedLogStream] = { + 0: convert_list_to_stream( + mock_parsed_logs_factory("Source 1", pendulum.parse("2022-11-16T00:05:54.270000-08:00"), 1) + ), + 1: convert_list_to_stream( + mock_parsed_logs_factory("Source 2", pendulum.parse("2022-11-16T00:05:54.290000-08:00"), 2) + ), + 2: convert_list_to_stream( + mock_parsed_logs_factory("Source 3", pendulum.parse("2022-11-16T00:05:54.380000-08:00"), 3) + ), + } + + # Check that we correctly get the first line of each non-empty log stream + + # First call: should add log records for all log streams + _add_log_from_parsed_log_streams_to_heap(heap, input_parsed_log_streams) + assert len(input_parsed_log_streams) == 3 + assert len(heap) == 3 + # Second call: source 1 is empty, should add log records for source 2 and source 3 + _add_log_from_parsed_log_streams_to_heap(heap, input_parsed_log_streams) + assert len(input_parsed_log_streams) == 2 # Source 1 should be removed + assert len(heap) == 5 + # Third call: source 1 and source 2 are empty, should add log records for source 3 + _add_log_from_parsed_log_streams_to_heap(heap, input_parsed_log_streams) + assert len(input_parsed_log_streams) == 1 # Source 2 should be removed + assert len(heap) == 6 + # Fourth call: source 1, source 2, and source 3 are empty, should not add any log records + _add_log_from_parsed_log_streams_to_heap(heap, input_parsed_log_streams) + assert len(input_parsed_log_streams) == 0 # Source 3 should be removed + assert len(heap) == 6 + # Fifth call: all sources are empty, should not add any log records + assert len(input_parsed_log_streams) == 0 # remains empty + assert len(heap) == 6 # no change in heap size + # Check heap + expected_logs: list[str] = [ + "Source 1 Event 0", + "Source 2 Event 0", + "Source 3 Event 0", + "Source 2 Event 1", + "Source 3 Event 1", + "Source 3 Event 2", + ] + actual_logs: list[str] = [] + for _ in range(len(heap)): + _, log = heapq.heappop(heap) + actual_logs.append(log.event) + assert actual_logs == expected_logs + + +@pytest.mark.parametrize( + "heap_setup, flush_size, last_log, expected_events", + [ + pytest.param( + [("msg1", "2023-01-01"), ("msg2", "2023-01-02")], + 2, + None, + ["msg1", "msg2"], + id="exact_size_flush", + ), + pytest.param( + [ + ("msg1", "2023-01-01"), + ("msg2", "2023-01-02"), + ("msg3", "2023-01-03"), + ("msg3", "2023-01-03"), + ("msg5", "2023-01-05"), + ], + 5, + None, + ["msg1", "msg2", "msg3", "msg5"], # msg3 is deduplicated, msg5 has default timestamp + id="flush_with_duplicates", + ), + pytest.param( + [("msg1", "2023-01-01"), ("msg1", "2023-01-01"), ("msg2", "2023-01-02")], + 3, + "msg1", + ["msg2"], # The last_log is "msg1", so any duplicates of "msg1" should be skipped + id="flush_with_last_log", + ), + pytest.param( + [("msg1", "DEFAULT"), ("msg1", "DEFAULT"), ("msg2", "DEFAULT")], + 3, + "msg1", + [ + "msg1", + "msg1", + "msg2", + ], # All messages have default timestamp, so they should be flushed even if last_log is "msg1" + id="flush_with_default_timestamp_and_last_log", + ), + pytest.param( + [("msg1", "2023-01-01"), ("msg2", "2023-01-02"), ("msg3", "2023-01-03")], + 2, + None, + ["msg1", "msg2"], # Only the first two messages should be flushed + id="flush_size_smaller_than_heap", + ), + ], +) +def test__flush_logs_out_of_heap(heap_setup, flush_size, last_log, expected_events): + """Test the _flush_logs_out_of_heap function with different scenarios.""" + + # Create structured log messages from the test setup + heap = [] + messages = {} + for i, (event, timestamp_str) in enumerate(heap_setup): + if timestamp_str == "DEFAULT": + timestamp = DEFAULT_SORT_DATETIME + else: + timestamp = pendulum.parse(timestamp_str) + + msg = StructuredLogMessage(event=event, timestamp=timestamp) + messages[event] = msg + heapq.heappush(heap, (_create_sort_key(msg.timestamp, i), msg)) + + # Set last_log if specified in the test case + last_log_obj = messages.get(last_log) if last_log is not None else None + last_log_container = [last_log_obj] + + # Run the function under test + result = list(_flush_logs_out_of_heap(heap, flush_size, last_log_container)) + + # Verify the results + assert len(result) == len(expected_events) + assert len(heap) == (len(heap_setup) - flush_size) + for i, expected_event in enumerate(expected_events): + assert result[i].event == expected_event, f"result = {result}, expected_event = {expected_events}" + + # verify that the last log is updated correctly + last_log_obj = last_log_container[0] + assert last_log_obj is not None + last_log_obj = cast("StructuredLogMessage", last_log_obj) + assert last_log_obj.event == expected_events[-1] + + def test_interleave_interleaves(): - log_sample1 = "\n".join( - [ - "[2022-11-16T00:05:54.278-0800] {taskinstance.py:1258} INFO - Starting attempt 1 of 1", - ] - ) - log_sample2 = "\n".join( - [ - "[2022-11-16T00:05:54.295-0800] {taskinstance.py:1278} INFO - Executing on 2022-11-16 08:05:52.324532+00:00", - "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO - Started process 52536 to run task", - "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO - Started process 52536 to run task", - "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO - Started process 52536 to run task", - "[2022-11-16T00:05:54.306-0800] {standard_task_runner.py:82} INFO - Running: ['airflow', 'tasks', 'run', 'simple_async_timedelta', 'wait', 'manual__2022-11-16T08:05:52.324532+00:00', '--job-id', '33648', '--raw', '--subdir', '/Users/dstandish/code/airflow/airflow/example_dags/example_time_delta_sensor_async.py', '--cfg-path', '/var/folders/7_/1xx0hqcs3txd7kqt0ngfdjth0000gn/T/tmp725r305n']", - "[2022-11-16T00:05:54.309-0800] {standard_task_runner.py:83} INFO - Job 33648: Subtask wait", - ] - ) - log_sample3 = "\n".join( - [ - "[2022-11-16T00:05:54.457-0800] {task_command.py:376} INFO - Running on host daniels-mbp-2.lan", - "[2022-11-16T00:05:54.592-0800] {taskinstance.py:1485} INFO - Exporting env vars: AIRFLOW_CTX_DAG_OWNER=airflow", - "AIRFLOW_CTX_DAG_ID=simple_async_timedelta", - "AIRFLOW_CTX_TASK_ID=wait", - "AIRFLOW_CTX_LOGICAL_DATE=2022-11-16T08:05:52.324532+00:00", - "AIRFLOW_CTX_TRY_NUMBER=1", - "AIRFLOW_CTX_DAG_RUN_ID=manual__2022-11-16T08:05:52.324532+00:00", - "[2022-11-16T00:05:54.604-0800] {taskinstance.py:1360} INFO - Pausing task as DEFERRED. dag_id=simple_async_timedelta, task_id=wait, execution_date=20221116T080552, start_date=20221116T080554", - ] - ) + log_sample1 = [ + "[2022-11-16T00:05:54.278-0800] {taskinstance.py:1258} INFO - Starting attempt 1 of 1", + ] + log_sample2 = [ + "[2022-11-16T00:05:54.295-0800] {taskinstance.py:1278} INFO - Executing on 2022-11-16 08:05:52.324532+00:00", + "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO - Started process 52536 to run task", + "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO - Started process 52536 to run task", + "[2022-11-16T00:05:54.300-0800] {standard_task_runner.py:55} INFO - Started process 52536 to run task", + "[2022-11-16T00:05:54.306-0800] {standard_task_runner.py:82} INFO - Running: ['airflow', 'tasks', 'run', 'simple_async_timedelta', 'wait', 'manual__2022-11-16T08:05:52.324532+00:00', '--job-id', '33648', '--raw', '--subdir', '/Users/dstandish/code/airflow/airflow/example_dags/example_time_delta_sensor_async.py', '--cfg-path', '/var/folders/7_/1xx0hqcs3txd7kqt0ngfdjth0000gn/T/tmp725r305n']", + "[2022-11-16T00:05:54.309-0800] {standard_task_runner.py:83} INFO - Job 33648: Subtask wait", + ] + log_sample3 = [ + "[2022-11-16T00:05:54.457-0800] {task_command.py:376} INFO - Running on host daniels-mbp-2.lan", + "[2022-11-16T00:05:54.592-0800] {taskinstance.py:1485} INFO - Exporting env vars: AIRFLOW_CTX_DAG_OWNER=airflow", + "AIRFLOW_CTX_DAG_ID=simple_async_timedelta", + "AIRFLOW_CTX_TASK_ID=wait", + "AIRFLOW_CTX_LOGICAL_DATE=2022-11-16T08:05:52.324532+00:00", + "AIRFLOW_CTX_TRY_NUMBER=1", + "AIRFLOW_CTX_DAG_RUN_ID=manual__2022-11-16T08:05:52.324532+00:00", + "[2022-11-16T00:05:54.604-0800] {taskinstance.py:1360} INFO - Pausing task as DEFERRED. dag_id=simple_async_timedelta, task_id=wait, execution_date=20221116T080552, start_date=20221116T080554", + ] # -08:00 tz = pendulum.tz.fixed_timezone(-28800) @@ -870,11 +1223,14 @@ def test_interleave_interleaves(): }, ] # Use a type adapter to durn it in to dicts -- makes it easier to compare/test than a bunch of objects - results = TypeAdapter(list[StructuredLogMessage]).dump_python( - _interleave_logs(log_sample2, log_sample1, log_sample3) + results: list[StructuredLogMessage] = list( + _interleave_logs( + convert_list_to_stream(log_sample2), + convert_list_to_stream(log_sample1), + convert_list_to_stream(log_sample3), + ) ) - # TypeAdapter gives us a generator out when it's generator is an input. Nice, but not useful for testing - results = list(results) + results: list[dict] = TypeAdapter(list[StructuredLogMessage]).dump_python(results) assert results == expected @@ -891,7 +1247,13 @@ def test_interleave_logs_correct_ordering(): [2023-01-17T12:47:11.883-0800] {triggerer_job.py:540} INFO - Trigger (ID 1) fired: TriggerEvent """ - logs = events(_interleave_logs(sample_with_dupe, "", sample_with_dupe)) + logs = extract_events( + _interleave_logs( + convert_list_to_stream(sample_with_dupe.splitlines()), + convert_list_to_stream([]), + convert_list_to_stream(sample_with_dupe.splitlines()), + ) + ) assert sample_with_dupe == "\n".join(logs) @@ -907,7 +1269,12 @@ def test_interleave_logs_correct_dedupe(): test, test""" - logs = events(_interleave_logs(",\n ".join(["test"] * 10))) + input_logs = ",\n ".join(["test"] * 10) + logs = extract_events( + _interleave_logs( + convert_list_to_stream(input_logs.splitlines()), + ) + ) assert sample_without_dupe == "\n".join(logs) diff --git a/devel-common/src/tests_common/test_utils/file_task_handler.py b/devel-common/src/tests_common/test_utils/file_task_handler.py new file mode 100644 index 0000000000000..5153fcfc511a0 --- /dev/null +++ b/devel-common/src/tests_common/test_utils/file_task_handler.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +from collections.abc import Generator, Iterable +from datetime import datetime +from typing import TYPE_CHECKING + +import pendulum + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + +if TYPE_CHECKING: + from airflow.utils.log.file_task_handler import ParsedLog, StructuredLogMessage + + +def extract_events(logs: Iterable[StructuredLogMessage], skip_source_info=True) -> list[str]: + """Helper function to return just the event (a.k.a message) from a list of StructuredLogMessage""" + logs = iter(logs) + if skip_source_info: + + def is_source_group(log: StructuredLogMessage) -> bool: + return not hasattr(log, "timestamp") or log.event == "::endgroup::" or hasattr(log, "sources") + + logs = itertools.dropwhile(is_source_group, logs) + + return [s.event for s in logs] + + +def convert_list_to_stream(input_list: list[str]) -> Generator[str, None, None]: + """ + Convert a list of strings to a stream-like object. + This function yields each string in the list one by one. + """ + yield from input_list + + +def mock_parsed_logs_factory( + event_prefix: str, + start_datetime: datetime, + count: int, +) -> list[ParsedLog]: + """ + Create a list of ParsedLog objects with the specified start datetime and count. + Each ParsedLog object contains a timestamp and a list of StructuredLogMessage objects. + """ + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.log.file_task_handler import StructuredLogMessage + + return [ + ( + pendulum.instance(start_datetime + pendulum.duration(seconds=i)), + i, + StructuredLogMessage( + timestamp=pendulum.instance(start_datetime + pendulum.duration(seconds=i)), + event=f"{event_prefix} Event {i}", + ), + ) + for i in range(count) + ] diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 62b5b70fa9bc0..6b20d228bad16 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1708,6 +1708,7 @@ StrictUndefined Stringified stringified Struct +StructuredLogMessage STS subchart subclassed diff --git a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index b5f6aaa12662b..0961b091a580c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -33,7 +33,6 @@ from airflow.configuration import conf from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.log.file_task_handler import FileTaskHandler from airflow.utils.log.logging_mixin import LoggingMixin @@ -170,15 +169,7 @@ def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages f"Reading remote log from Cloudwatch log_group: {self.log_group} log_stream: {relative_path}" ] try: - if AIRFLOW_V_3_0_PLUS: - from airflow.utils.log.file_task_handler import StructuredLogMessage - - logs = [ - StructuredLogMessage.model_validate(log) - for log in self.get_cloudwatch_logs(relative_path, ti) - ] - else: - logs = [self.get_cloudwatch_logs(relative_path, ti)] # type: ignore[arg-value] + logs = [self.get_cloudwatch_logs(relative_path, ti)] # type: ignore[arg-value] except Exception as e: logs = None messages.append(str(e)) @@ -206,8 +197,6 @@ def get_cloudwatch_logs(self, stream_name: str, task_instance: RuntimeTI): log_stream_name=stream_name, end_time=end_time, ) - if AIRFLOW_V_3_0_PLUS: - return list(self._event_to_dict(e) for e in events) return "\n".join(self._event_to_str(event) for event in events) def _event_to_dict(self, event: dict) -> dict: @@ -222,7 +211,8 @@ def _event_to_dict(self, event: dict) -> dict: def _event_to_str(self, event: dict) -> str: event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc) - formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] + # Format a datetime object to a string in Zulu time without milliseconds. + formatted_event_dt = event_dt.strftime("%Y-%m-%dT%H:%M:%SZ") message = event["message"] return f"[{formatted_event_dt}] {message}" diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py index 097ccc9eabbd2..4ca32105256e5 100644 --- a/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py +++ b/providers/amazon/tests/unit/amazon/aws/log/test_cloudwatch_task_handler.py @@ -25,11 +25,11 @@ from unittest.mock import ANY, call import boto3 +import pendulum import pytest import time_machine from moto import mock_aws from pydantic import TypeAdapter -from pydantic_core import TzInfo from watchtower import CloudWatchLogHandler from airflow.models import DAG, DagRun, TaskInstance @@ -50,7 +50,7 @@ def get_time_str(time_in_milliseconds): dt_time = dt.fromtimestamp(time_in_milliseconds / 1000.0, tz=timezone.utc) - return dt_time.strftime("%Y-%m-%d %H:%M:%S,000") + return dt_time.strftime("%Y-%m-%dT%H:%M:%SZ") @pytest.fixture(autouse=True) @@ -148,23 +148,12 @@ def test_log_message(self): stream_name = self.task_log_path.replace(":", "_") logs = self.subject.read(stream_name, self.ti) - if AIRFLOW_V_3_0_PLUS: - from airflow.utils.log.file_task_handler import StructuredLogMessage - - metadata, logs = logs + metadata, logs = logs - results = TypeAdapter(list[StructuredLogMessage]).dump_python(logs) - assert metadata == [ - f"Reading remote log from Cloudwatch log_group: log_group_name log_stream: {stream_name}" - ] - assert results == [ - { - "event": "Hi", - "foo": "bar", - "level": "info", - "timestamp": datetime(2025, 3, 27, 21, 58, 1, 2000, tzinfo=TzInfo(0)), - }, - ] + assert metadata == [ + f"Reading remote log from Cloudwatch log_group: log_group_name log_stream: {stream_name}" + ] + assert logs == ['[2025-03-27T21:58:01Z] {"foo": "bar", "event": "Hi", "level": "info"}'] def test_event_to_str(self): handler = self.subject @@ -282,7 +271,11 @@ def test_read(self, monkeypatch): {"timestamp": current_time, "message": "Third"}, ], ) - monkeypatch.setattr(self.cloudwatch_task_handler, "_read_from_logs_server", lambda a, b: ([], [])) + monkeypatch.setattr( + self.cloudwatch_task_handler, + "_read_from_logs_server", + lambda ti, worker_log_rel_path: ([], []), + ) msg_template = textwrap.dedent(""" INFO - ::group::Log message source details *** Reading remote log from Cloudwatch log_group: {} log_stream: {} @@ -294,14 +287,24 @@ def test_read(self, monkeypatch): if AIRFLOW_V_3_0_PLUS: from airflow.utils.log.file_task_handler import StructuredLogMessage + logs = list(logs) results = TypeAdapter(list[StructuredLogMessage]).dump_python(logs) assert results[-4:] == [ {"event": "::endgroup::", "timestamp": None}, - {"event": "First", "timestamp": datetime(2025, 3, 27, 21, 57, 59)}, - {"event": "Second", "timestamp": datetime(2025, 3, 27, 21, 58, 0)}, - {"event": "Third", "timestamp": datetime(2025, 3, 27, 21, 58, 1)}, + { + "event": "[2025-03-27T21:57:59Z] First", + "timestamp": pendulum.datetime(2025, 3, 27, 21, 57, 59), + }, + { + "event": "[2025-03-27T21:58:00Z] Second", + "timestamp": pendulum.datetime(2025, 3, 27, 21, 58, 0), + }, + { + "event": "[2025-03-27T21:58:01Z] Third", + "timestamp": pendulum.datetime(2025, 3, 27, 21, 58, 1), + }, ] - assert metadata["log_pos"] == 3 + assert metadata == {"end_of_log": False, "log_pos": 3} else: events = "\n".join( [ diff --git a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py index f254eeeec07a7..7d58b5e5b30b5 100644 --- a/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py +++ b/providers/amazon/tests/unit/amazon/aws/log/test_s3_task_handler.py @@ -264,6 +264,7 @@ def test_read(self): expected_s3_uri = f"s3://bucket/{self.remote_log_key}" if AIRFLOW_V_3_0_PLUS: + log = list(log) assert log[0].event == "::group::Log message source details" assert expected_s3_uri in log[0].sources assert log[1].event == "::endgroup::" @@ -284,6 +285,7 @@ def test_read_when_s3_log_missing(self): self.s3_task_handler._read_from_logs_server = mock.Mock(return_value=([], [])) log, metadata = self.s3_task_handler.read(ti) if AIRFLOW_V_3_0_PLUS: + log = list(log) assert len(log) == 2 assert metadata == {"end_of_log": True, "log_pos": 0} else: diff --git a/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py b/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py index 456f680de41fa..c95ab43236ecf 100644 --- a/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py +++ b/providers/celery/tests/unit/celery/log_handlers/test_log_handlers.py @@ -37,6 +37,9 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.file_task_handler import ( + convert_list_to_stream, +) from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS pytestmark = pytest.mark.db_test @@ -78,14 +81,24 @@ def test__read_for_celery_executor_fallbacks_to_worker(self, create_task_instanc fth = FileTaskHandler("") fth._read_from_logs_server = mock.Mock() - fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"] + + # compat with 2.x and 3.x + if AIRFLOW_V_3_0_PLUS: + fth._read_from_logs_server.return_value = ( + ["this message"], + [convert_list_to_stream(["this", "log", "content"])], + ) + else: + fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"] + logs, metadata = fth._read(ti=ti, try_number=1) fth._read_from_logs_server.assert_called_once() if AIRFLOW_V_3_0_PLUS: - assert metadata == {"end_of_log": False, "log_pos": 3} + logs = list(logs) assert logs[0].sources == ["this message"] assert [x.event for x in logs[-3:]] == ["this", "log", "content"] + assert metadata == {"end_of_log": False, "log_pos": 3} else: assert "*** this message\n" in logs assert logs.endswith("this\nlog\ncontent") diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py index 420f33b67f01f..3c6124603c7a4 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py @@ -29,7 +29,7 @@ from collections import defaultdict from collections.abc import Callable from operator import attrgetter -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from urllib.parse import quote, urlparse # Using `from elasticsearch import *` would break elasticsearch mocking used in unit test. @@ -56,6 +56,7 @@ from datetime import datetime from airflow.models.taskinstance import TaskInstance, TaskInstanceKey + from airflow.utils.log.file_task_handler import LogMetadata LOG_LINE_DEFAULTS = {"exc_text": "", "stack_info": ""} @@ -294,8 +295,8 @@ def _read_grouped_logs(self): return True def _read( - self, ti: TaskInstance, try_number: int, metadata: dict | None = None - ) -> tuple[EsLogMsgType, dict]: + self, ti: TaskInstance, try_number: int, metadata: LogMetadata | None = None + ) -> tuple[EsLogMsgType, LogMetadata]: """ Endpoint for streaming log. @@ -306,7 +307,9 @@ def _read( :return: a list of tuple with host and log documents, metadata. """ if not metadata: - metadata = {"offset": 0} + # LogMetadata(TypedDict) is used as type annotation for log_reader; added ignore to suppress mypy error + metadata = {"offset": 0} # type: ignore[assignment] + metadata = cast("LogMetadata", metadata) if "offset" not in metadata: metadata["offset"] = 0 @@ -346,7 +349,9 @@ def _read( "Otherwise, the logs for this task instance may have been removed." ) if AIRFLOW_V_3_0_PLUS: - return missing_log_message, metadata + from airflow.utils.log.file_task_handler import StructuredLogMessage + + return [StructuredLogMessage(event=missing_log_message)], metadata return [("", missing_log_message)], metadata # type: ignore[list-item] if ( # Assume end of log after not receiving new log for N min, diff --git a/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py b/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py index be3e36c8f90f7..2d90c110832d3 100644 --- a/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py +++ b/providers/elasticsearch/tests/unit/elasticsearch/log/test_es_task_handler.py @@ -208,6 +208,7 @@ def test_read(self, ti): ) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" @@ -235,6 +236,7 @@ def test_read_with_patterns(self, ti): ) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" @@ -304,10 +306,11 @@ def test_read_missing_logs(self, seconds, create_task_instance): ts = pendulum.now().add(seconds=-seconds) logs, metadatas = self.es_task_handler.read(ti, 1, {"offset": 0, "last_log_timestamp": str(ts)}) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) if seconds > 5: # we expect a log not found message when checking began more than 5 seconds ago expected_pattern = r"^\*\*\* Log .* not found in Elasticsearch.*" - assert re.match(expected_pattern, logs) is not None + assert re.match(expected_pattern, logs[0].event) is not None assert metadatas["end_of_log"] is True else: # we've "waited" less than 5 seconds so it should not be "end of log" and should be no log message @@ -360,6 +363,7 @@ def test_read_with_match_phrase_query(self, ti): }, ) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" @@ -382,6 +386,7 @@ def test_read_with_match_phrase_query(self, ti): def test_read_with_none_metadata(self, ti): logs, metadatas = self.es_task_handler.read(ti, 1) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" @@ -431,6 +436,7 @@ def test_read_with_empty_metadata(self, ti): ts = pendulum.now() logs, metadatas = self.es_task_handler.read(ti, 1, {}) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" @@ -520,6 +526,7 @@ def test_read_as_download_logs(self, ti): }, ) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" @@ -601,6 +608,7 @@ def test_read_with_json_format(self, ti): ) expected_message = "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - " if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[2].event == expected_message else: assert logs[0][0][1] == expected_message @@ -634,6 +642,7 @@ def test_read_with_json_format_with_custom_offset_and_host_fields(self, ti): ) expected_message = "[2020-12-24 19:25:00,962] {taskinstance.py:851} INFO - some random stuff - " if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[2].event == expected_message else: assert logs[0][0][1] == expected_message diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py index 205bcfc5cb1bd..cbe6df81c752f 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler.py @@ -114,6 +114,7 @@ def test_should_read_logs_from_remote(self, mock_blob, mock_client, mock_creds, mock_blob.from_string.assert_called_once_with(expected_gs_uri, mock_client.return_value) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == [expected_gs_uri] assert logs[1].event == "::endgroup::" @@ -143,6 +144,7 @@ def test_should_read_from_local_on_logs_read_error(self, mock_blob, mock_client, expected_gs_uri = f"gs://bucket/{mock_obj.name}" if AIRFLOW_V_3_0_PLUS: + log = list(log) assert log[0].event == "::group::Log message source details" assert log[0].sources == [ expected_gs_uri, diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py b/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py index ba7fe5f1b2a4c..dc87b5a9da758 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/log/test_wasb_task_handler.py @@ -117,14 +117,12 @@ def test_wasb_read(self, mock_hook_cls, ti): logs, metadata = self.wasb_task_handler.read(ti) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["https://wasb-container.blob.core.windows.net/abc/hello.log"] assert logs[1].event == "::endgroup::" assert logs[2].event == "Log line" - assert metadata == { - "end_of_log": True, - "log_pos": 1, - } + assert metadata == {"end_of_log": True, "log_pos": 1} else: assert logs[0][0][0] == "localhost" assert ( diff --git a/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py b/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py index 5a6b60aae24d6..7e75ca40b91c6 100644 --- a/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py +++ b/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py @@ -25,7 +25,7 @@ from collections.abc import Callable from datetime import datetime from operator import attrgetter -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import pendulum from opensearchpy import OpenSearch @@ -45,6 +45,7 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance, TaskInstanceKey + from airflow.utils.log.file_task_handler import LogMetadata if AIRFLOW_V_3_0_PLUS: @@ -333,8 +334,8 @@ def _render_log_id(self, ti: TaskInstance | TaskInstanceKey, try_number: int) -> ) def _read( - self, ti: TaskInstance, try_number: int, metadata: dict | None = None - ) -> tuple[OsLogMsgType, dict]: + self, ti: TaskInstance, try_number: int, metadata: LogMetadata | None = None + ) -> tuple[OsLogMsgType, LogMetadata]: """ Endpoint for streaming log. @@ -345,7 +346,10 @@ def _read( :return: a list of tuple with host and log documents, metadata. """ if not metadata: - metadata = {"offset": 0} + # LogMetadata(TypedDict) is used as type annotation for log_reader; added ignore to suppress mypy error + metadata = {"offset": 0} # type: ignore[assignment] + metadata = cast("LogMetadata", metadata) + if "offset" not in metadata: metadata["offset"] = 0 @@ -384,6 +388,12 @@ def _read( "If your task started recently, please wait a moment and reload this page. " "Otherwise, the logs for this task instance may have been removed." ) + if AIRFLOW_V_3_0_PLUS: + from airflow.utils.log.file_task_handler import StructuredLogMessage + + # return list of StructuredLogMessage for Airflow 3.0+ + return [StructuredLogMessage(event=missing_log_message)], metadata + return [("", missing_log_message)], metadata # type: ignore[list-item] if ( # Assume end of log after not receiving new log for N min, diff --git a/providers/opensearch/tests/unit/opensearch/log/test_os_task_handler.py b/providers/opensearch/tests/unit/opensearch/log/test_os_task_handler.py index fb51c56e469ec..4faeb2223f97c 100644 --- a/providers/opensearch/tests/unit/opensearch/log/test_os_task_handler.py +++ b/providers/opensearch/tests/unit/opensearch/log/test_os_task_handler.py @@ -204,6 +204,7 @@ def test_read(self, ti): "on 2023-07-09 07:47:32+00:00" ) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" @@ -235,6 +236,7 @@ def test_read_with_patterns(self, ti): "on 2023-07-09 07:47:32+00:00" ) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" @@ -332,10 +334,11 @@ def test_read_missing_logs(self, seconds, create_task_instance): ): logs, metadatas = self.os_task_handler.read(ti, 1, {"offset": 0, "last_log_timestamp": str(ts)}) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) if seconds > 5: # we expect a log not found message when checking began more than 5 seconds ago - assert len(logs[0]) == 2 - actual_message = logs[0][1] + assert len(logs) == 1 + actual_message = logs[0].event expected_pattern = r"^\*\*\* Log .* not found in Opensearch.*" assert re.match(expected_pattern, actual_message) is not None assert metadatas["end_of_log"] is True @@ -374,6 +377,7 @@ def test_read_with_none_metadata(self, ti): "on 2023-07-09 07:47:32+00:00" ) if AIRFLOW_V_3_0_PLUS: + logs = list(logs) assert logs[0].event == "::group::Log message source details" assert logs[0].sources == ["default_host"] assert logs[1].event == "::endgroup::" diff --git a/providers/redis/src/airflow/providers/redis/log/redis_task_handler.py b/providers/redis/src/airflow/providers/redis/log/redis_task_handler.py index d1fd1cc8de0d6..bdc5b6ed0f89a 100644 --- a/providers/redis/src/airflow/providers/redis/log/redis_task_handler.py +++ b/providers/redis/src/airflow/providers/redis/log/redis_task_handler.py @@ -19,7 +19,7 @@ import logging from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from airflow.configuration import conf from airflow.providers.redis.hooks.redis import RedisHook @@ -31,6 +31,7 @@ from redis import Redis from airflow.models import TaskInstance + from airflow.utils.log.file_task_handler import LogMetadata class RedisTaskHandler(FileTaskHandler, LoggingMixin): @@ -75,8 +76,8 @@ def _read( self, ti: TaskInstance, try_number: int, - metadata: dict[str, Any] | None = None, - ): + metadata: LogMetadata | None = None, + ) -> tuple[str | list[str], LogMetadata]: log_str = b"\n".join( self.conn.lrange(self._render_filename(ti, try_number), start=0, end=-1) ).decode() diff --git a/providers/redis/tests/unit/redis/log/test_redis_task_handler.py b/providers/redis/tests/unit/redis/log/test_redis_task_handler.py index 99bb497c56318..2a64f8d7674a8 100644 --- a/providers/redis/tests/unit/redis/log/test_redis_task_handler.py +++ b/providers/redis/tests/unit/redis/log/test_redis_task_handler.py @@ -31,7 +31,11 @@ from airflow.utils.timezone import datetime from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.file_task_handler import extract_events +from tests_common.test_utils.version_compat import ( + AIRFLOW_V_3_0_PLUS, + get_base_airflow_version_tuple, +) class TestRedisTaskHandler: @@ -120,7 +124,12 @@ def test_read(self, ti): logs = handler.read(ti) if AIRFLOW_V_3_0_PLUS: - assert logs == (["Line 1\nLine 2"], {"end_of_log": True}) + if get_base_airflow_version_tuple() < (3, 1, 0): + assert logs == (["Line 1\nLine 2"], {"end_of_log": True}) + else: + log_stream, metadata = logs + assert extract_events(log_stream) == ["Line 1", "Line 2"] + assert metadata == {"end_of_log": True} else: assert logs == ([[("", "Line 1\nLine 2")]], [{"end_of_log": True}]) lrange.assert_called_once_with(key, start=0, end=-1)