|
22 | 22 | import json |
23 | 23 | import logging |
24 | 24 | import os |
| 25 | +from collections.abc import Generator |
25 | 26 | from datetime import date, datetime, timedelta, timezone |
26 | 27 | from functools import cached_property |
27 | 28 | from pathlib import Path |
|
33 | 34 | from airflow.configuration import conf |
34 | 35 | from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook |
35 | 36 | from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms |
| 37 | +from airflow.providers.amazon.version_compat import SUPPORT_STREAM_BASED_READ |
36 | 38 | from airflow.utils.log.file_task_handler import FileTaskHandler |
37 | 39 | from airflow.utils.log.logging_mixin import LoggingMixin |
38 | 40 |
|
|
43 | 45 | from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI |
44 | 46 | from airflow.utils.log.file_task_handler import ( |
45 | 47 | LegacyLogResponse, |
46 | | - LogMessages, |
47 | 48 | LogResponse, |
48 | | - LogSourceInfo, |
49 | 49 | RawLogStream, |
50 | 50 | ) |
51 | 51 |
|
@@ -299,23 +299,24 @@ def close(self): |
299 | 299 | # Mark closed so we don't double write if close is called twice |
300 | 300 | self.closed = True |
301 | 301 |
|
302 | | - def _read_remote_logs( |
303 | | - self, task_instance, try_number, metadata=None |
304 | | - ) -> tuple[LogSourceInfo, LogMessages]: |
| 302 | + def _read_remote_logs(self, task_instance, try_number, metadata=None) -> LegacyLogResponse | LogResponse: |
305 | 303 | stream_name = self._render_filename(task_instance, try_number) |
306 | | - messages, logs = self.io.read(stream_name, task_instance) |
307 | | - |
308 | 304 | messages = [ |
309 | 305 | f"Reading remote log from Cloudwatch log_group: {self.io.log_group} log_stream: {stream_name}" |
310 | 306 | ] |
| 307 | + |
| 308 | + logs: list[str] | list[Generator[str, None, None]] |
311 | 309 | try: |
312 | 310 | events = self.io.get_cloudwatch_logs(stream_name, task_instance) |
313 | | - logs = ["\n".join(self._event_to_str(event) for event in events)] |
| 311 | + if SUPPORT_STREAM_BASED_READ: |
| 312 | + logs = [(self._event_to_str(event) for event in events)] |
| 313 | + else: |
| 314 | + logs = ["\n".join(self._event_to_str(event) for event in events)] |
314 | 315 | except Exception as e: |
315 | 316 | logs = [] |
316 | 317 | messages.append(str(e)) |
317 | 318 |
|
318 | | - return messages, logs |
| 319 | + return messages, logs # type: ignore[return-value] |
319 | 320 |
|
320 | 321 | def _event_to_str(self, event: dict) -> str: |
321 | 322 | event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc) |
|
0 commit comments