Skip to content

Commit de079f2

Browse files
committed
Support stream based read for CloudwatchTaskHandler
1 parent 4a6caf9 commit de079f2

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import json
2323
import logging
2424
import os
25+
from collections.abc import Generator
2526
from datetime import date, datetime, timedelta, timezone
2627
from functools import cached_property
2728
from pathlib import Path
@@ -33,6 +34,7 @@
3334
from airflow.configuration import conf
3435
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
3536
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
37+
from airflow.providers.amazon.version_compat import SUPPORT_STREAM_BASED_READ
3638
from airflow.utils.log.file_task_handler import FileTaskHandler
3739
from airflow.utils.log.logging_mixin import LoggingMixin
3840

@@ -43,9 +45,7 @@
4345
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
4446
from airflow.utils.log.file_task_handler import (
4547
LegacyLogResponse,
46-
LogMessages,
4748
LogResponse,
48-
LogSourceInfo,
4949
RawLogStream,
5050
)
5151

@@ -299,23 +299,24 @@ def close(self):
299299
# Mark closed so we don't double write if close is called twice
300300
self.closed = True
301301

302-
def _read_remote_logs(
303-
self, task_instance, try_number, metadata=None
304-
) -> tuple[LogSourceInfo, LogMessages]:
302+
def _read_remote_logs(self, task_instance, try_number, metadata=None) -> LegacyLogResponse | LogResponse:
305303
stream_name = self._render_filename(task_instance, try_number)
306-
messages, logs = self.io.read(stream_name, task_instance)
307-
308304
messages = [
309305
f"Reading remote log from Cloudwatch log_group: {self.io.log_group} log_stream: {stream_name}"
310306
]
307+
308+
logs: list[str] | list[Generator[str, None, None]]
311309
try:
312310
events = self.io.get_cloudwatch_logs(stream_name, task_instance)
313-
logs = ["\n".join(self._event_to_str(event) for event in events)]
311+
if SUPPORT_STREAM_BASED_READ:
312+
logs = [(self._event_to_str(event) for event in events)]
313+
else:
314+
logs = ["\n".join(self._event_to_str(event) for event in events)]
314315
except Exception as e:
315316
logs = []
316317
messages.append(str(e))
317318

318-
return messages, logs
319+
return messages, logs # type: ignore[return-value]
319320

320321
def _event_to_str(self, event: dict) -> str:
321322
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)

providers/amazon/src/airflow/providers/amazon/version_compat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
3434

3535
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
3636
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
37+
SUPPORT_STREAM_BASED_READ: bool = get_base_airflow_version_tuple() >= (3, 0, 3)
3738

3839
if AIRFLOW_V_3_1_PLUS:
3940
from airflow.sdk import BaseHook

0 commit comments

Comments
 (0)