diff --git a/airflow-core/docs/img/airflow_erd.sha256 b/airflow-core/docs/img/airflow_erd.sha256 index fb89acbe007e2..b9c9c514f82ac 100644 --- a/airflow-core/docs/img/airflow_erd.sha256 +++ b/airflow-core/docs/img/airflow_erd.sha256 @@ -1 +1 @@ -db00d57fce32830b69f2c1481b231e65e67e197b4a96a5fa1c870cd555eac3bd \ No newline at end of file +7d6f2aa31fb10d8006b6b7f572bedcc4be78eb828edfd42386e0b872b6999afc \ No newline at end of file diff --git a/airflow-core/src/airflow/__init__.py b/airflow-core/src/airflow/__init__.py index 474098a5053f9..6a997cecbdfef 100644 --- a/airflow-core/src/airflow/__init__.py +++ b/airflow-core/src/airflow/__init__.py @@ -101,7 +101,7 @@ def __getattr__(name: str): module_path, attr_name, deprecated = __lazy_imports.get(name, ("", "", False)) if not module_path: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - elif deprecated: + if deprecated: warnings.warn( f"Import {name!r} directly from the airflow module is deprecated and " f"will be removed in the future. Please import it from 'airflow{module_path}.{attr_name}'.", diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 2afb34557737a..a1e6729e5ea44 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -92,10 +92,9 @@ def _guess_best_algorithm(key: AllowedPrivateKeys): if isinstance(key, RSAPrivateKey): return "RS512" - elif isinstance(key, Ed25519PrivateKey): + if isinstance(key, Ed25519PrivateKey): return "EdDSA" - else: - raise ValueError(f"Unknown key object {type(key)}") + raise ValueError(f"Unknown key object {type(key)}") @attrs.define(repr=False) @@ -297,8 +296,7 @@ def __attrs_post_init__(self): "Cannot guess the algorithm when using JWKS - please specify it in the config option " "[api_auth] jwt_algorithm" ) - else: - self.algorithm = ["HS512"] + self.algorithm = ["HS512"] def _get_kid_from_header(self, unvalidated: str) -> str: header = jwt.get_unverified_header(unvalidated) @@ -475,7 +473,7 @@ def generate_private_key(key_type: str = "RSA", key_size: int = 2048): # Generate an RSA private key return rsa.generate_private_key(public_exponent=65537, key_size=key_size, backend=default_backend()) - elif key_type == "Ed25519": + if key_type == "Ed25519": return ed25519.Ed25519PrivateKey.generate() raise ValueError(f"unsupported key type: {key_type}") diff --git a/airflow-core/src/airflow/api_fastapi/common/parameters.py b/airflow-core/src/airflow/api_fastapi/common/parameters.py index e68001d5b539a..f81383be969f0 100644 --- a/airflow-core/src/airflow/api_fastapi/common/parameters.py +++ b/airflow-core/src/airflow/api_fastapi/common/parameters.py @@ -206,8 +206,7 @@ def to_orm(self, select: Select) -> Select: if self.value[0] == "-": return select.order_by(nullscheck, column.desc(), primary_key_column.desc()) - else: - return select.order_by(nullscheck, column.asc(), primary_key_column.asc()) + return select.order_by(nullscheck, column.asc(), primary_key_column.asc()) def get_primary_key_column(self) -> Column: """Get the primary key column of the model of SortParam object.""" diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dags.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dags.py index 8f2874c1aa0a0..8f849a259f23b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dags.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dags.py @@ -88,7 +88,7 @@ def get_owners(cls, v: Any) -> list[str] | None: if v is None: return [] - elif isinstance(v, str): + if isinstance(v, str): return v.split(",") return v diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py index 985a3934f5ca3..7b12593994f7a 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -275,15 +275,14 @@ def clear_dag_run( task_instances=cast("list[TaskInstanceResponse]", task_instances), total_entries=len(task_instances), ) - else: - dag.clear( - run_id=dag_run_id, - task_ids=None, - only_failed=body.only_failed, - session=session, - ) - dag_run_cleared = session.scalar(select(DagRun).where(DagRun.id == dag_run.id)) - return dag_run_cleared + dag.clear( + run_id=dag_run_id, + task_ids=None, + only_failed=body.only_failed, + session=session, + ) + dag_run_cleared = session.scalar(select(DagRun).where(DagRun.id == dag_run.id)) + return dag_run_cleared @dag_run_router.get( diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py index f5b22ddf5e8d0..e54a15b2447c7 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py @@ -140,16 +140,15 @@ def get_log( if not metadata.get("end_of_log", False): encoded_token = URLSafeSerializer(request.app.state.secret_key).dumps(metadata) return TaskInstancesLogResponse.model_construct(continuation_token=encoded_token, content=logs) - else: - # text/plain, or something else we don't understand. Return raw log content - - # We need to exhaust the iterator before we can generate the continuation token. - # We could improve this by making it a streaming/async response, and by then setting the header using - # HTTP Trailers - logs = "".join(task_log_reader.read_log_stream(ti, try_number, metadata)) - headers = None - if not metadata.get("end_of_log", False): - headers = { - "Airflow-Continuation-Token": URLSafeSerializer(request.app.state.secret_key).dumps(metadata) - } - return Response(media_type="application/x-ndjson", content=logs, headers=headers) + # text/plain, or something else we don't understand. Return raw log content + + # We need to exhaust the iterator before we can generate the continuation token. + # We could improve this by making it a streaming/async response, and by then setting the header using + # HTTP Trailers + logs = "".join(task_log_reader.read_log_stream(ti, try_number, metadata)) + headers = None + if not metadata.get("end_of_log", False): + headers = { + "Airflow-Continuation-Token": URLSafeSerializer(request.app.state.secret_key).dumps(metadata) + } + return Response(media_type="application/x-ndjson", content=logs, headers=headers) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py index 8f1d3e32a34b6..4c236cf9756be 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py @@ -97,7 +97,7 @@ def handle_bulk_create( status_code=status.HTTP_409_CONFLICT, detail=f"The connections with these connection_ids: {matched_connection_ids} already exist.", ) - elif action.action_on_existence == BulkActionOnExistence.SKIP: + if action.action_on_existence == BulkActionOnExistence.SKIP: create_connection_ids = not_found_connection_ids else: create_connection_ids = to_create_connection_ids @@ -130,7 +130,7 @@ def handle_bulk_update( status_code=status.HTTP_404_NOT_FOUND, detail=f"The connections with these connection_ids: {not_found_connection_ids} were not found.", ) - elif action.action_on_non_existence == BulkActionNotOnExistence.SKIP: + if action.action_on_non_existence == BulkActionNotOnExistence.SKIP: update_connection_ids = matched_connection_ids else: update_connection_ids = to_update_connection_ids @@ -170,7 +170,7 @@ def handle_bulk_delete( status_code=status.HTTP_404_NOT_FOUND, detail=f"The connections with these connection_ids: {not_found_connection_ids} were not found.", ) - elif action.action_on_non_existence == BulkActionNotOnExistence.SKIP: + if action.action_on_non_existence == BulkActionNotOnExistence.SKIP: delete_connection_ids = matched_connection_ids else: delete_connection_ids = to_delete_connection_ids diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/pools.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/pools.py index 4a78e12e4f4f0..98e16cbdccc28 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/pools.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/pools.py @@ -66,7 +66,7 @@ def handle_bulk_create(self, action: BulkCreateAction[PoolBody], results: BulkAc status_code=status.HTTP_409_CONFLICT, detail=f"The pools with these pool names: {matched_pool_names} already exist.", ) - elif action.action_on_existence == BulkActionOnExistence.SKIP: + if action.action_on_existence == BulkActionOnExistence.SKIP: create_pool_names = not_found_pool_names else: create_pool_names = to_create_pool_names @@ -97,7 +97,7 @@ def handle_bulk_update(self, action: BulkUpdateAction[PoolBody], results: BulkAc status_code=status.HTTP_404_NOT_FOUND, detail=f"The pools with these pool names: {not_found_pool_names} were not found.", ) - elif action.action_on_non_existence == BulkActionNotOnExistence.SKIP: + if action.action_on_non_existence == BulkActionNotOnExistence.SKIP: update_pool_names = matched_pool_names else: update_pool_names = to_update_pool_names @@ -134,7 +134,7 @@ def handle_bulk_delete(self, action: BulkDeleteAction[PoolBody], results: BulkAc status_code=status.HTTP_404_NOT_FOUND, detail=f"The pools with these pool names: {not_found_pool_names} were not found.", ) - elif action.action_on_non_existence == BulkActionNotOnExistence.SKIP: + if action.action_on_non_existence == BulkActionNotOnExistence.SKIP: delete_pool_names = matched_pool_names else: delete_pool_names = to_delete_pool_names diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/public/variables.py b/airflow-core/src/airflow/api_fastapi/core_api/services/public/variables.py index e602faa4e7d00..77ab8568cbc53 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/variables.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/variables.py @@ -57,7 +57,7 @@ def handle_bulk_create(self, action: BulkCreateAction, results: BulkActionRespon status_code=status.HTTP_409_CONFLICT, detail=f"The variables with these keys: {matched_keys} already exist.", ) - elif action.action_on_existence == BulkActionOnExistence.SKIP: + if action.action_on_existence == BulkActionOnExistence.SKIP: create_keys = not_found_keys else: create_keys = to_create_keys @@ -86,7 +86,7 @@ def handle_bulk_update(self, action: BulkUpdateAction, results: BulkActionRespon status_code=status.HTTP_404_NOT_FOUND, detail=f"The variables with these keys: {not_found_keys} were not found.", ) - elif action.action_on_non_existence == BulkActionNotOnExistence.SKIP: + if action.action_on_non_existence == BulkActionNotOnExistence.SKIP: update_keys = matched_keys else: update_keys = to_update_keys @@ -118,7 +118,7 @@ def handle_bulk_delete(self, action: BulkDeleteAction, results: BulkActionRespon status_code=status.HTTP_404_NOT_FOUND, detail=f"The variables with these keys: {not_found_keys} were not found.", ) - elif action.action_on_non_existence == BulkActionNotOnExistence.SKIP: + if action.action_on_non_existence == BulkActionNotOnExistence.SKIP: delete_keys = matched_keys else: delete_keys = to_delete_keys diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 7437a01f6db69..cd8287be97b73 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -195,13 +195,13 @@ def ti_state_discriminator(v: dict[str, str] | StrictBaseModel) -> str: if state == TIState.SUCCESS: return "success" - elif state in set(TerminalTIState): + if state in set(TerminalTIState): return "_terminal_" - elif state == TIState.DEFERRED: + if state == TIState.DEFERRED: return "deferred" - elif state == TIState.UP_FOR_RESCHEDULE: + if state == TIState.UP_FOR_RESCHEDULE: return "up_for_reschedule" - elif state == TIState.UP_FOR_RETRY: + if state == TIState.UP_FOR_RETRY: return "up_for_retry" return "_other_" diff --git a/airflow-core/src/airflow/cli/commands/api_server_command.py b/airflow-core/src/airflow/cli/commands/api_server_command.py index a10de9d27df86..343890aa8e4e8 100644 --- a/airflow-core/src/airflow/cli/commands/api_server_command.py +++ b/airflow-core/src/airflow/cli/commands/api_server_command.py @@ -129,9 +129,9 @@ def _get_ssl_cert_and_key_filepaths(cli_arguments) -> tuple[str | None, str | No raise AirflowConfigException(error_template_2.format(ssl_key)) return (ssl_cert, ssl_key) - elif ssl_cert: + if ssl_cert: raise AirflowConfigException(error_template_1.format("SSL certificate", "SSL key")) - elif ssl_key: + if ssl_key: raise AirflowConfigException(error_template_1.format("SSL key", "SSL certificate")) return (None, None) diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py b/airflow-core/src/airflow/cli/commands/dag_command.py index 197273e4be617..ee75ef34aa170 100644 --- a/airflow-core/src/airflow/cli/commands/dag_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_command.py @@ -234,7 +234,7 @@ def dag_dependencies_show(args) -> None: "Option --save and --imgcat are mutually exclusive. " "Please remove one option to execute the command.", ) - elif filename: + if filename: _save_dot_to_file(dot, filename) elif imgcat: _display_dot_via_imgcat(dot) @@ -255,7 +255,7 @@ def dag_show(args) -> None: "Option --save and --imgcat are mutually exclusive. " "Please remove one option to execute the command.", ) - elif filename: + if filename: _save_dot_to_file(dot, filename) elif imgcat: _display_dot_via_imgcat(dot) @@ -275,8 +275,7 @@ def _display_dot_via_imgcat(dot: Dot) -> None: except OSError as e: if e.errno == errno.ENOENT: raise SystemExit("Failed to execute. Make sure the imgcat executables are on your systems 'PATH'") - else: - raise + raise def _save_dot_to_file(dot: Dot, filename: str) -> None: diff --git a/airflow-core/src/airflow/cli/commands/info_command.py b/airflow-core/src/airflow/cli/commands/info_command.py index 8bee548e36965..96b6550e48297 100644 --- a/airflow-core/src/airflow/cli/commands/info_command.py +++ b/airflow-core/src/airflow/cli/commands/info_command.py @@ -141,11 +141,11 @@ def get_current() -> OperatingSystem: """Get current operating system.""" if os.name == "nt": return OperatingSystem.WINDOWS - elif "linux" in sys.platform: + if "linux" in sys.platform: return OperatingSystem.LINUX - elif "darwin" in sys.platform: + if "darwin" in sys.platform: return OperatingSystem.MACOSX - elif "cygwin" in sys.platform: + if "cygwin" in sys.platform: return OperatingSystem.CYGWIN return OperatingSystem.UNKNOWN @@ -203,8 +203,7 @@ def _get_version(cmd: list[str], grep: bytes | None = None): data = [line for line in data if grep in line] if len(data) != 1: return "NOT AVAILABLE" - else: - return data[0].decode() + return data[0].decode() except OSError: return "NOT AVAILABLE" @@ -216,8 +215,7 @@ def get_fullname(o): module = o.__class__.__module__ if module is None or module == str.__class__.__module__: return o.__class__.__name__ # Avoid reporting __builtin__ - else: - return f"{module}.{o.__class__.__name__}" + return f"{module}.{o.__class__.__name__}" try: handler_names = [get_fullname(handler) for handler in logging.getLogger("airflow.task").handlers] diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index 0b330ed61cb7f..52466fcdc1fb2 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -108,7 +108,7 @@ def _get_dag_run( ) if dag_run is not None: return dag_run, False - elif not create_if_necessary: + if not create_if_necessary: raise DagRunNotFound( f"DagRun for {dag.dag_id} with run_id or logical_date of {logical_date_or_run_id!r} not found" ) @@ -132,7 +132,7 @@ def _get_dag_run( state=DagRunState.RUNNING, ) return dag_run, True - elif create_if_necessary == "db": + if create_if_necessary == "db": dag_run = dag.create_dagrun( run_id=_generate_temporary_run_id(), logical_date=dag_run_logical_date, diff --git a/airflow-core/src/airflow/configuration.py b/airflow-core/src/airflow/configuration.py index 2357f4433272b..61ba3341a7701 100644 --- a/airflow-core/src/airflow/configuration.py +++ b/airflow-core/src/airflow/configuration.py @@ -124,8 +124,7 @@ def expand_env_var(env_var: str | None) -> str | None: interpolated = os.path.expanduser(os.path.expandvars(str(env_var))) if interpolated == env_var: return interpolated - else: - env_var = interpolated + env_var = interpolated def run_command(command: str) -> str: @@ -1160,13 +1159,12 @@ def getboolean(self, section: str, key: str, **kwargs) -> bool: # type: ignore[ val = val.split("#")[0].strip() if val in ("t", "true", "1"): return True - elif val in ("f", "false", "0"): + if val in ("f", "false", "0"): return False - else: - raise AirflowConfigException( - f'Failed to convert value to bool. Please check "{key}" key in "{section}" section. ' - f'Current value: "{val}".' - ) + raise AirflowConfigException( + f'Failed to convert value to bool. Please check "{key}" key in "{section}" section. ' + f'Current value: "{val}".' + ) def getint(self, section: str, key: str, **kwargs) -> int: # type: ignore[override] val = self.get(section, key, _extra_stacklevel=1, **kwargs) @@ -2020,7 +2018,7 @@ def write_default_airflow_configuration_if_needed() -> AirflowConfigParser: f"but got a directory {airflow_config.__fspath__()!r}." ) raise IsADirectoryError(msg) - elif not airflow_config.exists(): + if not airflow_config.exists(): log.debug("Creating new Airflow config file in: %s", airflow_config.__fspath__()) config_directory = airflow_config.parent if not config_directory.exists(): diff --git a/airflow-core/src/airflow/dag_processing/bundles/base.py b/airflow-core/src/airflow/dag_processing/bundles/base.py index 4b3488ee2f18c..5d49bf43fd4a4 100644 --- a/airflow-core/src/airflow/dag_processing/bundles/base.py +++ b/airflow-core/src/airflow/dag_processing/bundles/base.py @@ -48,8 +48,7 @@ def get_bundle_storage_root_path(): if configured_location := conf.get("dag_processor", "dag_bundle_storage_path", fallback=None): return Path(configured_location) - else: - return Path(tempfile.gettempdir(), "airflow", "dag_bundles") + return Path(tempfile.gettempdir(), "airflow", "dag_bundles") STALE_BUNDLE_TRACKING_FOLDER = get_bundle_storage_root_path() / "_tracking" diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 53c8c56a6cc9d..25f9e2a73ed87 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -147,7 +147,7 @@ def _execute_callbacks( "Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!" ) # _execute_task_callbacks(dagbag, request) - elif isinstance(request, DagCallbackRequest): + if isinstance(request, DagCallbackRequest): _execute_dag_callbacks(dagbag, request, log) @@ -277,7 +277,7 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # if isinstance(msg, DagFileParsingResult): self.parsing_result = msg return - elif isinstance(msg, GetConnection): + if isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) diff --git a/airflow-core/src/airflow/example_dags/example_bash_decorator.py b/airflow-core/src/airflow/example_dags/example_bash_decorator.py index 7448299e1a801..89335c0686465 100644 --- a/airflow-core/src/airflow/example_dags/example_bash_decorator.py +++ b/airflow-core/src/airflow/example_dags/example_bash_decorator.py @@ -73,8 +73,7 @@ def this_will_skip() -> str: def sleep_in(day: str) -> str: if day in (WeekDay.SATURDAY, WeekDay.SUNDAY): return f"sleep {60 * 60}" - else: - raise AirflowSkipException("No sleeping in today!") + raise AirflowSkipException("No sleeping in today!") sleep_in(day="{{ dag_run.logical_date.strftime('%A').lower() }}") # [END howto_decorator_bash_conditional] diff --git a/airflow-core/src/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow-core/src/airflow/example_dags/example_branch_python_dop_operator_3.py index dd064366a5830..06c210676a856 100644 --- a/airflow-core/src/airflow/example_dags/example_branch_python_dop_operator_3.py +++ b/airflow-core/src/airflow/example_dags/example_branch_python_dop_operator_3.py @@ -39,8 +39,7 @@ def should_run(**kwargs) -> str: print(f"------------- exec dttm = {kwargs['logical_date']} and minute = {kwargs['logical_date'].minute}") if kwargs["logical_date"].minute % 2 == 0: return "empty_task_1" - else: - return "empty_task_2" + return "empty_task_2" with DAG( diff --git a/airflow-core/src/airflow/exceptions.py b/airflow-core/src/airflow/exceptions.py index 10d3cb63b6384..045f9647ade76 100644 --- a/airflow-core/src/airflow/exceptions.py +++ b/airflow-core/src/airflow/exceptions.py @@ -126,9 +126,9 @@ def _render_asset_key(key: AssetUniqueKey | AssetNameRef | AssetUriRef) -> str: if isinstance(key, AssetUniqueKey): return f"Asset(name={key.name!r}, uri={key.uri!r})" - elif isinstance(key, AssetNameRef): + if isinstance(key, AssetNameRef): return f"Asset.ref(name={key.name!r})" - elif isinstance(key, AssetUriRef): + if isinstance(key, AssetUriRef): return f"Asset.ref(uri={key.uri!r})" return repr(key) # Should not happen, but let's fails more gracefully in an exception. diff --git a/airflow-core/src/airflow/executors/executor_loader.py b/airflow-core/src/airflow/executors/executor_loader.py index 4c3f7d70392a9..1f895fee5203f 100644 --- a/airflow-core/src/airflow/executors/executor_loader.py +++ b/airflow-core/src/airflow/executors/executor_loader.py @@ -108,10 +108,9 @@ def _get_executor_names(cls) -> list[ExecutorName]: "Incorrectly formatted executor configuration. Second portion of an executor " f"configuration must be a module path but received: {module_path}" ) - else: - executor_names_per_team.append( - ExecutorName(alias=split_name[0], module_path=split_name[1], team_id=team_id) - ) + executor_names_per_team.append( + ExecutorName(alias=split_name[0], module_path=split_name[1], team_id=team_id) + ) else: raise AirflowConfigException(f"Incorrectly formatted executor configuration: {name}") @@ -227,12 +226,11 @@ def lookup_executor_name_by_str(cls, executor_name_str: str) -> ExecutorName: if executor_name := _alias_to_executors.get(executor_name_str): return executor_name - elif executor_name := _module_to_executors.get(executor_name_str): + if executor_name := _module_to_executors.get(executor_name_str): return executor_name - elif executor_name := _classname_to_executors.get(executor_name_str): + if executor_name := _classname_to_executors.get(executor_name_str): return executor_name - else: - raise UnknownExecutorException(f"Unknown executor being loaded: {executor_name_str}") + raise UnknownExecutorException(f"Unknown executor being loaded: {executor_name_str}") @classmethod def load_executor(cls, executor_name: ExecutorName | str | None) -> BaseExecutor: diff --git a/airflow-core/src/airflow/executors/executor_utils.py b/airflow-core/src/airflow/executors/executor_utils.py index 11934970edbab..9b8a76435b3f3 100644 --- a/airflow-core/src/airflow/executors/executor_utils.py +++ b/airflow-core/src/airflow/executors/executor_utils.py @@ -52,8 +52,7 @@ def __eq__(self, other) -> bool: and self.team_id == other.team_id ): return True - else: - return False + return False def __hash__(self) -> int: """Implement hash.""" diff --git a/airflow-core/src/airflow/jobs/job.py b/airflow-core/src/airflow/jobs/job.py index 7233d4083a648..8697c4d2be259 100644 --- a/airflow-core/src/airflow/jobs/job.py +++ b/airflow-core/src/airflow/jobs/job.py @@ -289,12 +289,11 @@ def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None: def _heartrate(job_type: str) -> float: if job_type == "TriggererJob": return conf.getfloat("triggerer", "JOB_HEARTBEAT_SEC") - elif job_type == "SchedulerJob": + if job_type == "SchedulerJob": return conf.getfloat("scheduler", "SCHEDULER_HEARTBEAT_SEC") - else: - # Heartrate used to be hardcoded to scheduler, so in all other - # cases continue to use that value for back compat - return conf.getfloat("scheduler", "JOB_HEARTBEAT_SEC") + # Heartrate used to be hardcoded to scheduler, so in all other + # cases continue to use that value for back compat + return conf.getfloat("scheduler", "JOB_HEARTBEAT_SEC") @staticmethod def _is_alive( diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 8a84855e29169..c95303a480cbb 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -558,8 +558,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) - ) starved_tasks.add((task_instance.dag_id, task_instance.task_id)) continue - else: - executor_slots_available[executor_obj.name] -= 1 + executor_slots_available[executor_obj.name] -= 1 else: # This is a defensive guard for if we happen to have a task who's executor cannot be # found. The check in the dag parser should make this not realistically possible but the diff --git a/airflow-core/src/airflow/metrics/statsd_logger.py b/airflow-core/src/airflow/metrics/statsd_logger.py index aa400e033ccd5..8d47bc9ae132a 100644 --- a/airflow-core/src/airflow/metrics/statsd_logger.py +++ b/airflow-core/src/airflow/metrics/statsd_logger.py @@ -167,8 +167,7 @@ def get_statsd_logger(cls) -> SafeStatsdLogger: "Your custom StatsD client must extend the statsd.StatsClient in order to ensure " "backwards compatibility." ) - else: - log.info("Successfully loaded custom StatsD client") + log.info("Successfully loaded custom StatsD client") else: stats_class = StatsClient diff --git a/airflow-core/src/airflow/metrics/validators.py b/airflow-core/src/airflow/metrics/validators.py index d103656a58e30..85152426fdd17 100644 --- a/airflow-core/src/airflow/metrics/validators.py +++ b/airflow-core/src/airflow/metrics/validators.py @@ -259,8 +259,7 @@ class PatternAllowListValidator(ListValidator): def test(self, name: str) -> bool: if self.validate_list is not None: return super()._has_pattern_match(name) - else: - return True # default is all metrics are allowed + return True # default is all metrics are allowed class PatternBlockListValidator(ListValidator): @@ -269,5 +268,4 @@ class PatternBlockListValidator(ListValidator): def test(self, name: str) -> bool: if self.validate_list is not None: return not super()._has_pattern_match(name) - else: - return True # default is all metrics are allowed + return True # default is all metrics are allowed diff --git a/airflow-core/src/airflow/migrations/env.py b/airflow-core/src/airflow/migrations/env.py index 9bf9dfec0054b..dfb46ab6c37e0 100644 --- a/airflow-core/src/airflow/migrations/env.py +++ b/airflow-core/src/airflow/migrations/env.py @@ -36,8 +36,7 @@ def include_object(_, name, type_, *args): # Only create migrations for objects that are in the target metadata if type_ == "table" and name not in target_metadata.tables: return False - else: - return True + return True # Make sure everything is imported so that alembic can find it all diff --git a/airflow-core/src/airflow/migrations/versions/0042_3_0_0_add_uuid_primary_key_to_task_instance_.py b/airflow-core/src/airflow/migrations/versions/0042_3_0_0_add_uuid_primary_key_to_task_instance_.py index 41cfddc9cef0b..b45582c35919f 100644 --- a/airflow-core/src/airflow/migrations/versions/0042_3_0_0_add_uuid_primary_key_to_task_instance_.py +++ b/airflow-core/src/airflow/migrations/versions/0042_3_0_0_add_uuid_primary_key_to_task_instance_.py @@ -163,8 +163,7 @@ def _get_type_id_column(dialect_name: str) -> sa.types.TypeEngine: if dialect_name == "postgresql": return postgresql.UUID(as_uuid=False) # For other databases, use String(36) to match UUID format - else: - return sa.String(36) + return sa.String(36) def create_foreign_keys(): diff --git a/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py b/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py index e0af8511ca724..b5e4382564d70 100644 --- a/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py +++ b/airflow-core/src/airflow/migrations/versions/0068_3_0_0_ti_table_id_unique_per_try.py @@ -42,8 +42,7 @@ def _get_uuid_type(dialect_name: str) -> sa.types.TypeEngine: if dialect_name == "sqlite": return sa.String(36) - else: - return UUIDType(binary=False) + return UUIDType(binary=False) def upgrade(): diff --git a/airflow-core/src/airflow/models/asset.py b/airflow-core/src/airflow/models/asset.py index 8a56ea8ce8d71..17a90031da26c 100644 --- a/airflow-core/src/airflow/models/asset.py +++ b/airflow-core/src/airflow/models/asset.py @@ -211,8 +211,7 @@ def __eq__(self, other: object) -> bool: if isinstance(other, (self.__class__, AssetAlias)): return self.name == other.name - else: - return NotImplemented + return NotImplemented def to_public(self) -> AssetAlias: from airflow.sdk.definitions.asset import AssetAlias @@ -290,7 +289,7 @@ def from_public(cls, obj: Asset) -> AssetModel: def __init__(self, name: str = "", uri: str = "", **kwargs): if not name and not uri: raise TypeError("must provide either 'name' or 'uri'") - elif not name: + if not name: name = uri elif not uri: uri = name @@ -642,8 +641,7 @@ class AssetDagRunQueue(Base): def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return self.asset_id == other.asset_id and self.target_dag_id == other.target_dag_id - else: - return NotImplemented + return NotImplemented def __hash__(self): return hash(self.__mapper__.primary_key) diff --git a/airflow-core/src/airflow/models/backfill.py b/airflow-core/src/airflow/models/backfill.py index 2ca2e87396ceb..355a21de2b20a 100644 --- a/airflow-core/src/airflow/models/backfill.py +++ b/airflow-core/src/airflow/models/backfill.py @@ -297,34 +297,33 @@ def _create_backfill_dag_run( ) ) return - else: - lock = session.execute( - with_row_locks( - query=select(DagRun).where(DagRun.logical_date == info.logical_date), - session=session, - skip_locked=True, - ) + lock = session.execute( + with_row_locks( + query=select(DagRun).where(DagRun.logical_date == info.logical_date), + session=session, + skip_locked=True, ) - if lock: - _handle_clear_run( - session=session, - dag=dag, - dr=dr, - info=info, + ) + if lock: + _handle_clear_run( + session=session, + dag=dag, + dr=dr, + info=info, + backfill_id=backfill_id, + sort_ordinal=backfill_sort_ordinal, + ) + else: + session.add( + BackfillDagRun( backfill_id=backfill_id, + dag_run_id=None, + logical_date=info.logical_date, + exception_reason=BackfillDagRunExceptionReason.IN_FLIGHT, sort_ordinal=backfill_sort_ordinal, ) - else: - session.add( - BackfillDagRun( - backfill_id=backfill_id, - dag_run_id=None, - logical_date=info.logical_date, - exception_reason=BackfillDagRunExceptionReason.IN_FLIGHT, - sort_ordinal=backfill_sort_ordinal, - ) - ) - return + ) + return try: dr = dag.create_dagrun( diff --git a/airflow-core/src/airflow/models/base.py b/airflow-core/src/airflow/models/base.py index e9f86f8d7e672..c146f4619efb3 100644 --- a/airflow-core/src/airflow/models/base.py +++ b/airflow-core/src/airflow/models/base.py @@ -61,21 +61,20 @@ def get_id_collation_args(): collation = conf.get("database", "sql_engine_collation_for_ids", fallback=None) if collation: return {"collation": collation} - else: - # Automatically use utf8mb3_bin collation for mysql - # This is backwards-compatible. All our IDS are ASCII anyway so even if - # we migrate from previously installed database with different collation and we end up mixture of - # COLLATIONS, it's not a problem whatsoever (and we keep it small enough so that our indexes - # for MYSQL will not exceed the maximum index size. - # - # See https://github.com/apache/airflow/pull/17603#issuecomment-901121618. - # - # We cannot use session/dialect as at this point we are trying to determine the right connection - # parameters, so we use the connection - conn = conf.get("database", "sql_alchemy_conn", fallback="") - if conn.startswith(("mysql", "mariadb")): - return {"collation": "utf8mb3_bin"} - return {} + # Automatically use utf8mb3_bin collation for mysql + # This is backwards-compatible. All our IDS are ASCII anyway so even if + # we migrate from previously installed database with different collation and we end up mixture of + # COLLATIONS, it's not a problem whatsoever (and we keep it small enough so that our indexes + # for MYSQL will not exceed the maximum index size. + # + # See https://github.com/apache/airflow/pull/17603#issuecomment-901121618. + # + # We cannot use session/dialect as at this point we are trying to determine the right connection + # parameters, so we use the connection + conn = conf.get("database", "sql_alchemy_conn", fallback="") + if conn.startswith(("mysql", "mariadb")): + return {"collation": "utf8mb3_bin"} + return {} COLLATION_ARGS: dict[str, Any] = get_id_collation_args() diff --git a/airflow-core/src/airflow/models/baseoperator.py b/airflow-core/src/airflow/models/baseoperator.py index 0d0f8b92e2d24..4c766007142c2 100644 --- a/airflow-core/src/airflow/models/baseoperator.py +++ b/airflow-core/src/airflow/models/baseoperator.py @@ -495,8 +495,7 @@ def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: """Get list of the direct relatives to the current task, upstream or downstream.""" if upstream: return self.upstream_list - else: - return self.downstream_list + return self.downstream_list @staticmethod def xcom_push( diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 221b4fba99412..2727c926e0f2a 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -321,8 +321,7 @@ def get_password(self) -> str | None: f"FERNET_KEY configuration is missing" ) return fernet.decrypt(bytes(self._password, "utf-8")).decode() - else: - return self._password + return self._password def set_password(self, value: str | None): """Encrypt password and set in object attribute.""" @@ -481,8 +480,7 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection: if e.error.error == ErrorType.CONNECTION_NOT_FOUND: log.debug("Unable to retrieve connection from MetastoreBackend using Task SDK") raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") - else: - raise + raise # check cache first # enabled only if SecretCache.init() has been called first diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index 197176233030d..5e72f55f06600 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -176,7 +176,7 @@ def _get_model_data_interval( if end is not None: raise InconsistentDataInterval(instance, start_field_name, end_field_name) return None - elif end is None: + if end is None: raise InconsistentDataInterval(instance, start_field_name, end_field_name) return DataInterval(start, end) @@ -2504,8 +2504,7 @@ def calculate_dagrun_date_fields( "Passing a datetime to `DagModel.calculate_dagrun_date_fields` is not supported. " "Provide a data interval instead." ) - else: - last_automated_data_interval = last_automated_dag_run + last_automated_data_interval = last_automated_dag_run next_dagrun_info = dag.next_dagrun_info(last_automated_data_interval) if next_dagrun_info is None: self.next_dagrun_data_interval = self.next_dagrun = self.next_dagrun_create_after = None diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index 23b87a9cf2e05..393d01ce7a2c2 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -261,7 +261,7 @@ def get_dag(self, dag_id, session: Session = None): # If the source file no longer exports `dag_id`, delete it from self.dags if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]: return self.dags[dag_id] - elif dag_id in self.dags: + if dag_id in self.dags: del self.dags[dag_id] return self.dags.get(dag_id) diff --git a/airflow-core/src/airflow/models/dagcode.py b/airflow-core/src/airflow/models/dagcode.py index 54347df6084d3..2db338f87b50e 100644 --- a/airflow-core/src/airflow/models/dagcode.py +++ b/airflow-core/src/airflow/models/dagcode.py @@ -133,8 +133,7 @@ def _get_code_from_db(cls, dag_id, session: Session = NEW_SESSION) -> str: ) if not dag_code: raise DagCodeNotFound() - else: - code = dag_code.source_code + code = dag_code.source_code return code @staticmethod diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index df4ecf6f56e65..f902c5755f13b 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -126,10 +126,9 @@ def _creator_note(val): """Creator the ``note`` association proxy.""" if isinstance(val, str): return DagRunNote(content=val) - elif isinstance(val, dict): + if isinstance(val, dict): return DagRunNote(**val) - else: - return DagRunNote(*val) + return DagRunNote(*val) class DagRun(Base, LoggingMixin): diff --git a/airflow-core/src/airflow/models/renderedtifields.py b/airflow-core/src/airflow/models/renderedtifields.py index e60b4406b0893..c971f391d9e24 100644 --- a/airflow-core/src/airflow/models/renderedtifields.py +++ b/airflow-core/src/airflow/models/renderedtifields.py @@ -179,8 +179,7 @@ def get_templated_fields( if result: rendered_fields = result.rendered_fields return rendered_fields - else: - return None + return None @classmethod @provide_session diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index 34aa73336a5b9..dc45ffc8546a2 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -324,7 +324,7 @@ def _sort_serialized_dag_dict(cls, serialized_dag: Any): """Recursively sort json_dict and its nested dictionaries and lists.""" if isinstance(serialized_dag, dict): return {k: cls._sort_serialized_dag_dict(v) for k, v in sorted(serialized_dag.items())} - elif isinstance(serialized_dag, list): + if isinstance(serialized_dag, list): if all(isinstance(i, dict) for i in serialized_dag): if all( isinstance(i.get("__var", {}), Iterable) and "task_id" in i.get("__var", {}) diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 78ef893ee3a66..a43934fdf5433 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -309,9 +309,8 @@ def _run_raw_task( ti.clear_next_method_args() TaskInstance.save_to_db(ti=ti, session=session) return None - else: - ti.handle_failure(e, test_mode, context, session=session) - raise + ti.handle_failure(e, test_mode, context, session=session) + raise except SystemExit as e: # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`. # Therefore, here we must handle only error codes. @@ -589,10 +588,9 @@ def _creator_note(val): """Creator the ``note`` association proxy.""" if isinstance(val, str): return TaskInstanceNote(content=val) - elif isinstance(val, dict): + if isinstance(val, dict): return TaskInstanceNote(**val) - else: - return TaskInstanceNote(*val) + return TaskInstanceNote(*val) def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Operator): @@ -3502,7 +3500,7 @@ def duration_expression_update( ), } ) - elif bind.dialect.name == "postgresql": + if bind.dialect.name == "postgresql": return query.values( { "end_date": end_date, diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 3b8b00192ef51..b5c32904bd742 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -120,10 +120,8 @@ def setdefault(cls, key, default, description=None, deserialize_json=False): if default is not None: Variable.set(key=key, value=default, description=description, serialize_json=deserialize_json) return default - else: - raise ValueError("Default Value must be set") - else: - return obj + raise ValueError("Default Value must be set") + return obj @classmethod def get( @@ -169,16 +167,13 @@ def get( if var_val is None: if default_var is not cls.__NO_DEFAULT_SENTINEL: return default_var - else: - raise KeyError(f"Variable {key} does not exist") - else: - if deserialize_json: - obj = json.loads(var_val) - mask_secret(obj, key) - return obj - else: - mask_secret(var_val, key) - return var_val + raise KeyError(f"Variable {key} does not exist") + if deserialize_json: + obj = json.loads(var_val) + mask_secret(obj, key) + return obj + mask_secret(var_val, key) + return var_val @staticmethod def set( diff --git a/airflow-core/src/airflow/providers_manager.py b/airflow-core/src/airflow/providers_manager.py index 4a7f4fe9099cd..eb048dbd0edd3 100644 --- a/airflow-core/src/airflow/providers_manager.py +++ b/airflow-core/src/airflow/providers_manager.py @@ -72,8 +72,7 @@ def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any], conn_type: def ensure_prefix(field): if field not in conn_attrs and not field.startswith("extra__"): return f"extra__{conn_type}__{field}" - else: - return field + return field if "placeholders" in field_behaviors: placeholders = field_behaviors["placeholders"] diff --git a/airflow-core/src/airflow/secrets/base_secrets.py b/airflow-core/src/airflow/secrets/base_secrets.py index 329eb95cb3dae..a1f063969614a 100644 --- a/airflow-core/src/airflow/secrets/base_secrets.py +++ b/airflow-core/src/airflow/secrets/base_secrets.py @@ -63,8 +63,7 @@ def deserialize_connection(self, conn_id: str, value: str) -> Connection: value = value.strip() if value[0] == "{": return Connection.from_json(conn_id=conn_id, value=value) - else: - return Connection(conn_id=conn_id, uri=value) + return Connection(conn_id=conn_id, uri=value) def get_connection(self, conn_id: str) -> Connection | None: """ @@ -78,8 +77,7 @@ def get_connection(self, conn_id: str) -> Connection | None: if value: return self.deserialize_connection(conn_id=conn_id, value=value) - else: - return None + return None def get_variable(self, key: str) -> str | None: """ diff --git a/airflow-core/src/airflow/serialization/helpers.py b/airflow-core/src/airflow/serialization/helpers.py index e32b2cf0f2676..bb6be0d3d902e 100644 --- a/airflow-core/src/airflow/serialization/helpers.py +++ b/airflow-core/src/airflow/serialization/helpers.py @@ -45,9 +45,9 @@ def translate_tuples_to_lists(obj: Any): """Recursively convert tuples to lists.""" if isinstance(obj, tuple): return [translate_tuples_to_lists(item) for item in obj] - elif isinstance(obj, list): + if isinstance(obj, list): return [translate_tuples_to_lists(item) for item in obj] - elif isinstance(obj, dict): + if isinstance(obj, dict): return {key: translate_tuples_to_lists(value) for key, value in obj.items()} return obj @@ -65,17 +65,16 @@ def translate_tuples_to_lists(obj: Any): f"{rendered[: max_length - 79]!r}... " ) return serialized - else: - if not template_field and not isinstance(template_field, tuple): - # Avoid unnecessary serialization steps for empty fields unless they are tuples - # and need to be converted to lists - return template_field - template_field = translate_tuples_to_lists(template_field) - serialized = str(template_field) - if len(serialized) > max_length: - rendered = redact(serialized, name) - return ( - "Truncated. You can change this behaviour in [core]max_templated_field_length. " - f"{rendered[: max_length - 79]!r}... " - ) + if not template_field and not isinstance(template_field, tuple): + # Avoid unnecessary serialization steps for empty fields unless they are tuples + # and need to be converted to lists return template_field + template_field = translate_tuples_to_lists(template_field) + serialized = str(template_field) + if len(serialized) > max_length: + rendered = redact(serialized, name) + return ( + "Truncated. You can change this behaviour in [core]max_templated_field_length. " + f"{rendered[: max_length - 79]!r}... " + ) + return template_field diff --git a/airflow-core/src/airflow/serialization/serde.py b/airflow-core/src/airflow/serialization/serde.py index 7c0c1a0081aa7..0268ad91206d8 100644 --- a/airflow-core/src/airflow/serialization/serde.py +++ b/airflow-core/src/airflow/serialization/serde.py @@ -285,8 +285,7 @@ def _convert(old: dict) -> dict: # Return old style dicts directly as they do not need wrapping if old[OLD_TYPE] == OLD_DICT: return old[OLD_DATA] - else: - return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]} + return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]} return old diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index f395569b32e13..6bda483373b09 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -16,6 +16,8 @@ # under the License. """Serialized DAG and BaseOperator.""" +# TODO: update test_recursive_serialize_calls_must_forward_kwargs and re-enable RET505 +# ruff: noqa: RET505 from __future__ import annotations import collections.abc @@ -1394,7 +1396,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None: if k == "label": # Label shouldn't be set anymore -- it's computed from task_id now continue - elif k == "downstream_task_ids": + if k == "downstream_task_ids": v = set(v) elif k in {"retry_delay", "execution_timeout", "max_retry_delay"}: # If operator's execution_timeout is None and core.default_task_execution_timeout is not None, diff --git a/airflow-core/src/airflow/serialization/serializers/timezone.py b/airflow-core/src/airflow/serialization/serializers/timezone.py index 53920593143bc..9f2ef7cef65ac 100644 --- a/airflow-core/src/airflow/serialization/serializers/timezone.py +++ b/airflow-core/src/airflow/serialization/serializers/timezone.py @@ -92,10 +92,10 @@ def _get_tzinfo_name(tzinfo: datetime.tzinfo | None) -> str | None: if hasattr(tzinfo, "key"): # zoneinfo timezone return tzinfo.key - elif hasattr(tzinfo, "name"): + if hasattr(tzinfo, "name"): # Pendulum timezone return tzinfo.name - elif hasattr(tzinfo, "zone"): + if hasattr(tzinfo, "zone"): # pytz timezone return tzinfo.zone # type: ignore[no-any-return] diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index 49d74770269c3..1bb3389e1937a 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -231,8 +231,7 @@ def _get_async_conn_uri_from_sync(sync_uri): aiolib = AIO_LIBS_MAPPING.get(scheme) if aiolib: return f"{scheme}+{aiolib}:{rest}" - else: - return sync_uri + return sync_uri def configure_vars(): diff --git a/airflow-core/src/airflow/timetables/events.py b/airflow-core/src/airflow/timetables/events.py index 91e458b3af208..42b5d13e2ec78 100644 --- a/airflow-core/src/airflow/timetables/events.py +++ b/airflow-core/src/airflow/timetables/events.py @@ -115,10 +115,9 @@ def infer_manual_data_interval(self, *, run_after: DateTime) -> DataInterval: # or for the first event if all events are in the future if run_after < self.event_dates[0]: return DataInterval.exact(self.event_dates[0]) - else: - past_events = itertools.dropwhile(lambda when: when > run_after, self.event_dates[::-1]) - most_recent_event = next(past_events) - return DataInterval.exact(most_recent_event) + past_events = itertools.dropwhile(lambda when: when > run_after, self.event_dates[::-1]) + most_recent_event = next(past_events) + return DataInterval.exact(most_recent_event) def serialize(self): return { diff --git a/airflow-core/src/airflow/timetables/trigger.py b/airflow-core/src/airflow/timetables/trigger.py index 6ee32c4b99f9a..87e52a3f2640d 100644 --- a/airflow-core/src/airflow/timetables/trigger.py +++ b/airflow-core/src/airflow/timetables/trigger.py @@ -251,8 +251,7 @@ def _calc_first_run(self) -> DateTime: buffer_between_runs = max(gap_between_runs / 10, datetime.timedelta(minutes=5)) if gap_to_past <= buffer_between_runs: return past_run_time - else: - return next_run_time + return next_run_time class MultipleCronTriggerTimetable(Timetable): diff --git a/airflow-core/src/airflow/traces/otel_tracer.py b/airflow-core/src/airflow/traces/otel_tracer.py index 3184de8f02aa4..d5e71e3f47e05 100644 --- a/airflow-core/src/airflow/traces/otel_tracer.py +++ b/airflow-core/src/airflow/traces/otel_tracer.py @@ -366,14 +366,12 @@ def generate_span_id(self) -> int: id = self.span_id self.span_id = None return id - else: - new_id = random.getrandbits(64) - return new_id + new_id = random.getrandbits(64) + return new_id def generate_trace_id(self) -> int: if self.trace_id is not None: id = self.trace_id return id - else: - new_id = random.getrandbits(128) - return new_id + new_id = random.getrandbits(128) + return new_id diff --git a/airflow-core/src/airflow/utils/dag_cycle_tester.py b/airflow-core/src/airflow/utils/dag_cycle_tester.py index 95d06b5bb9e33..9348d9f7eea6d 100644 --- a/airflow-core/src/airflow/utils/dag_cycle_tester.py +++ b/airflow-core/src/airflow/utils/dag_cycle_tester.py @@ -48,7 +48,7 @@ def _check_adjacent_tasks(task_id, current_task): if visited[adjacent_task] == CYCLE_IN_PROGRESS: msg = f"Cycle detected in DAG: {dag.dag_id}. Faulty task: {task_id}" raise AirflowDagCycleException(msg) - elif visited[adjacent_task] == CYCLE_NEW: + if visited[adjacent_task] == CYCLE_NEW: return adjacent_task return None diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index a11ef22582bce..e4919d82bfd36 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -1565,7 +1565,7 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: if (row := self._session.execute(stmt.limit(1)).one_or_none()) is None: raise IndexError(key) return self._process_row(row) - elif isinstance(key, slice): + if isinstance(key, slice): # This implements the slicing syntax. We want to optimize negative # slicing (e.g. seq[-10:]) by not doing an additional COUNT query # if possible. We can do this unless the start and stop have diff --git a/airflow-core/src/airflow/utils/decorators.py b/airflow-core/src/airflow/utils/decorators.py index 69475dda84349..66f6bbbeb6ee0 100644 --- a/airflow-core/src/airflow/utils/decorators.py +++ b/airflow-core/src/airflow/utils/decorators.py @@ -39,9 +39,9 @@ def _is_task_decorator(self, decorator_node: cst.Decorator) -> bool: decorator_expr = decorator_node.decorator if isinstance(decorator_expr, cst.Name): return decorator_expr.value in self.decorators_to_remove - elif isinstance(decorator_expr, cst.Attribute) and isinstance(decorator_expr.value, cst.Name): + if isinstance(decorator_expr, cst.Attribute) and isinstance(decorator_expr.value, cst.Name): return f"{decorator_expr.value.value}.{decorator_expr.attr.value}" in self.decorators_to_remove - elif isinstance(decorator_expr, cst.Call): + if isinstance(decorator_expr, cst.Call): return self._is_task_decorator(cst.Decorator(decorator=decorator_expr.func)) return False diff --git a/airflow-core/src/airflow/utils/email.py b/airflow-core/src/airflow/utils/email.py index a33cdcc15dc36..b4c60350f5c1b 100644 --- a/airflow-core/src/airflow/utils/email.py +++ b/airflow-core/src/airflow/utils/email.py @@ -285,12 +285,11 @@ def get_email_address_list(addresses: str | Iterable[str]) -> list[str]: """ if isinstance(addresses, str): return _get_email_list_from_str(addresses) - elif isinstance(addresses, collections.abc.Iterable): + if isinstance(addresses, collections.abc.Iterable): if not all(isinstance(item, str) for item in addresses): raise TypeError("The items in your iterable must be strings.") return list(addresses) - else: - raise TypeError(f"Unexpected argument type: Received '{type(addresses).__name__}'.") + raise TypeError(f"Unexpected argument type: Received '{type(addresses).__name__}'.") def _get_smtp_connection(host: str, port: int, timeout: int, with_ssl: bool) -> smtplib.SMTP: @@ -305,18 +304,17 @@ def _get_smtp_connection(host: str, port: int, timeout: int, with_ssl: bool) -> """ if not with_ssl: return smtplib.SMTP(host=host, port=port, timeout=timeout) + ssl_context_string = conf.get("email", "SSL_CONTEXT") + if ssl_context_string == "default": + ssl_context = ssl.create_default_context() + elif ssl_context_string == "none": + ssl_context = None else: - ssl_context_string = conf.get("email", "SSL_CONTEXT") - if ssl_context_string == "default": - ssl_context = ssl.create_default_context() - elif ssl_context_string == "none": - ssl_context = None - else: - raise RuntimeError( - f"The email.ssl_context configuration variable must " - f"be set to 'default' or 'none' and is '{ssl_context_string}." - ) - return smtplib.SMTP_SSL(host=host, port=port, timeout=timeout, context=ssl_context) + raise RuntimeError( + f"The email.ssl_context configuration variable must " + f"be set to 'default' or 'none' and is '{ssl_context_string}." + ) + return smtplib.SMTP_SSL(host=host, port=port, timeout=timeout, context=ssl_context) def _get_email_list_from_str(addresses: str) -> list[str]: diff --git a/airflow-core/src/airflow/utils/file.py b/airflow-core/src/airflow/utils/file.py index ec4ab9421e768..12138043b85d9 100644 --- a/airflow-core/src/airflow/utils/file.py +++ b/airflow-core/src/airflow/utils/file.py @@ -143,8 +143,7 @@ def correct_maybe_zipped(fileloc: None | str | Path) -> None | str | Path: _, archive, _ = search_.groups() if archive and zipfile.is_zipfile(archive): return archive - else: - return fileloc + return fileloc def open_maybe_zipped(fileloc, mode="r"): @@ -159,8 +158,7 @@ def open_maybe_zipped(fileloc, mode="r"): _, archive, filename = ZIP_REGEX.search(fileloc).groups() if archive and zipfile.is_zipfile(archive): return TextIOWrapper(zipfile.ZipFile(archive, mode=mode).open(filename)) - else: - return open(fileloc, mode=mode) + return open(fileloc, mode=mode) def _find_path_from_directory( @@ -236,10 +234,9 @@ def find_path_from_directory( """ if ignore_file_syntax == "glob" or not ignore_file_syntax: return _find_path_from_directory(base_dir_path, ignore_file_name, _GlobIgnoreRule) - elif ignore_file_syntax == "regexp": + if ignore_file_syntax == "regexp": return _find_path_from_directory(base_dir_path, ignore_file_name, _RegexpIgnoreRule) - else: - raise ValueError(f"Unsupported ignore_file_syntax: {ignore_file_syntax}") + raise ValueError(f"Unsupported ignore_file_syntax: {ignore_file_syntax}") def list_py_file_paths( diff --git a/airflow-core/src/airflow/utils/helpers.py b/airflow-core/src/airflow/utils/helpers.py index 665793c89baf0..74b7dacd33de7 100644 --- a/airflow-core/src/airflow/utils/helpers.py +++ b/airflow-core/src/airflow/utils/helpers.py @@ -124,8 +124,7 @@ def parse_template_string(template_string: str) -> tuple[str, None] | tuple[None if "{{" in template_string: # jinja mode return None, jinja2.Template(template_string) - else: - return template_string, None + return template_string, None @cache @@ -136,17 +135,16 @@ def log_filename_template_renderer() -> Callable[..., str]: import jinja2 return jinja2.Template(template).render - else: - def f_str_format(ti: TaskInstance, try_number: int | None = None): - return template.format( - dag_id=ti.dag_id, - task_id=ti.task_id, - logical_date=ti.logical_date.isoformat(), - try_number=try_number or ti.try_number, - ) + def f_str_format(ti: TaskInstance, try_number: int | None = None): + return template.format( + dag_id=ti.dag_id, + task_id=ti.task_id, + logical_date=ti.logical_date.isoformat(), + try_number=try_number or ti.try_number, + ) - return f_str_format + return f_str_format def render_log_filename(ti: TaskInstance, try_number, filename_template) -> str: @@ -276,8 +274,7 @@ def at_most_one(*args) -> bool: def is_set(val): if val is NOTSET: return False - else: - return bool(val) + return bool(val) return sum(map(is_set, args)) in (0, 1) @@ -294,7 +291,7 @@ def prune_dict(val: Any, mode="strict"): def is_empty(x): if mode == "strict": return x is None - elif mode == "truthy": + if mode == "truthy": return bool(x) is False raise ValueError("allowable values for `mode` include 'truthy' and 'strict'") @@ -303,27 +300,26 @@ def is_empty(x): for k, v in val.items(): if is_empty(v): continue - elif isinstance(v, (list, dict)): + if isinstance(v, (list, dict)): new_val = prune_dict(v, mode=mode) if not is_empty(new_val): new_dict[k] = new_val else: new_dict[k] = v return new_dict - elif isinstance(val, list): + if isinstance(val, list): new_list = [] for v in val: if is_empty(v): continue - elif isinstance(v, (list, dict)): + if isinstance(v, (list, dict)): new_val = prune_dict(v, mode=mode) if not is_empty(new_val): new_list.append(new_val) else: new_list.append(v) return new_list - else: - return val + return val def prevent_duplicates(kwargs1: dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None: diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py b/airflow-core/src/airflow/utils/log/file_task_handler.py index b27cec6c0e494..5314aa859c13f 100644 --- a/airflow-core/src/airflow/utils/log/file_task_handler.py +++ b/airflow-core/src/airflow/utils/log/file_task_handler.py @@ -341,8 +341,7 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO logical_date=date, try_number=try_number, ) - else: - raise RuntimeError(f"Unable to render log filename for {ti}. This should never happen") + raise RuntimeError(f"Unable to render log filename for {ti}. This should never happen") def _get_executor_get_task_log( self, ti: TaskInstance diff --git a/airflow-core/src/airflow/utils/retries.py b/airflow-core/src/airflow/utils/retries.py index 809d176ef6c8e..e885eaededcc4 100644 --- a/airflow-core/src/airflow/utils/retries.py +++ b/airflow-core/src/airflow/utils/retries.py @@ -109,5 +109,4 @@ def wrapped_function(*args, **kwargs): # Allow using decorator with and without arguments if _func is None: return retry_decorator - else: - return retry_decorator(_func) + return retry_decorator(_func) diff --git a/airflow-core/src/airflow/utils/session.py b/airflow-core/src/airflow/utils/session.py index 52a7cfa8e8b2d..e6b04f06461de 100644 --- a/airflow-core/src/airflow/utils/session.py +++ b/airflow-core/src/airflow/utils/session.py @@ -97,9 +97,8 @@ def provide_session(func: Callable[PS, RT]) -> Callable[PS, RT]: def wrapper(*args, **kwargs) -> RT: if "session" in kwargs or session_args_idx < len(args): return func(*args, **kwargs) - else: - with create_session() as session: - return func(*args, session=session, **kwargs) + with create_session() as session: + return func(*args, session=session, **kwargs) return wrapper diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py b/airflow-core/src/airflow/utils/sqlalchemy.py index 84eb8a6f98cf1..9e8cf12fba26c 100644 --- a/airflow-core/src/airflow/utils/sqlalchemy.py +++ b/airflow-core/src/airflow/utils/sqlalchemy.py @@ -70,9 +70,9 @@ def process_bind_param(self, value, dialect): if value is None: return None raise TypeError(f"expected datetime.datetime, not {value!r}") - elif value.tzinfo is None: + if value.tzinfo is None: raise ValueError("naive datetime is disallowed") - elif dialect.name == "mysql": + if dialect.name == "mysql": # For mysql versions prior 8.0.19 we should send timestamps as naive values in UTC # see: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-literals.html return make_naive(value, timezone=utc) @@ -116,8 +116,7 @@ class ExtendedJSON(TypeDecorator): def load_dialect_impl(self, dialect) -> TypeEngine: if dialect.name == "postgresql": return dialect.type_descriptor(JSONB) - else: - return dialect.type_descriptor(JSON) + return dialect.type_descriptor(JSON) def process_bind_param(self, value, dialect): from airflow.serialization.serialized_objects import BaseSerialization @@ -164,13 +163,13 @@ def sanitize_for_serialization(obj: V1Pod): """ if obj is None: return None - elif isinstance(obj, (float, bool, bytes, str, int)): + if isinstance(obj, (float, bool, bytes, str, int)): return obj - elif isinstance(obj, list): + if isinstance(obj, list): return [sanitize_for_serialization(sub_obj) for sub_obj in obj] - elif isinstance(obj, tuple): + if isinstance(obj, tuple): return tuple(sanitize_for_serialization(sub_obj) for sub_obj in obj) - elif isinstance(obj, (datetime.datetime, datetime.date)): + if isinstance(obj, (datetime.datetime, datetime.date)): return obj.isoformat() if isinstance(obj, dict): @@ -288,11 +287,10 @@ def compare_values(self, x, y): """ if self.comparator: return self.comparator(x, y) - else: - try: - return x == y - except AttributeError: - return False + try: + return x == y + except AttributeError: + return False def nulls_first(col, session: Session) -> dict[str, Any]: @@ -305,8 +303,7 @@ def nulls_first(col, session: Session) -> dict[str, Any]: """ if session.bind.dialect.name == "postgresql": return nullsfirst(col) - else: - return col + return col USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True) diff --git a/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_between_tasks.py b/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_between_tasks.py index 6489f37c2e6dc..3bdd0d3314eda 100644 --- a/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_between_tasks.py +++ b/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_between_tasks.py @@ -129,10 +129,9 @@ def paused_task(): logger.info("Task has been paused.") time.sleep(1) continue - else: - logger.info("Resuming task execution.") - # Break the loop and finish with the task execution. - break + logger.info("Resuming task execution.") + # Break the loop and finish with the task execution. + break # Cleanup the control file. if os.path.exists(control_file): diff --git a/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py b/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py index 0ff64064e01fc..92c174fb5547c 100644 --- a/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py +++ b/airflow-core/tests/integration/otel/dags/otel_test_dag_with_pause_in_task.py @@ -65,10 +65,9 @@ def task1(ti): logger.info("Task has been paused.") time.sleep(1) continue - else: - logger.info("Resuming task execution.") - # Break the loop and finish with the task execution. - break + logger.info("Resuming task execution.") + # Break the loop and finish with the task execution. + break otel_task_tracer = otel_tracer.get_otel_tracer_for_task(Trace) tracer_provider = otel_task_tracer.get_otel_tracer_provider() diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index 0f933a700f53c..0b8f1a4cce22c 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -1009,9 +1009,8 @@ def test_scheduler_change_in_the_middle_of_first_task_until_the_end( if "pause" in file_contents: log.info("Control file exists and the task has been paused.") break - else: - time.sleep(1) - continue + time.sleep(1) + continue except FileNotFoundError: print("Control file not found. Waiting...") time.sleep(3) @@ -1126,9 +1125,8 @@ def test_scheduler_exits_gracefully_in_the_middle_of_the_first_task( if "pause" in file_contents: log.info("Control file exists and the task has been paused.") break - else: - time.sleep(1) - continue + time.sleep(1) + continue except FileNotFoundError: print("Control file not found. Waiting...") time.sleep(3) @@ -1227,9 +1225,8 @@ def test_scheduler_exits_forcefully_in_the_middle_of_the_first_task( if "pause" in file_contents: log.info("Control file exists and the task has been paused.") break - else: - time.sleep(1) - continue + time.sleep(1) + continue except FileNotFoundError: print("Control file not found. Waiting...") time.sleep(3) @@ -1332,9 +1329,8 @@ def test_scheduler_exits_forcefully_after_the_first_task_finishes( if "pause" in file_contents: log.info("Control file exists and the task has been paused.") break - else: - time.sleep(1) - continue + time.sleep(1) + continue except FileNotFoundError: print("Control file not found. Waiting...") time.sleep(3) diff --git a/airflow-core/tests/unit/always/test_providers_manager.py b/airflow-core/tests/unit/always/test_providers_manager.py index d28b5bf971a0c..4f2de48b70bac 100644 --- a/airflow-core/tests/unit/always/test_providers_manager.py +++ b/airflow-core/tests/unit/always/test_providers_manager.py @@ -249,10 +249,9 @@ def test_hook_values(self): # When there is error importing provider that is excluded the provider name is in the message if any(excluded_provider in record.message for excluded_provider in excluded_providers): continue - else: - print(record.message, file=sys.stderr) - print(record.exc_info, file=sys.stderr) - real_warning_count += 1 + print(record.message, file=sys.stderr) + print(record.exc_info, file=sys.stderr) + real_warning_count += 1 if real_warning_count: raise AssertionError("There are warnings generated during hook imports. Please fix them") assert [w.message for w in warning_records if "hook-class-names" in str(w.message)] == [] diff --git a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py index e6b5017c257a0..a5658623a700c 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py +++ b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py @@ -354,8 +354,7 @@ def side_effect_func( ): if not details: return False - else: - return access_per_dag.get(details.id, False) + return access_per_dag.get(details.id, False) auth_manager.is_authorized_dag = MagicMock(side_effect=side_effect_func) user = Mock() diff --git a/airflow-core/tests/unit/assets/test_evaluation.py b/airflow-core/tests/unit/assets/test_evaluation.py index e71635f402d1a..4abf42c3e0ea6 100644 --- a/airflow-core/tests/unit/assets/test_evaluation.py +++ b/airflow-core/tests/unit/assets/test_evaluation.py @@ -186,7 +186,7 @@ class _AssetEvaluator(AssetEvaluator): # Can't use mock because AssetEvaluator def _resolve_asset_alias(self, o): if o is asset_alias_1: return [] - elif o is resolved_asset_alias_2: + if o is resolved_asset_alias_2: return [asset] return super()._resolve_asset_alias(o) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 1f4b90e5f97fa..2f9975403c797 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -5662,8 +5662,7 @@ def side_effect(*args, **kwargs): if call_count < retry_times - 1: call_count += 1 raise OperationalError("any_statement", "any_params", "any_orig") - else: - return session.execute(*args, **kwargs) + return session.execute(*args, **kwargs) return side_effect diff --git a/airflow-core/tests/unit/models/test_backfill.py b/airflow-core/tests/unit/models/test_backfill.py index a67e3d958d835..d238d5245446d 100644 --- a/airflow-core/tests/unit/models/test_backfill.py +++ b/airflow-core/tests/unit/models/test_backfill.py @@ -326,9 +326,8 @@ def create_next_run( .limit(1) ) return next_run - else: - dr = dag_maker.create_dagrun(logical_date=next_date, run_id="second_run") - return dr + dr = dag_maker.create_dagrun(logical_date=next_date, run_id="second_run") + return dr @pytest.mark.parametrize("is_backfill", [True, False]) diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py b/airflow-core/tests/unit/models/test_mappedoperator.py index cd7b9852a08e0..7f2f3770bb4d6 100644 --- a/airflow-core/tests/unit/models/test_mappedoperator.py +++ b/airflow-core/tests/unit/models/test_mappedoperator.py @@ -479,8 +479,7 @@ def inner(*args, **kwargs): kwargs.update(python_callable=failure_callable()) if partial: return PythonOperator.partial(**kwargs) - else: - return PythonOperator(**kwargs) + return PythonOperator(**kwargs) @pytest.mark.parametrize("type_", ["taskflow", "classic"]) def test_one_to_many_work_failed(self, type_, dag_maker): @@ -745,7 +744,7 @@ def other_teardown(): def my_setup(val): if val == "data2.json": raise ValueError("fail!") - elif val == "data3.json": + if val == "data3.json": raise AirflowSkipException("skip!") print(f"setup: {val}") return val @@ -788,7 +787,7 @@ def other_teardown(): def my_setup_callable(val): if val == "data2.json": raise ValueError("fail!") - elif val == "data3.json": + if val == "data3.json": raise AirflowSkipException("skip!") print(f"setup: {val}") return val @@ -960,7 +959,7 @@ def test_mapped_task_group_simple(self, type_, dag_maker, session): def my_setup(val): if val == "data2.json": raise ValueError("fail!") - elif val == "data3.json": + if val == "data3.json": raise AirflowSkipException("skip!") print(f"setup: {val}") @@ -987,7 +986,7 @@ def file_transforms(filename): def my_setup_callable(val): if val == "data2.json": raise ValueError("fail!") - elif val == "data3.json": + if val == "data3.json": raise AirflowSkipException("skip!") print(f"setup: {val}") @@ -1038,7 +1037,7 @@ def my_setup(val): def my_work(val): if val == "data2.json": raise ValueError("fail!") - elif val == "data3.json": + if val == "data3.json": raise AirflowSkipException("skip!") print(f"work: {val}") @@ -1062,7 +1061,7 @@ def my_work(vals): val = vals[0] if val == "data2.json": raise ValueError("fail!") - elif val == "data3.json": + if val == "data3.json": raise AirflowSkipException("skip!") print(f"work: {val}") @@ -1106,7 +1105,7 @@ def test_teardown_many_one_explicit(self, type_, dag_maker): def my_setup(val): if val == "data2.json": raise ValueError("fail!") - elif val == "data3.json": + if val == "data3.json": raise AirflowSkipException("skip!") print(f"setup: {val}") return val @@ -1128,7 +1127,7 @@ def my_teardown(val): def my_setup_callable(val): if val == "data2.json": raise ValueError("fail!") - elif val == "data3.json": + if val == "data3.json": raise AirflowSkipException("skip!") print(f"setup: {val}") return val diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 349b211dd4fc0..ed10101e08fee 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -4861,22 +4861,21 @@ def _get_lazy_xcom_access_expected_sql_lines() -> list[str]: "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.`key` = 'xxx'", ] - elif backend == "postgres": + if backend == "postgres": return [ "SELECT xcom.value", "FROM xcom", "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.key = 'xxx'", ] - elif backend == "sqlite": + if backend == "sqlite": return [ "SELECT xcom.value", "FROM xcom", "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' " "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.\"key\" = 'xxx'", ] - else: - raise RuntimeError(f"unknown backend {backend!r}") + raise RuntimeError(f"unknown backend {backend!r}") def test_expand_non_templated_field(dag_maker, session): diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index fe52ea522a5c8..f70cd7e5338a0 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -1002,8 +1002,7 @@ def test_dag_params_roundtrip(self, val, expected_val): dag = DAG(dag_id="simple_dag", schedule=None, params=val) # further tests not relevant return - else: - dag = DAG(dag_id="simple_dag", schedule=None, params=val) + dag = DAG(dag_id="simple_dag", schedule=None, params=val) BaseOperator(task_id="simple_task", dag=dag, start_date=datetime(2019, 8, 1)) serialized_dag_json = SerializedDAG.to_json(dag) @@ -1093,15 +1092,14 @@ def test_task_params_roundtrip(self, val, expected_val): ) # further tests not relevant return - else: - BaseOperator( - task_id="simple_task", - dag=dag, - params=val, - start_date=datetime(2019, 8, 1), - ) - serialized_dag = SerializedDAG.to_dict(dag) - deserialized_dag = SerializedDAG.from_dict(serialized_dag) + BaseOperator( + task_id="simple_task", + dag=dag, + params=val, + start_date=datetime(2019, 8, 1), + ) + serialized_dag = SerializedDAG.to_dict(dag) + deserialized_dag = SerializedDAG.from_dict(serialized_dag) if val: assert "params" in serialized_dag["dag"]["tasks"][0]["__var"] diff --git a/airflow-core/tests/unit/ti_deps/deps/test_mapped_task_upstream_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_mapped_task_upstream_dep.py index 82aba1a0fbb78..ad8dfa081ad90 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_mapped_task_upstream_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_mapped_task_upstream_dep.py @@ -376,7 +376,7 @@ def test_upstream_mapped_expanded( def m1(x): if x == 0 and upstream_instance_state == FAILED: raise AirflowFailException() - elif x == 0 and upstream_instance_state == SKIPPED: + if x == 0 and upstream_instance_state == SKIPPED: raise AirflowSkipException() return x diff --git a/airflow-core/tests/unit/utils/test_task_group.py b/airflow-core/tests/unit/utils/test_task_group.py index d95d326869881..38db9cd9dd2d5 100644 --- a/airflow-core/tests/unit/utils/test_task_group.py +++ b/airflow-core/tests/unit/utils/test_task_group.py @@ -45,13 +45,11 @@ def make_task(name, type_="classic"): if type_ == "classic": return BashOperator(task_id=name, bash_command="echo 1") - else: + @task_decorator + def my_task(): + pass - @task_decorator - def my_task(): - pass - - return my_task.override(task_id=name)() + return my_task.override(task_id=name)() EXPECTED_JSON_LEGACY = { diff --git a/airflow-ctl/src/airflowctl/api/client.py b/airflow-ctl/src/airflowctl/api/client.py index 756c4a2c5f00b..71ab54f217995 100644 --- a/airflow-ctl/src/airflowctl/api/client.py +++ b/airflow-ctl/src/airflowctl/api/client.py @@ -142,8 +142,7 @@ def load(self) -> Credentials: self.api_url = credentials["api_url"] self.api_token = keyring.get_password("airflowctl", f"api_token-{self.api_environment}") return self - else: - raise AirflowCtlNotFoundException(f"No credentials found in {default_config_dir}") + raise AirflowCtlNotFoundException(f"No credentials found in {default_config_dir}") class BearerAuth(httpx.Auth): diff --git a/airflow-ctl/src/airflowctl/api/operations.py b/airflow-ctl/src/airflowctl/api/operations.py index 0ac1b3427fc3d..01af7c75e5950 100644 --- a/airflow-ctl/src/airflowctl/api/operations.py +++ b/airflow-ctl/src/airflowctl/api/operations.py @@ -101,8 +101,7 @@ def wrapped(self, *args, **kwargs): try: if self.exit_in_error: return _exit_if_server_response_error(response=func(self, *args, **kwargs)) - else: - return func(self, *args, **kwargs) + return func(self, *args, **kwargs) except httpx.ConnectError as e: raise e diff --git a/dev/airflow-github b/dev/airflow-github index cef3ac69b6394..79a47fe64588a 100755 --- a/dev/airflow-github +++ b/dev/airflow-github @@ -94,10 +94,9 @@ def get_commit_in_main_associated_with_pr(repo: git.Repo, issue: Issue) -> str | if commit_line and commit_line.endswith(f"(#{issue.number})"): return commit_line.split(" ")[0] return None - else: - pr: PullRequest = issue.as_pull_request() - if pr.is_merged(): - return pr.merge_commit_sha + pr: PullRequest = issue.as_pull_request() + if pr.is_merged(): + return pr.merge_commit_sha return None @@ -186,7 +185,7 @@ def is_core_commit(files: list[str]) -> bool: if file.startswith(ignore): break # Handle renaming. Form: {old_name => new_name}somename.py - elif file.startswith("{"): + if file.startswith("{"): new_files = file[1:].split(" => ") if any(n.strip().startswith(ignore) for n in new_files): break diff --git a/dev/breeze/src/airflow_breeze/commands/ci_image_commands.py b/dev/breeze/src/airflow_breeze/commands/ci_image_commands.py index fa2d03b010dfb..4170cbbdd46a6 100644 --- a/dev/breeze/src/airflow_breeze/commands/ci_image_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/ci_image_commands.py @@ -212,12 +212,11 @@ def get_exitcode(status: int) -> int: # but until then we need to do this ugly conversion if os.WIFSIGNALED(status): return -os.WTERMSIG(status) - elif os.WIFEXITED(status): + if os.WIFEXITED(status): return os.WEXITSTATUS(status) - elif os.WIFSTOPPED(status): + if os.WIFSTOPPED(status): return -os.WSTOPSIG(status) - else: - return 1 + return 1 option_upgrade_to_newer_dependencies = click.option( @@ -770,26 +769,24 @@ def should_we_run_the_build(build_ci_params: BuildCiParams) -> bool: if answer == answer.YES: if is_repo_rebased(build_ci_params.github_repository, build_ci_params.airflow_branch): return True - else: - get_console().print( - "\n[warning]This might take a lot of time (more than 10 minutes) even if you have " - "a good network connection. We think you should attempt to rebase first.[/]\n" - ) - answer = user_confirm( - "But if you really, really want - you can attempt it. Are you really sure?", - timeout=STANDARD_TIMEOUT, - default_answer=Answer.NO, - ) - if answer == Answer.YES: - return True - else: - get_console().print( - f"[info]Please rebase your code to latest {build_ci_params.airflow_branch} " - "before continuing.[/]\nCheck this link to find out how " - "https://github.com/apache/airflow/blob/main/contributing-docs/10_working_with_git.rst\n" - ) - get_console().print("[error]Exiting the process[/]\n") - sys.exit(1) + get_console().print( + "\n[warning]This might take a lot of time (more than 10 minutes) even if you have " + "a good network connection. We think you should attempt to rebase first.[/]\n" + ) + answer = user_confirm( + "But if you really, really want - you can attempt it. Are you really sure?", + timeout=STANDARD_TIMEOUT, + default_answer=Answer.NO, + ) + if answer == Answer.YES: + return True + get_console().print( + f"[info]Please rebase your code to latest {build_ci_params.airflow_branch} " + "before continuing.[/]\nCheck this link to find out how " + "https://github.com/apache/airflow/blob/main/contributing-docs/10_working_with_git.rst\n" + ) + get_console().print("[error]Exiting the process[/]\n") + sys.exit(1) elif answer == Answer.NO: instruct_build_image(build_ci_params.python) return False diff --git a/dev/breeze/src/airflow_breeze/commands/developer_commands.py b/dev/breeze/src/airflow_breeze/commands/developer_commands.py index a7295eec231b0..67927a2788323 100644 --- a/dev/breeze/src/airflow_breeze/commands/developer_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/developer_commands.py @@ -1092,9 +1092,8 @@ def find_airflow_container() -> str | None: # On docker-compose v1 we get '--------' as output here stop_exec_on_error(docker_compose_ps_command.returncode) return container_running - else: - stop_exec_on_error(1) - return None + stop_exec_on_error(1) + return None @main.command( diff --git a/dev/breeze/src/airflow_breeze/commands/kubernetes_commands.py b/dev/breeze/src/airflow_breeze/commands/kubernetes_commands.py index 8619c5e6f55fd..fc377f37a46cc 100644 --- a/dev/breeze/src/airflow_breeze/commands/kubernetes_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/kubernetes_commands.py @@ -291,12 +291,11 @@ def _create_cluster( num_tries -= 1 if num_tries == 0: return result.returncode, f"K8S cluster {cluster_name}." - else: - get_console(output=output).print( - f"[warning]Failed to create KinD cluster {cluster_name}. " - f"Retrying! There are {num_tries} tries left.\n" - ) - _delete_cluster(python=python, kubernetes_version=kubernetes_version, output=output) + get_console(output=output).print( + f"[warning]Failed to create KinD cluster {cluster_name}. " + f"Retrying! There are {num_tries} tries left.\n" + ) + _delete_cluster(python=python, kubernetes_version=kubernetes_version, output=output) @kubernetes_group.command( @@ -458,8 +457,7 @@ def _get_python_kubernetes_version_from_name(cluster_name: str) -> tuple[str | N python = cluster_match.group(1) kubernetes_version = cluster_match.group(2) return python, kubernetes_version - else: - return None, None + return None, None LIST_CONSOLE_WIDTH = 120 diff --git a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py index 3efdba0b0d127..11bb1d0a33b24 100644 --- a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py @@ -2925,9 +2925,8 @@ def modify_single_file_constraints( constraints_file.write_text(constraint_content) get_console().print("[success]Updated.[/]") return True - else: - get_console().print("[warning]The file has not been modified.[/]") - return False + get_console().print("[warning]The file has not been modified.[/]") + return False def modify_all_constraint_files( @@ -2959,10 +2958,9 @@ def confirm_modifications(constraints_repo: Path) -> bool: confirm = user_confirm("Do you want to continue?") if confirm == Answer.YES: return True - elif confirm == Answer.NO: + if confirm == Answer.NO: return False - else: - sys.exit(1) + sys.exit(1) def commit_constraints_and_tag(constraints_repo: Path, airflow_version: str, commit_message: str) -> None: diff --git a/dev/breeze/src/airflow_breeze/commands/sbom_commands.py b/dev/breeze/src/airflow_breeze/commands/sbom_commands.py index e56f83015ff88..2f2374946350a 100644 --- a/dev/breeze/src/airflow_breeze/commands/sbom_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/sbom_commands.py @@ -226,9 +226,8 @@ def _dir_exists_warn_and_should_skip(dir: Path, force: bool) -> bool: if not force: get_console().print(f"[warning]The {dir} already exists. Skipping") return True - else: - get_console().print(f"[warning]The {dir} already exists. Forcing update") - return False + get_console().print(f"[warning]The {dir} already exists. Forcing update") + return False return False apache_airflow_documentation_directory = airflow_site_archive_directory / "apache-airflow" @@ -793,8 +792,7 @@ def export_dependency_information( def sort_deps_key(dependency: dict[str, Any]) -> str: if dependency.get("Vcs"): return "0:" + dependency["Name"] - else: - return "1:" + dependency["Name"] + return "1:" + dependency["Name"] def convert_all_sbom_to_value_dictionaries( diff --git a/dev/breeze/src/airflow_breeze/commands/setup_commands.py b/dev/breeze/src/airflow_breeze/commands/setup_commands.py index 55878e5db0360..9014c069eb1fe 100644 --- a/dev/breeze/src/airflow_breeze/commands/setup_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/setup_commands.py @@ -416,9 +416,8 @@ def write_to_shell(command_to_execute: str, script_path: str, force_setup: bool) "You can force autocomplete installation by adding --force[/]\n" ) return False - else: - backup(script_path_file) - remove_autogenerated_code(script_path) + backup(script_path_file) + remove_autogenerated_code(script_path) text = "" if script_path_file.exists(): get_console().print(f"\nModifying the {script_path} file!\n") diff --git a/dev/breeze/src/airflow_breeze/params/build_prod_params.py b/dev/breeze/src/airflow_breeze/params/build_prod_params.py index fc525edc08a60..2663a0aac4bae 100644 --- a/dev/breeze/src/airflow_breeze/params/build_prod_params.py +++ b/dev/breeze/src/airflow_breeze/params/build_prod_params.py @@ -60,8 +60,7 @@ class BuildProdParams(CommonBuildParams): def airflow_version(self) -> str: if self.install_airflow_version: return self.install_airflow_version - else: - return self._get_version_with_suffix() + return self._get_version_with_suffix() @property def airflow_semver_version(self) -> str: diff --git a/dev/breeze/src/airflow_breeze/params/common_build_params.py b/dev/breeze/src/airflow_breeze/params/common_build_params.py index b206d595b2353..711448cd680e0 100644 --- a/dev/breeze/src/airflow_breeze/params/common_build_params.py +++ b/dev/breeze/src/airflow_breeze/params/common_build_params.py @@ -152,8 +152,7 @@ def _build_arg(self, name: str, value: Any, optional: bool): if value is None or "": if optional: return - else: - raise ValueError(f"Value for {name} cannot be empty or None") + raise ValueError(f"Value for {name} cannot be empty or None") if value is True: str_value = "true" elif value is False: diff --git a/dev/breeze/src/airflow_breeze/params/shell_params.py b/dev/breeze/src/airflow_breeze/params/shell_params.py index 499f6387a77df..d171e94c69a38 100644 --- a/dev/breeze/src/airflow_breeze/params/shell_params.py +++ b/dev/breeze/src/airflow_breeze/params/shell_params.py @@ -699,12 +699,11 @@ def _generate_env_for_docker_compose_file_if_needed(env: dict[str, str]): # we check if the set of env variables had not changed since last run # if so - cool, we do not need to do anything else return - else: - if get_verbose(): - get_console().print( - f"[info]The keys has changed vs last run. Regenerating[/]: " - f"{GENERATED_DOCKER_ENV_PATH} and {GENERATED_DOCKER_COMPOSE_ENV_PATH}" - ) + if get_verbose(): + get_console().print( + f"[info]The keys has changed vs last run. Regenerating[/]: " + f"{GENERATED_DOCKER_ENV_PATH} and {GENERATED_DOCKER_COMPOSE_ENV_PATH}" + ) if get_verbose(): get_console().print(f"[info]Generating new docker env file [/]: {GENERATED_DOCKER_ENV_PATH}") GENERATED_DOCKER_ENV_PATH.write_text("\n".join(sorted(env.keys()))) diff --git a/dev/breeze/src/airflow_breeze/prepare_providers/provider_distributions.py b/dev/breeze/src/airflow_breeze/prepare_providers/provider_distributions.py index 80685c2be85d5..cccfcbf02fe55 100644 --- a/dev/breeze/src/airflow_breeze/prepare_providers/provider_distributions.py +++ b/dev/breeze/src/airflow_breeze/prepare_providers/provider_distributions.py @@ -294,6 +294,6 @@ def get_packages_list_to_act_on( and (package.strip() not in removed_provider_ids or include_removed) and (package.strip() not in not_ready_provider_ids or include_not_ready) ] - elif provider_distributions: + if provider_distributions: return list(provider_distributions) return get_available_distributions(include_removed=include_removed, include_not_ready=include_not_ready) diff --git a/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py b/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py index 4f9c59214dc71..a8ade2388ecc1 100644 --- a/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py +++ b/dev/breeze/src/airflow_breeze/prepare_providers/provider_documentation.py @@ -389,10 +389,9 @@ def _get_all_changes_for_package( if not only_min_version_update: _print_changes_table(changes_table) return False, [array_of_changes], changes_table - else: - if not only_min_version_update: - get_console().print(f"[info]No changes for {provider_id}") - return False, [], "" + if not only_min_version_update: + get_console().print(f"[info]No changes for {provider_id}") + return False, [], "" if len(provider_details.versions) == 1: get_console().print( f"[info]The provider '{provider_id}' has never been released but it is ready to release!\n" @@ -754,7 +753,7 @@ def update_release_notes( if answer == Answer.NO: get_console().print(f"\n[warning]Skipping provider: {provider_id} on user request![/]\n") raise PrepareReleaseDocsUserSkippedException() - elif answer == Answer.QUIT: + if answer == Answer.QUIT: raise PrepareReleaseDocsUserQuitException() elif not list_of_list_of_changes: get_console().print( diff --git a/dev/breeze/src/airflow_breeze/utils/cache.py b/dev/breeze/src/airflow_breeze/utils/cache.py index 61eff9419fc9b..e7f618df81042 100644 --- a/dev/breeze/src/airflow_breeze/utils/cache.py +++ b/dev/breeze/src/airflow_breeze/utils/cache.py @@ -41,8 +41,7 @@ def read_from_cache_file(param_name: str) -> str | None: cache_exists = check_if_cache_exists(param_name) if cache_exists: return (Path(BUILD_CACHE_PATH) / f".{param_name}").read_text().strip() - else: - return None + return None def touch_cache_file(param_name: str, root_dir: Path = BUILD_CACHE_PATH): diff --git a/dev/breeze/src/airflow_breeze/utils/cdxgen.py b/dev/breeze/src/airflow_breeze/utils/cdxgen.py index c2907f36ab01d..f8822250dc244 100644 --- a/dev/breeze/src/airflow_breeze/utils/cdxgen.py +++ b/dev/breeze/src/airflow_breeze/utils/cdxgen.py @@ -213,8 +213,7 @@ def get_requirements_for_provider( f"Provider requirements already existed, skipped generation for {provider_id} version " f"{provider_version} python {python_version}", ) - else: - provider_folder_path.mkdir(exist_ok=True) + provider_folder_path.mkdir(exist_ok=True) command = f""" mkdir -pv {DOCKER_FILE_PREFIX} diff --git a/dev/breeze/src/airflow_breeze/utils/coertions.py b/dev/breeze/src/airflow_breeze/utils/coertions.py index 6f8c2c21baac8..2c505301d6e0b 100644 --- a/dev/breeze/src/airflow_breeze/utils/coertions.py +++ b/dev/breeze/src/airflow_breeze/utils/coertions.py @@ -23,10 +23,9 @@ def coerce_bool_value(value: str | bool) -> bool: if isinstance(value, bool): return value - elif not value: # handle "" and other false-y coerce-able values + if not value: # handle "" and other false-y coerce-able values return False - else: - return value[0].lower() in ["t", "y"] # handle all kinds of truth-y/yes-y/false-y/non-sy strings + return value[0].lower() in ["t", "y"] # handle all kinds of truth-y/yes-y/false-y/non-sy strings def one_or_none_set(iterable: Iterable[bool]) -> bool: diff --git a/dev/breeze/src/airflow_breeze/utils/confirm.py b/dev/breeze/src/airflow_breeze/utils/confirm.py index f80670d420afd..602b6f078d68d 100644 --- a/dev/breeze/src/airflow_breeze/utils/confirm.py +++ b/dev/breeze/src/airflow_breeze/utils/confirm.py @@ -77,16 +77,14 @@ def user_confirm( if user_status == "": if default_answer: return default_answer - else: - continue + continue if user_status.upper() in ["Y", "YES"]: return Answer.YES - elif user_status.upper() in ["N", "NO"]: + if user_status.upper() in ["N", "NO"]: return Answer.NO - elif user_status.upper() in ["Q", "QUIT"] and quit_allowed: + if user_status.upper() in ["Q", "QUIT"] and quit_allowed: return Answer.QUIT - else: - print(f"Wrong answer given {user_status}. Should be one of {allowed_answers}. Try again.") + print(f"Wrong answer given {user_status}. Should be one of {allowed_answers}. Try again.") except TimeoutOccurred: if default_answer: return default_answer @@ -107,7 +105,7 @@ def confirm_action( answer = user_confirm(message, timeout, default_answer, quit_allowed) if answer == Answer.YES: return True - elif abort: + if abort: sys.exit(1) elif answer == Answer.QUIT: sys.exit(1) diff --git a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py index 926e52e6f2428..5e3402931f5ce 100644 --- a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py +++ b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py @@ -836,12 +836,11 @@ def enter_shell(shell_params: ShellParams, output: Output | None = None) -> RunC ) if command_result.returncode == 0: return command_result - else: - get_console().print(f"[red]Error {command_result.returncode} returned[/]") - if get_verbose(): - get_console().print(command_result.stderr) - notify_on_unhealthy_backend_container(shell_params.project_name, shell_params.backend, output) - return command_result + get_console().print(f"[red]Error {command_result.returncode} returned[/]") + if get_verbose(): + get_console().print(command_result.stderr) + notify_on_unhealthy_backend_container(shell_params.project_name, shell_params.backend, output) + return command_result def notify_on_unhealthy_backend_container(project_name: str, backend: str, output: Output | None = None): diff --git a/dev/breeze/src/airflow_breeze/utils/docs_publisher.py b/dev/breeze/src/airflow_breeze/utils/docs_publisher.py index 7ac8fd1a46c1d..09cb20e7ea46a 100644 --- a/dev/breeze/src/airflow_breeze/utils/docs_publisher.py +++ b/dev/breeze/src/airflow_breeze/utils/docs_publisher.py @@ -52,8 +52,7 @@ def _build_dir(self) -> str: if self.is_versioned: version = "stable" return f"{GENERATED_PATH}/_build/docs/{self.package_name}/{version}" - else: - return f"{GENERATED_PATH}/_build/docs/{self.package_name}" + return f"{GENERATED_PATH}/_build/docs/{self.package_name}" @property def _current_version(self): @@ -76,8 +75,7 @@ def _current_version(self): def _publish_dir(self) -> str: if self.is_versioned: return f"docs-archive/{self.package_name}/{self._current_version}" - else: - return f"docs-archive/{self.package_name}" + return f"docs-archive/{self.package_name}" def publish(self, override_versioned: bool, airflow_site_dir: str): """Copy documentation packages files to airflow-site repository.""" diff --git a/dev/breeze/src/airflow_breeze/utils/image.py b/dev/breeze/src/airflow_breeze/utils/image.py index cd2301ccf487e..0dcde4c33c775 100644 --- a/dev/breeze/src/airflow_breeze/utils/image.py +++ b/dev/breeze/src/airflow_breeze/utils/image.py @@ -69,8 +69,7 @@ def run_pull_in_parallel( def get_right_method() -> Callable[..., tuple[int, str]]: if verify: return run_pull_and_verify_image - else: - return run_pull_image + return run_pull_image def get_kwds(index: int, image_param: BuildCiParams | BuildProdParams): d = { @@ -166,11 +165,10 @@ def run_pull_image( ) return 1, f"Image Python {image_params.python}" continue - else: - get_console(output=output).print( - f"\n[error]There was an error pulling the image {image_params.python}. Failing.[/]\n" - ) - return command_result.returncode, f"Image Python {image_params.python}" + get_console(output=output).print( + f"\n[error]There was an error pulling the image {image_params.python}. Failing.[/]\n" + ) + return command_result.returncode, f"Image Python {image_params.python}" def run_pull_and_verify_image( diff --git a/dev/breeze/src/airflow_breeze/utils/kubernetes_utils.py b/dev/breeze/src/airflow_breeze/utils/kubernetes_utils.py index 94f9093a3751f..9d8249be3bcdd 100644 --- a/dev/breeze/src/airflow_breeze/utils/kubernetes_utils.py +++ b/dev/breeze/src/airflow_breeze/utils/kubernetes_utils.py @@ -146,11 +146,10 @@ def _download_tool_if_needed( f"{K8S_BIN_BASE_PATH}" ) return - else: - get_console().print( - f"[info]Currently installed `{tool}` tool version: {current_version}. " - f"Downloading {expected_version}." - ) + get_console().print( + f"[info]Currently installed `{tool}` tool version: {current_version}. " + f"Downloading {expected_version}." + ) else: get_console().print( f"[warning]The version check of `{tool}` tool returned " @@ -407,11 +406,10 @@ def _attempt_to_connect(port_number: int, output: Output | None, wait_seconds: i f"http://localhost:{port_number}/api/v2/monitor/health and it is healthy." ) return True - else: - get_console(output=output).print( - f"[warning]Error when connecting to localhost:{port_number} " - f"{response.status_code}: {response.reason}" - ) + get_console(output=output).print( + f"[warning]Error when connecting to localhost:{port_number} " + f"{response.status_code}: {response.reason}" + ) current_time = datetime.now(timezone.utc) if current_time - start_time > timedelta(seconds=wait_seconds): if wait_seconds > 0: diff --git a/dev/breeze/src/airflow_breeze/utils/md5_build_check.py b/dev/breeze/src/airflow_breeze/utils/md5_build_check.py index c40d59321ecab..c2bb827544033 100644 --- a/dev/breeze/src/airflow_breeze/utils/md5_build_check.py +++ b/dev/breeze/src/airflow_breeze/utils/md5_build_check.py @@ -158,13 +158,12 @@ def md5sum_check_if_build_is_needed( get_console().print(f" * [info]{file}[/]") get_console().print("\n[warning]Likely CI image needs rebuild[/]\n") return True - else: - if build_ci_params.skip_image_upgrade_check: - return False - get_console().print( - "[info]Docker image build is not needed for CI build as no important files are changed! " - "You can add --force-build to force it[/]" - ) + if build_ci_params.skip_image_upgrade_check: + return False + get_console().print( + "[info]Docker image build is not needed for CI build as no important files are changed! " + "You can add --force-build to force it[/]" + ) return False diff --git a/dev/breeze/src/airflow_breeze/utils/packages.py b/dev/breeze/src/airflow_breeze/utils/packages.py index 0dcaac4166a07..f0a8751c1496d 100644 --- a/dev/breeze/src/airflow_breeze/utils/packages.py +++ b/dev/breeze/src/airflow_breeze/utils/packages.py @@ -377,12 +377,11 @@ def get_short_package_names(long_form_providers: Iterable[str]) -> tuple[str, .. def get_short_package_name(long_form_provider: str) -> str: if long_form_provider in REGULAR_DOC_PACKAGES: return long_form_provider - else: - if not long_form_provider.startswith(LONG_PROVIDERS_PREFIX): - raise ValueError( - f"Invalid provider name: {long_form_provider}. Should start with {LONG_PROVIDERS_PREFIX}" - ) - return long_form_provider[len(LONG_PROVIDERS_PREFIX) :].replace("-", ".") + if not long_form_provider.startswith(LONG_PROVIDERS_PREFIX): + raise ValueError( + f"Invalid provider name: {long_form_provider}. Should start with {LONG_PROVIDERS_PREFIX}" + ) + return long_form_provider[len(LONG_PROVIDERS_PREFIX) :].replace("-", ".") def find_matching_long_package_names( @@ -641,10 +640,8 @@ def format_version_suffix(version_suffix: str) -> str: if version_suffix: if version_suffix[0] == "." or version_suffix[0] == "+": return version_suffix - else: - return f".{version_suffix}" - else: - return "" + return f".{version_suffix}" + return "" def get_provider_jinja_context( diff --git a/dev/breeze/src/airflow_breeze/utils/parallel.py b/dev/breeze/src/airflow_breeze/utils/parallel.py index 11038db457a2c..1d6af2d174865 100644 --- a/dev/breeze/src/airflow_breeze/utils/parallel.py +++ b/dev/breeze/src/airflow_breeze/utils/parallel.py @@ -192,8 +192,7 @@ def get_best_matching_lines(self, output: Output) -> list[str] | None: if self.matcher_for_joined_line is not None and previous_line is not None: list_to_return: list[str] = [previous_line, best_line] return list_to_return - else: - self.last_good_match[output.file_name] = best_line + self.last_good_match[output.file_name] = best_line last_match = self.last_good_match.get(output.file_name) if last_match is None: return None diff --git a/dev/breeze/src/airflow_breeze/utils/run_tests.py b/dev/breeze/src/airflow_breeze/utils/run_tests.py index 80d30c5f3654e..629d0e0a608e5 100644 --- a/dev/breeze/src/airflow_breeze/utils/run_tests.py +++ b/dev/breeze/src/airflow_breeze/utils/run_tests.py @@ -259,8 +259,7 @@ def convert_test_type_to_pytest_args( helm_folder = TEST_GROUP_TO_TEST_FOLDERS[test_group][0] if test_type and test_type != ALL_TEST_TYPE: return [f"{helm_folder}/tests/helm_tests/{test_type}"] - else: - return [helm_folder] + return [helm_folder] if test_type == SelectiveCoreTestType.OTHER.value and test_group == GroupOfTests.CORE: return find_all_other_tests() if test_group in [ diff --git a/dev/breeze/src/airflow_breeze/utils/run_utils.py b/dev/breeze/src/airflow_breeze/utils/run_utils.py index 1ff6ef30ba5c9..d8a372517823b 100644 --- a/dev/breeze/src/airflow_breeze/utils/run_utils.py +++ b/dev/breeze/src/airflow_breeze/utils/run_utils.py @@ -386,8 +386,7 @@ def commit_sha(): command_result = run_command(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=False) if command_result.stdout: return command_result.stdout.strip() - else: - return "COMMIT_SHA_NOT_FOUND" + return "COMMIT_SHA_NOT_FOUND" def check_if_image_exists(image: str) -> bool: @@ -414,31 +413,30 @@ def _run_compile_internally( text=True, env=env, ) - else: - compile_lock.parent.mkdir(parents=True, exist_ok=True) - compile_lock.unlink(missing_ok=True) - try: - with SoftFileLock(compile_lock, timeout=5): - with open(asset_out, "w") as output_file: - result = run_command( - command_to_execute, - check=False, - no_output_dump_on_exception=True, - text=True, - env=env, - stderr=subprocess.STDOUT, - stdout=output_file, - ) - if result.returncode == 0: - asset_out.unlink(missing_ok=True) - return result - except Timeout: - get_console().print("[error]Another asset compilation is running. Exiting[/]\n") - get_console().print("[warning]If you are sure there is no other compilation,[/]") - get_console().print("[warning]Remove the lock file and re-run compilation:[/]") - get_console().print(compile_lock) - get_console().print() - sys.exit(1) + compile_lock.parent.mkdir(parents=True, exist_ok=True) + compile_lock.unlink(missing_ok=True) + try: + with SoftFileLock(compile_lock, timeout=5): + with open(asset_out, "w") as output_file: + result = run_command( + command_to_execute, + check=False, + no_output_dump_on_exception=True, + text=True, + env=env, + stderr=subprocess.STDOUT, + stdout=output_file, + ) + if result.returncode == 0: + asset_out.unlink(missing_ok=True) + return result + except Timeout: + get_console().print("[error]Another asset compilation is running. Exiting[/]\n") + get_console().print("[warning]If you are sure there is no other compilation,[/]") + get_console().print("[warning]Remove the lock file and re-run compilation:[/]") + get_console().print(compile_lock) + get_console().print() + sys.exit(1) def kill_process_group(gid: int): diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py index bfe6f3251069e..5eb9973729f73 100644 --- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py +++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py @@ -666,11 +666,8 @@ def _should_be_run(self, source_area: FileGroupForCi) -> bool: f"[warning]{source_area} enabled because it matched {len(matched_files)} changed files[/]" ) return True - else: - get_console().print( - f"[warning]{source_area} disabled because it did not match any changed files[/]" - ) - return False + get_console().print(f"[warning]{source_area} disabled because it did not match any changed files[/]") + return False @cached_property def mypy_checks(self) -> list[str]: @@ -881,25 +878,22 @@ def _get_providers_test_types_to_run(self, split_to_individual_providers: bool = if self.full_tests_needed or self.run_task_sdk_tests: if split_to_individual_providers: return list(providers_test_type()) - else: - return ["Providers"] - else: - all_providers_source_files = self._matching_files( - FileGroupForCi.ALL_PROVIDERS_PYTHON_FILES, CI_FILE_GROUP_MATCHES - ) - assets_source_files = self._matching_files(FileGroupForCi.ASSET_FILES, CI_FILE_GROUP_MATCHES) + return ["Providers"] + all_providers_source_files = self._matching_files( + FileGroupForCi.ALL_PROVIDERS_PYTHON_FILES, CI_FILE_GROUP_MATCHES + ) + assets_source_files = self._matching_files(FileGroupForCi.ASSET_FILES, CI_FILE_GROUP_MATCHES) - if ( - len(all_providers_source_files) == 0 - and len(assets_source_files) == 0 - and not self.needs_api_tests - ): - # IF API tests are needed, that will trigger extra provider checks - return [] - else: - affected_providers = self._find_all_providers_affected( - include_docs=False, - ) + if ( + len(all_providers_source_files) == 0 + and len(assets_source_files) == 0 + and not self.needs_api_tests + ): + # IF API tests are needed, that will trigger extra provider checks + return [] + affected_providers = self._find_all_providers_affected( + include_docs=False, + ) candidate_test_types: set[str] = set() if isinstance(affected_providers, AllProvidersSentinel): if split_to_individual_providers: @@ -1419,30 +1413,27 @@ def only_new_ui_files(self) -> bool: if all_source_files and new_ui_source_files and not remaining_files: return True - else: - return False + return False @cached_property def testable_core_integrations(self) -> list[str]: if not self.run_tests: return [] - else: - return [ - integration - for integration in TESTABLE_CORE_INTEGRATIONS - if integration not in DISABLE_TESTABLE_INTEGRATIONS_FROM_CI - ] + return [ + integration + for integration in TESTABLE_CORE_INTEGRATIONS + if integration not in DISABLE_TESTABLE_INTEGRATIONS_FROM_CI + ] @cached_property def testable_providers_integrations(self) -> list[str]: if not self.run_tests: return [] - else: - return [ - integration - for integration in TESTABLE_PROVIDERS_INTEGRATIONS - if integration not in DISABLE_TESTABLE_INTEGRATIONS_FROM_CI - ] + return [ + integration + for integration in TESTABLE_PROVIDERS_INTEGRATIONS + if integration not in DISABLE_TESTABLE_INTEGRATIONS_FROM_CI + ] @cached_property def is_committer_build(self): diff --git a/dev/breeze/src/airflow_breeze/utils/version_utils.py b/dev/breeze/src/airflow_breeze/utils/version_utils.py index ce36425b46ed1..3c51cda77453b 100644 --- a/dev/breeze/src/airflow_breeze/utils/version_utils.py +++ b/dev/breeze/src/airflow_breeze/utils/version_utils.py @@ -62,15 +62,13 @@ def get_package_version_suffix(version_suffix_for_pypi: str, version_suffix_for_ # if there is a PyPi version suffix, return the combined version. Otherwise just return the local version. if version_suffix_for_pypi: return version_suffix_for_pypi + version_suffix_for_local - else: - return version_suffix_for_local + return version_suffix_for_local def remove_local_version_suffix(version_suffix: str) -> str: if "+" in version_suffix: return version_suffix.split("+")[0] - else: - return version_suffix + return version_suffix def is_local_package_version(version_suffix: str) -> bool: @@ -87,5 +85,4 @@ def is_local_package_version(version_suffix: str) -> bool: """ if version_suffix and ("+" in version_suffix): return True - else: - return False + return False diff --git a/dev/stats/get_important_pr_candidates.py b/dev/stats/get_important_pr_candidates.py index 708ca2791f046..2bf039d14234d 100755 --- a/dev/stats/get_important_pr_candidates.py +++ b/dev/stats/get_important_pr_candidates.py @@ -171,8 +171,7 @@ def num_changed_files(self) -> float: def body_length(self) -> int: if self.pull_request.body is not None: return len(self.pull_request.body) - else: - return 0 + return 0 @cached_property def num_additions(self) -> int: @@ -265,13 +264,12 @@ def __str__(self) -> str: f'"{self.pull_request.title}". ' f"Merged at {self.pull_request.merged_at}: {self.pull_request.html_url}" ) - else: - return ( - f"Score: {self.score:.2f}: PR{self.pull_request.number}" - f"by @{self.pull_request.user.login}: " - f'"{self.pull_request.title}". ' - f"Merged at {self.pull_request.merged_at}: {self.pull_request.html_url}" - ) + return ( + f"Score: {self.score:.2f}: PR{self.pull_request.number}" + f"by @{self.pull_request.user.login}: " + f'"{self.pull_request.title}". ' + f"Merged at {self.pull_request.merged_at}: {self.pull_request.html_url}" + ) def verboseStr(self) -> str: if self.tagged_protm: diff --git a/devel-common/src/sphinx_exts/docs_build/docs_builder.py b/devel-common/src/sphinx_exts/docs_build/docs_builder.py index 70220e324087e..1a5d46c6d7f60 100644 --- a/devel-common/src/sphinx_exts/docs_build/docs_builder.py +++ b/devel-common/src/sphinx_exts/docs_build/docs_builder.py @@ -82,8 +82,7 @@ def _build_dir(self) -> Path: if self.is_versioned: version = "stable" return GENERATED_PATH / "_build" / "docs" / self.package_name / version - else: - return GENERATED_PATH / "_build" / "docs" / self.package_name + return GENERATED_PATH / "_build" / "docs" / self.package_name @property def log_spelling_filename(self) -> Path: @@ -109,18 +108,17 @@ def log_build_warning_filename(self) -> Path: def _src_dir(self) -> Path: if self.package_name == "helm-chart": return AIRFLOW_CONTENT_ROOT_PATH / "chart" / "docs" - elif self.package_name == "apache-airflow": + if self.package_name == "apache-airflow": return AIRFLOW_CONTENT_ROOT_PATH / "airflow-core" / "docs" - elif self.package_name == "docker-stack": + if self.package_name == "docker-stack": return AIRFLOW_CONTENT_ROOT_PATH / "docker-stack-docs" - elif self.package_name == "apache-airflow-providers": + if self.package_name == "apache-airflow-providers": return AIRFLOW_CONTENT_ROOT_PATH / "providers-summary-docs" - elif self.package_name.startswith("apache-airflow-providers-"): + if self.package_name.startswith("apache-airflow-providers-"): package_paths = self.package_name[len("apache-airflow-providers-") :].split("-") return (AIRFLOW_CONTENT_ROOT_PATH / "providers").joinpath(*package_paths) / "docs" - else: - console.print(f"[red]Unknown package name: {self.package_name}") - sys.exit(1) + console.print(f"[red]Unknown package name: {self.package_name}") + sys.exit(1) @property def _generated_api_dir(self) -> Path: @@ -300,8 +298,7 @@ def get_available_providers_distributions(include_suspended: bool = False): def get_short_form(package_name: str) -> str | None: if package_name.startswith("apache-airflow-providers-"): return package_name.replace("apache-airflow-providers-", "").replace("-", ".") - else: - return None + return None def get_long_form(package_name: str) -> str | None: diff --git a/devel-common/src/sphinx_exts/operators_and_hooks_ref.py b/devel-common/src/sphinx_exts/operators_and_hooks_ref.py index 6c3506a066c49..2730df37baed5 100644 --- a/devel-common/src/sphinx_exts/operators_and_hooks_ref.py +++ b/devel-common/src/sphinx_exts/operators_and_hooks_ref.py @@ -226,10 +226,9 @@ def _get_decorator_details(decorator): def get_full_name(node): if isinstance(node, ast.Attribute): return f"{get_full_name(node.value)}.{node.attr}" - elif isinstance(node, ast.Name): + if isinstance(node, ast.Name): return node.id - else: - return ast.dump(node) + return ast.dump(node) def eval_node(node): try: @@ -242,12 +241,11 @@ def eval_node(node): args = [eval_node(arg) for arg in decorator.args] kwargs = {kw.arg: eval_node(kw.value) for kw in decorator.keywords if kw.arg != "category"} return name, args, kwargs - elif isinstance(decorator, ast.Name): + if isinstance(decorator, ast.Name): return decorator.id, [], {} - elif isinstance(decorator, ast.Attribute): + if isinstance(decorator, ast.Attribute): return decorator.attr, [], {} - else: - return decorator, [], {} + return decorator, [], {} def _iter_module_for_deprecations(ast_node, file_path, class_name=None) -> list[dict[str, Any]]: diff --git a/devel-common/src/sphinx_exts/removemarktransform.py b/devel-common/src/sphinx_exts/removemarktransform.py index 6e53dff2ffa58..bb65d026bce93 100644 --- a/devel-common/src/sphinx_exts/removemarktransform.py +++ b/devel-common/src/sphinx_exts/removemarktransform.py @@ -58,7 +58,7 @@ def is_pycode(node: nodes.literal_block) -> bool: language = node.get("language") if language in ("py", "py3", "python", "python3", "default"): return True - elif language == "guess": + if language == "guess": try: lexer = guess_lexer(node.rawsource) return isinstance(lexer, (PythonLexer, Python3Lexer)) diff --git a/devel-common/src/tests_common/_internals/capture_warnings.py b/devel-common/src/tests_common/_internals/capture_warnings.py index 0d873e61db6b7..cc17ae5cf0a11 100644 --- a/devel-common/src/tests_common/_internals/capture_warnings.py +++ b/devel-common/src/tests_common/_internals/capture_warnings.py @@ -122,9 +122,9 @@ def group(self) -> str: """ if "/tests/" in self.filename: return "tests" - elif self.filename.startswith("airflow/"): + if self.filename.startswith("airflow/"): return "airflow" - elif self.filename.startswith("providers/"): + if self.filename.startswith("providers/"): return "providers" return "other" diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 3ae9c7dffc956..6de117a1f7973 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -563,11 +563,9 @@ def skip_db_test(item): if next(item.iter_markers(name="non_db_test_override"), None): # non_db_test can override the db_test set for example on module or class level return - else: - pytest.skip( - f"The test is skipped as it is DB test " - f"and --skip-db-tests is flag is passed to pytest. {item}" - ) + pytest.skip( + f"The test is skipped as it is DB test and --skip-db-tests is flag is passed to pytest. {item}" + ) if next(item.iter_markers(name="backend"), None): # also automatically skip tests marked with `backend` marker as they are implicitly # db tests @@ -582,14 +580,13 @@ def only_run_db_test(item): ): # non_db_test at individual level can override the db_test set for example on module or class level return - else: - if next(item.iter_markers(name="backend"), None): - # Also do not skip the tests marked with `backend` marker - as it is implicitly a db test - return - pytest.skip( - f"The test is skipped as it is not a DB tests " - f"and --run-db-tests-only flag is passed to pytest. {item}" - ) + if next(item.iter_markers(name="backend"), None): + # Also do not skip the tests marked with `backend` marker - as it is implicitly a db test + return + pytest.skip( + f"The test is skipped as it is not a DB tests " + f"and --run-db-tests-only flag is passed to pytest. {item}" + ) def skip_if_integration_disabled(marker, item): diff --git a/devel-common/src/tests_common/test_utils/compat.py b/devel-common/src/tests_common/test_utils/compat.py index 1be9dcecb88cc..59e5fbdaf1a86 100644 --- a/devel-common/src/tests_common/test_utils/compat.py +++ b/devel-common/src/tests_common/test_utils/compat.py @@ -112,10 +112,9 @@ def deserialize_operator(serialized_operator: dict[str, Any]) -> Operator: from airflow.serialization.serialized_objects import BaseSerialization return BaseSerialization.deserialize(serialized_operator) - else: - from airflow.serialization.serialized_objects import SerializedBaseOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator - return SerializedBaseOperator.deserialize_operator(serialized_operator) + return SerializedBaseOperator.deserialize_operator(serialized_operator) def connection_to_dict( diff --git a/devel-common/src/tests_common/test_utils/logging_command_executor.py b/devel-common/src/tests_common/test_utils/logging_command_executor.py index d41a3ea1ad9dc..27802e2c86e57 100644 --- a/devel-common/src/tests_common/test_utils/logging_command_executor.py +++ b/devel-common/src/tests_common/test_utils/logging_command_executor.py @@ -37,23 +37,22 @@ def execute_cmd(self, cmd, silent=False, cwd=None, env=None): return subprocess.call( args=cmd, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, env=env, cwd=cwd ) - else: - self.log.info("Executing: '%s'", " ".join(shlex.quote(c) for c in cmd)) - with subprocess.Popen( - args=cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - cwd=cwd, - env=env, - ) as process: - output, err = process.communicate() - retcode = process.poll() - self.log.info("Stdout: %s", output) - self.log.info("Stderr: %s", err) - if retcode: - self.log.error("Error when executing %s", " ".join(shlex.quote(c) for c in cmd)) - return retcode + self.log.info("Executing: '%s'", " ".join(shlex.quote(c) for c in cmd)) + with subprocess.Popen( + args=cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + cwd=cwd, + env=env, + ) as process: + output, err = process.communicate() + retcode = process.poll() + self.log.info("Stdout: %s", output) + self.log.info("Stderr: %s", err) + if retcode: + self.log.error("Error when executing %s", " ".join(shlex.quote(c) for c in cmd)) + return retcode def check_output(self, cmd): self.log.info("Executing for output: '%s'", " ".join(shlex.quote(c) for c in cmd)) @@ -87,24 +86,23 @@ def execute_cmd(self, cmd, silent=False, cwd=None, env=None): return subprocess.call( args=cmd, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, env=env, cwd=cwd ) - else: - self.log.info("Executing: '%s'", " ".join(shlex.quote(c) for c in cmd)) - with subprocess.Popen( - args=cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - cwd=cwd, - env=env, - ) as process: - output, err = process.communicate() - retcode = process.poll() - if retcode: - raise CommandExecutionError( - f"Error when executing '{' '.join(cmd)}' with stdout: {output}, stderr: {err}" - ) - self.log.info("Stdout: %s", output) - self.log.info("Stderr: %s", err) + self.log.info("Executing: '%s'", " ".join(shlex.quote(c) for c in cmd)) + with subprocess.Popen( + args=cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + cwd=cwd, + env=env, + ) as process: + output, err = process.communicate() + retcode = process.poll() + if retcode: + raise CommandExecutionError( + f"Error when executing '{' '.join(cmd)}' with stdout: {output}, stderr: {err}" + ) + self.log.info("Stdout: %s", output) + self.log.info("Stderr: %s", err) def get_executor() -> LoggingCommandExecutor: diff --git a/devel-common/src/tests_common/test_utils/system_tests.py b/devel-common/src/tests_common/test_utils/system_tests.py index daaa2639afbe5..1ba90d24ec811 100644 --- a/devel-common/src/tests_common/test_utils/system_tests.py +++ b/devel-common/src/tests_common/test_utils/system_tests.py @@ -53,11 +53,10 @@ def callback(context: Context): def add_callback(current: list[Callable] | Callable | None, new: Callable) -> list[Callable] | Callable: if not current: return new - elif isinstance(current, list): + if isinstance(current, list): current.append(new) return current - else: - return [current, new] + return [current, new] @pytest.mark.system def test_run(): diff --git a/docker-tests/tests/docker_tests/command_utils.py b/docker-tests/tests/docker_tests/command_utils.py index bec5fec86f869..d60807a808fd5 100644 --- a/docker-tests/tests/docker_tests/command_utils.py +++ b/docker-tests/tests/docker_tests/command_utils.py @@ -27,15 +27,13 @@ def run_command( try: if return_output: return subprocess.check_output(cmd, **kwargs).decode() - else: - try: - result = subprocess.run(cmd, check=check, **kwargs) - return result.returncode == 0 - except FileNotFoundError: - if check: - raise - else: - return False + try: + result = subprocess.run(cmd, check=check, **kwargs) + return result.returncode == 0 + except FileNotFoundError: + if check: + raise + return False except subprocess.CalledProcessError as ex: if print_output_on_error: print("========================= OUTPUT start ============================") diff --git a/providers/airbyte/src/airflow/providers/airbyte/hooks/airbyte.py b/providers/airbyte/src/airflow/providers/airbyte/hooks/airbyte.py index f2a63080b6898..fd2498532010f 100644 --- a/providers/airbyte/src/airflow/providers/airbyte/hooks/airbyte.py +++ b/providers/airbyte/src/airflow/providers/airbyte/hooks/airbyte.py @@ -147,10 +147,9 @@ def wait_for_job(self, job_id: str | int, wait_seconds: float = 3, timeout: floa break if state == JobStatusEnum.FAILED: raise AirflowException(f"Job failed:\n{job}") - elif state == JobStatusEnum.CANCELLED: + if state == JobStatusEnum.CANCELLED: raise AirflowException(f"Job was cancelled:\n{job}") - else: - raise AirflowException(f"Encountered unexpected state `{state}` for job_id `{job_id}`") + raise AirflowException(f"Encountered unexpected state `{state}` for job_id `{job_id}`") def submit_sync_connection(self, connection_id: str) -> Any: try: @@ -186,7 +185,6 @@ def test_connection(self): health_check = self.airbyte_api.health.get_health_check() if health_check.status_code == 200: return True, "Connection successfully tested" - else: - return False, str(health_check.raw_response) + return False, str(health_check.raw_response) except Exception as e: return False, str(e) diff --git a/providers/airbyte/src/airflow/providers/airbyte/sensors/airbyte.py b/providers/airbyte/src/airflow/providers/airbyte/sensors/airbyte.py index bb524fc322598..10a2976da4d10 100644 --- a/providers/airbyte/src/airflow/providers/airbyte/sensors/airbyte.py +++ b/providers/airbyte/src/airflow/providers/airbyte/sensors/airbyte.py @@ -79,10 +79,10 @@ def poke(self, context: Context) -> bool: if status == JobStatusEnum.FAILED: message = f"Job failed: \n{job}" raise AirflowException(message) - elif status == JobStatusEnum.CANCELLED: + if status == JobStatusEnum.CANCELLED: message = f"Job was cancelled: \n{job}" raise AirflowException(message) - elif status == JobStatusEnum.SUCCEEDED: + if status == JobStatusEnum.SUCCEEDED: self.log.info("Job %s completed successfully.", self.airbyte_job_id) return True diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/log/oss_task_handler.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/log/oss_task_handler.py index 9b88600e0187e..949556a04c2fe 100644 --- a/providers/alibaba/src/airflow/providers/alibaba/cloud/log/oss_task_handler.py +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/log/oss_task_handler.py @@ -93,8 +93,7 @@ def read(self, relative_path, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMessages if self.oss_log_exists(relative_path): logs.append(self.oss_read(relative_path, return_error=True)) return messages, logs - else: - return messages, None + return messages, None def oss_log_exists(self, remote_log_location): """ diff --git a/providers/alibaba/tests/unit/alibaba/cloud/utils/test_utils.py b/providers/alibaba/tests/unit/alibaba/cloud/utils/test_utils.py index 6908157e13e5c..c4755cd53e879 100644 --- a/providers/alibaba/tests/unit/alibaba/cloud/utils/test_utils.py +++ b/providers/alibaba/tests/unit/alibaba/cloud/utils/test_utils.py @@ -36,7 +36,6 @@ def wrapper(*args, **kwargs) -> None: if self.hook is not None: return func(*bound_args.args, **bound_args.kwargs) - else: - return None + return None return cast("T", wrapper) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py index fcd9bddaceded..18da2fe1663ba 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py @@ -107,26 +107,23 @@ def _create_policy_store(client: BaseClient, args) -> tuple[str | None, bool]: f"There is already a policy store with description '{args.policy_store_description}' in Amazon Verified Permissions: '{existing_policy_stores[0]['policyStoreId']}'." ) return existing_policy_stores[0]["policyStoreId"], False - else: - print(f"No policy store with description '{args.policy_store_description}' found, creating one.") - if args.dry_run: - print( - f"Dry run, not creating the policy store with description '{args.policy_store_description}'." - ) - return None, True - - response = client.create_policy_store( - validationSettings={ - "mode": "STRICT", - }, - description=args.policy_store_description, - ) - if args.verbose: - log.debug("Response from create_policy_store: %s", response) + print(f"No policy store with description '{args.policy_store_description}' found, creating one.") + if args.dry_run: + print(f"Dry run, not creating the policy store with description '{args.policy_store_description}'.") + return None, True + + response = client.create_policy_store( + validationSettings={ + "mode": "STRICT", + }, + description=args.policy_store_description, + ) + if args.verbose: + log.debug("Response from create_policy_store: %s", response) - print(f"Policy store created: '{response['policyStoreId']}'") + print(f"Policy store created: '{response['policyStoreId']}'") - return response["policyStoreId"], True + return response["policyStoreId"], True def _set_schema(client: BaseClient, policy_store_id: str, args) -> None: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py index c0b6ea5dd553f..2639fd268f8f2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py @@ -72,7 +72,7 @@ def build_task_kwargs() -> dict: raise ValueError( "capacity_provider_strategy and launch_type are mutually exclusive, you can not provide both." ) - elif "cluster" in task_kwargs and not (has_capacity_provider or has_launch_type): + if "cluster" in task_kwargs and not (has_capacity_provider or has_launch_type): # Default API behavior if neither is provided is to fall back on the default capacity # provider if it exists. Since it is not a required value, check if there is one # before using it, and if there is not then use the FARGATE launch_type as diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py index 8024e6181db45..c7602d5a92459 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py @@ -138,7 +138,7 @@ def get_task_state(self) -> str: """ if self.last_status == "RUNNING": return State.RUNNING - elif self.desired_status == "RUNNING": + if self.desired_status == "RUNNING": return State.QUEUED is_finished = self.desired_status == "STOPPED" has_exit_codes = all(["exit_code" in x for x in self.containers]) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena.py index b6ef37325c1cf..952d471eb2307 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/athena.py @@ -208,7 +208,7 @@ def get_query_results( if query_state is None: self.log.error("Invalid Query state. Query execution id: %s", query_execution_id) return None - elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: + if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: self.log.error( 'Query is in "%s" state. Cannot fetch results. Query execution id: %s', query_state, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py index 0230dea91f14e..8bb48a0c32325 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py @@ -191,15 +191,13 @@ def create_session(self, deferrable: bool = False) -> boto3.session.Session | Ai session = self.get_async_session() self._apply_session_kwargs(session) return session - else: - return boto3.session.Session(region_name=self.region_name) - elif not self.role_arn: + return boto3.session.Session(region_name=self.region_name) + if not self.role_arn: if deferrable: session = self.get_async_session() self._apply_session_kwargs(session) return session - else: - return self.basic_session + return self.basic_session # Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only # to create the initial boto3 session. @@ -624,7 +622,7 @@ def _resolve_service_name(self, is_resource_type: bool = False) -> str: if is_resource_type: raise LookupError("Requested `resource_type`, but `client_type` was set instead.") return self.client_type - elif self.resource_type: + if self.resource_type: if not is_resource_type: raise LookupError("Requested `client_type`, but `resource_type` was set instead.") return self.resource_type @@ -840,15 +838,14 @@ def expand_role(self, role: str, region_name: str | None = None) -> str: """ if "/" in role: return role - else: - session = self.get_session(region_name=region_name) - _client = session.client( - service_name="iam", - endpoint_url=self.conn_config.get_service_endpoint_url("iam"), - config=self.config, - verify=self.verify, - ) - return _client.get_role(RoleName=role)["Role"]["Arn"] + session = self.get_session(region_name=region_name) + _client = session.client( + service_name="iam", + endpoint_url=self.conn_config.get_service_endpoint_url("iam"), + config=self.config, + verify=self.verify, + ) + return _client.get_role(RoleName=role)["Role"]["Arn"] @staticmethod def retry(should_retry: Callable[[Exception], bool]): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/cloud_formation.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/cloud_formation.py index 5f591795af144..ad3b83fb4a7c3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/cloud_formation.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/cloud_formation.py @@ -62,8 +62,7 @@ def get_stack_status(self, stack_name: client | resource) -> dict | None: except ClientError as e: if "does not exist" in str(e): return None - else: - raise e + raise e def create_stack(self, stack_name: str, cloudformation_parameters: dict) -> None: """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/datasync.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/datasync.py index 4b74f7cf92885..aca571fa62c11 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/datasync.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/datasync.py @@ -312,9 +312,9 @@ def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = self.log.info("status=%s", status) if status in self.TASK_EXECUTION_SUCCESS_STATES: return True - elif status in self.TASK_EXECUTION_FAILURE_STATES: + if status in self.TASK_EXECUTION_FAILURE_STATES: return False - elif status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES: + if status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES: time.sleep(self.wait_interval_seconds) else: raise AirflowException(f"Unknown status: {status}") # Should never happen diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/dms.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/dms.py index f7b91b4b572e9..1bbe414f02c88 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/dms.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/dms.py @@ -108,9 +108,8 @@ def get_task_status(self, replication_task_arn: str) -> str | None: status = replication_tasks[0]["Status"] self.log.info('Replication task with ARN(%s) has status "%s".', replication_task_arn, status) return status - else: - self.log.info("Replication task with ARN(%s) is not found.", replication_task_arn) - return None + self.log.info("Replication task with ARN(%s) is not found.", replication_task_arn) + return None def create_replication_task( self, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/dynamodb.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/dynamodb.py index 53980a1f93308..082c1c203d076 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/dynamodb.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/dynamodb.py @@ -102,5 +102,4 @@ def get_import_status(self, import_arn: str) -> tuple[str, str | None, str | Non error_code = e.response.get("Error", {}).get("Code") if error_code == "ImportNotFoundException": raise AirflowException("S3 import into Dynamodb job not found.") - else: - raise e + raise e diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/emr.py index cfc12e2378c66..2be6426e9247d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/emr.py @@ -83,11 +83,10 @@ def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: list[str cluster_id = matching_clusters[0]["Id"] self.log.info("Found cluster name = %s id = %s", emr_cluster_name, cluster_id) return cluster_id - elif len(matching_clusters) > 1: + if len(matching_clusters) > 1: raise AirflowException(f"More than one cluster found for name {emr_cluster_name}") - else: - self.log.info("No cluster found for name %s", emr_cluster_name) - return None + self.log.info("No cluster found for name %s", emr_cluster_name) + return None def create_job_flow(self, job_flow_overrides: dict[str, Any]) -> dict[str, Any]: """ @@ -387,12 +386,11 @@ def create_emr_on_eks_cluster( if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Create EMR EKS Cluster failed: {response}") - else: - self.log.info( - "Create EMR EKS Cluster success - virtual cluster id %s", - response["id"], - ) - return response["id"] + self.log.info( + "Create EMR EKS Cluster success - virtual cluster id %s", + response["id"], + ) + return response["id"] def submit_job( self, @@ -446,13 +444,12 @@ def submit_job( if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Start Job Run failed: {response}") - else: - self.log.info( - "Start Job Run success - Job Id %s and virtual cluster id %s", - response["id"], - response["virtualClusterId"], - ) - return response["id"] + self.log.info( + "Start Job Run success - Job Id %s and virtual cluster id %s", + response["id"], + response["virtualClusterId"], + ) + return response["id"] def get_job_failure_reason(self, job_id: str) -> str | None: """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py index b9cd21bffae29..7093120ceafd0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py @@ -320,8 +320,7 @@ def job_completion( if ret: time.sleep(sleep_before_return) return ret - else: - time.sleep(self.job_poll_interval) + time.sleep(self.job_poll_interval) async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]: """ @@ -338,8 +337,7 @@ async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens) if ret: return ret - else: - await asyncio.sleep(self.job_poll_interval) + await asyncio.sleep(self.job_poll_interval) def _handle_state( self, @@ -367,13 +365,12 @@ def _handle_state( job_error_message = f"Exiting Job {run_id} Run State: {state}" self.log.info(job_error_message) raise AirflowException(job_error_message) - else: - self.log.info( - "Polling for AWS Glue Job %s current run state with status %s", - job_name, - state, - ) - return None + self.log.info( + "Polling for AWS Glue Job %s current run state with status %s", + job_name, + state, + ) + return None def has_job(self, job_name) -> bool: """ @@ -414,8 +411,7 @@ def update_job(self, **job_kwargs) -> bool: self.conn.update_job(JobName=job_name, JobUpdate=job_kwargs) self.log.info("Updated configurations: %s", update_config) return True - else: - return False + return False def get_or_create_glue_job(self) -> str | None: """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py index 9ee8635ccbcb0..49d03a503c66d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -101,13 +101,12 @@ def invoke_rest_api( "Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..." ) return self._invoke_rest_api_using_local_session_token(**api_kwargs) - else: - to_log = e.response - # ResponseMetadata is removed because it contains data that is either very unlikely to be - # useful in XComs and logs, or redundant given the data already included in the response - to_log.pop("ResponseMetadata", None) - self.log.error(to_log) - raise + to_log = e.response + # ResponseMetadata is removed because it contains data that is either very unlikely to be + # useful in XComs and logs, or redundant given the data already included in the response + to_log.pop("ResponseMetadata", None) + self.log.error(to_log) + raise def _invoke_rest_api_using_local_session_token( self, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py index 72879fd8595e0..c57d1d2a26763 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -181,7 +181,7 @@ def parse_statement_response(self, resp: DescribeStatementResponseTypeDef) -> bo if num_rows is not None: self.log.info("Processed %s rows", num_rows) return True - elif status in FAILURE_STATES: + if status in FAILURE_STATES: exception_cls = ( RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py index 23b9555def84c..c74b23352f93f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/redshift_sql.py @@ -245,15 +245,14 @@ def _get_identifier_from_hostname(self, hostname: str) -> str: parts = hostname.split(".") if hostname.endswith("amazonaws.com") and len(parts) == 6: return f"{parts[0]}.{parts[2]}" - else: - self.log.debug( - """Could not parse identifier from hostname '%s'. + self.log.debug( + """Could not parse identifier from hostname '%s'. You are probably using IP to connect to Redshift cluster. Expected format: 'cluster_identifier.id.region_name.redshift.amazonaws.com' Falling back to whole hostname.""", - hostname, - ) - return hostname + hostname, + ) + return hostname def get_openlineage_database_dialect(self, connection: Connection) -> str: """Return redshift dialect.""" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py index 5455bc102735b..a904b9cfbb112 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py @@ -442,8 +442,7 @@ async def get_head_object_async( except ClientError as e: if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: return head_object_val - else: - raise e + raise e async def list_prefixes_async( self, @@ -936,8 +935,7 @@ def head_object(self, key: str, bucket_name: str | None = None) -> dict | None: except ClientError as e: if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: return None - else: - raise e + raise e @unify_bucket_name_and_key @provide_bucket_name @@ -1469,8 +1467,7 @@ def download_file( raise AirflowNotFoundException( f"The source file in Bucket {bucket_name} with path {key} does not exist" ) - else: - raise e + raise e if preserve_file_name: local_dir = local_path or gettempdir() diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py index 43df07e23bd48..a785cc5ce83cb 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -750,7 +750,7 @@ def check_status( if status in self.failed_states: raise AirflowException(f"SageMaker resource failed because {response['FailureReason']}") - elif status not in non_terminal_states: + if status not in non_terminal_states: break if max_ingestion_time and sec > max_ingestion_time: @@ -1010,8 +1010,7 @@ def _list_request( if "NextToken" not in response or (max_results is not None and len(results) == max_results): # Return when there are no results left (no NextToken) or when we've reached max_results. return results - else: - next_token = response["NextToken"] + next_token = response["NextToken"] @staticmethod def _name_matches_pattern( @@ -1172,9 +1171,8 @@ def stop_pipeline( ): self.log.warning("Cannot stop pipeline execution, as it was not running: %s", ce) break - else: - self.log.error(ce) - raise + self.log.error(ce) + raise else: break @@ -1214,9 +1212,8 @@ def create_model_package_group(self, package_group_name: str, package_group_desc # log msg only so it doesn't look like an error self.log.info("%s", e.response["Error"]["Message"]) return False - else: - self.log.error("Error when trying to create Model Package Group: %s", e) - raise + self.log.error("Error when trying to create Model Package Group: %s", e) + raise def _describe_auto_ml_job(self, job_name: str): res = self.conn.describe_auto_ml_job(AutoMLJobName=job_name) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py index 4ad327b51c5ff..d7ee54543628c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py @@ -180,9 +180,8 @@ def _handle_state(self, execution_id, status, error_message): if status in finished_states: self.log.info(execution_message) return {"Status": status, "ExecutionId": execution_id} - else: - log_error_message = f"Execution {execution_id} failed with error: {error_message}" - self.log.error(log_error_message) - if error_message == "": - error_message = execution_message - raise AirflowException(error_message) + log_error_message = f"Execution {execution_id} failed with error: {error_message}" + self.log.error(log_error_message) + if error_message == "": + error_message = execution_message + raise AirflowException(error_message) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py b/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py index e909776bdffd8..c97497b7d8aad 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/base_aws.py @@ -49,9 +49,9 @@ class BaseAwsLink(BaseOperatorLink): def get_aws_domain(aws_partition) -> str | None: if aws_partition == "aws": return "aws.amazon.com" - elif aws_partition == "aws-cn": + if aws_partition == "aws-cn": return "amazonaws.cn" - elif aws_partition == "aws-us-gov": + if aws_partition == "aws-us-gov": return "amazonaws-us-gov.com" return None diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/links/emr.py index d81bc93cc9b07..d36aab9db89b5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/links/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/emr.py @@ -127,8 +127,7 @@ def format_link(self, application_id: str | None = None, job_run_id: str | None ) if url: return url._replace(path="/logs/SPARK_DRIVER/stdout.gz").geturl() - else: - return "" + return "" class EmrServerlessDashboardLink(BaseAwsLink): @@ -145,8 +144,7 @@ def format_link(self, application_id: str | None = None, job_run_id: str | None ) if url: return url.geturl() - else: - return "" + return "" class EmrServerlessS3LogsLink(BaseAwsLink): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py index 44460443e4778..4c31e952e0bed 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py @@ -56,8 +56,7 @@ def json_serialize_legacy(value: Any) -> str | None: """ if isinstance(value, (date, datetime)): return value.isoformat() - else: - return None + return None def json_serialize(value: Any) -> str | None: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/log/s3_task_handler.py b/providers/amazon/src/airflow/providers/amazon/aws/log/s3_task_handler.py index e3586a84299a4..24088ca0d9c19 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/log/s3_task_handler.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/log/s3_task_handler.py @@ -162,8 +162,7 @@ def read(self, relative_path: str, ti: RuntimeTI) -> tuple[LogSourceInfo, LogMes for key in keys: logs.append(self.s3_read(key, return_error=True)) return messages, logs - else: - return messages, None + return messages, None class S3TaskHandler(FileTaskHandler, LoggingMixin): diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/athena.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/athena.py index 3d152e513d646..44eb6590aeb83 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/athena.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/athena.py @@ -168,7 +168,7 @@ def execute(self, context: Context) -> str | None: f"Final state of Athena job is {query_status}, query_execution_id is " f"{self.query_execution_id}. Error: {error_message}" ) - elif not query_status or query_status in AthenaHook.INTERMEDIATE_STATES: + if not query_status or query_status in AthenaHook.INTERMEDIATE_STATES: raise AirflowException( f"Final state of Athena job is {query_status}. Max tries of poll status exceeded, " f"query_execution_id is {self.query_execution_id}." diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/batch.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/batch.py index 49669f77f6202..dfb86611a767d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/batch.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/batch.py @@ -245,9 +245,9 @@ def execute(self, context: Context) -> str | None: if job_status == self.hook.SUCCESS_STATE: self.log.info("Job completed.") return self.job_id - elif job_status == self.hook.FAILURE_STATE: + if job_status == self.hook.FAILURE_STATE: raise AirflowException(f"Error while running job: {self.job_id} is in {job_status} state") - elif job_status in self.hook.INTERMEDIATE_STATES: + if job_status in self.hook.INTERMEDIATE_STATES: self.defer( timeout=self.execution_timeout, trigger=BatchJobTrigger( diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py index bb2c753236249..c83ae2690e46d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py @@ -567,8 +567,7 @@ def execute(self, context): if self.do_xcom_push and self.task_log_fetcher: return self.task_log_fetcher.get_last_log_message() - else: - return None + return None def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str | None: validated_event = validate_execute_complete_event(event) @@ -729,11 +728,10 @@ def _check_success_task(self) -> None: f"This task is not in success state - last {self.number_logs_exception} " f"logs from Cloudwatch:\n{last_logs}" ) - else: - raise AirflowException(f"This task is not in success state {task}") - elif container.get("lastStatus") == "PENDING": + raise AirflowException(f"This task is not in success state {task}") + if container.get("lastStatus") == "PENDING": raise AirflowException(f"This task is still pending {task}") - elif "error" in container.get("reason", "").lower(): + if "error" in container.get("reason", "").lower(): raise AirflowException( f"This containers encounter an error during launching: " f"{container.get('reason', '').lower()}" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py index bac02c4084364..bdaad9a749c1c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py @@ -276,11 +276,11 @@ def execute(self, context: Context): if self.compute: if self.compute not in SUPPORTED_COMPUTE_VALUES: raise ValueError("Provided compute type is not supported.") - elif (self.compute == "nodegroup") and not self.nodegroup_role_arn: + if (self.compute == "nodegroup") and not self.nodegroup_role_arn: raise ValueError( MISSING_ARN_MSG.format(compute=NODEGROUP_FULL_NAME, requirement="nodegroup_role_arn") ) - elif (self.compute == "fargate") and not self.fargate_pod_execution_role_arn: + if (self.compute == "fargate") and not self.fargate_pod_execution_role_arn: raise ValueError( MISSING_ARN_MSG.format( compute=FARGATE_FULL_NAME, requirement="fargate_pod_execution_role_arn" @@ -349,7 +349,7 @@ def deferrable_create_cluster_next(self, context: Context, event: dict[str, Any] if event is None: self.log.error("Trigger error: event is None") raise AirflowException("Trigger error: event is None") - elif event["status"] == "failed": + if event["status"] == "failed": self.log.error("Cluster failed to start and will be torn down.") self.hook.delete_cluster(name=self.cluster_name) self.defer( @@ -414,7 +414,7 @@ def execute_failed(self, context: Context, event: dict[str, Any] | None = None) if event is None: self.log.info("Trigger error: event is None") raise AirflowException("Trigger error: event is None") - elif event["status"] == "deleted": + if event["status"] == "deleted": self.log.info("Cluster deleted") raise AirflowException("Error creating cluster") diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py index 2d3fcfedef400..91609a3e143e1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py @@ -850,9 +850,8 @@ def execute(self, context: Context) -> int: if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Modify cluster failed: {response}") - else: - self.log.info("Steps concurrency level %d", response["StepConcurrencyLevel"]) - return response["StepConcurrencyLevel"] + self.log.info("Steps concurrency level %d", response["StepConcurrencyLevel"]) + return response["StepConcurrencyLevel"] class EmrTerminateJobFlowOperator(BaseOperator): @@ -1070,7 +1069,7 @@ def start_application_deferred(self, context: Context, event: dict[str, Any] | N if event is None: self.log.error("Trigger error: event is None") raise AirflowException("Trigger error: event is None") - elif event["status"] != "success": + if event["status"] != "success": raise AirflowException(f"Application {event['application_id']} failed to create") self.log.info("Starting application %s", event["application_id"]) self.hook.conn.start_application(applicationId=event["application_id"]) @@ -1533,7 +1532,7 @@ def stop_application(self, context: Context, event: dict[str, Any] | None = None if event is None: self.log.error("Trigger error: event is None") raise AirflowException("Trigger error: event is None") - elif event["status"] == "success": + if event["status"] == "success": self.hook.conn.stop_application(applicationId=self.application_id) self.defer( trigger=EmrServerlessStopApplicationTrigger( diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune.py index 73e3c759e9eda..85cf20c7971dc 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune.py @@ -139,7 +139,7 @@ def execute(self, context: Context, event: dict[str, Any] | None = None, **kwarg if status.lower() in NeptuneHook.AVAILABLE_STATES: self.log.info("Neptune cluster %s is already available.", self.cluster_id) return {"db_cluster_id": self.cluster_id} - elif status.lower() in NeptuneHook.ERROR_STATES: + if status.lower() in NeptuneHook.ERROR_STATES: # some states will not allow you to start the cluster self.log.error( "Neptune cluster %s is in error state %s and cannot be started", self.cluster_id, status @@ -259,7 +259,7 @@ def execute(self, context: Context, event: dict[str, Any] | None = None, **kwarg if status.lower() in NeptuneHook.STOPPED_STATES: self.log.info("Neptune cluster %s is already stopped.", self.cluster_id) return {"db_cluster_id": self.cluster_id} - elif status.lower() in NeptuneHook.ERROR_STATES: + if status.lower() in NeptuneHook.ERROR_STATES: # some states will not allow you to stop the cluster self.log.error( "Neptune cluster %s is in error state %s and cannot be stopped", self.cluster_id, status diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/redshift_data.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/redshift_data.py index 334834d16d3b4..6bc62b0cedb63 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/redshift_data.py @@ -224,8 +224,7 @@ def get_sql_results( results: list = [self.hook.conn.get_statement_result(Id=sid) for sid in statement_ids] self.log.debug("Statement result(s): %s", results) return results - else: - return statement_ids + return statement_ids def on_kill(self) -> None: """Cancel the submitted redshift query.""" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py index 406d9f597527b..f486084f34de0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py @@ -158,9 +158,8 @@ def execute(self, context: Context): if self.hook.check_for_bucket(self.bucket_name): self.log.info("Getting tags for bucket %s", self.bucket_name) return self.hook.get_bucket_tagging(self.bucket_name) - else: - self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) - return None + self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) + return None class S3PutBucketTaggingOperator(AwsBaseOperator[S3Hook]): @@ -213,9 +212,8 @@ def execute(self, context: Context): return self.hook.put_bucket_tagging( key=self.key, value=self.value, tag_set=self.tag_set, bucket_name=self.bucket_name ) - else: - self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) - return None + self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) + return None class S3DeleteBucketTaggingOperator(AwsBaseOperator[S3Hook]): @@ -254,9 +252,8 @@ def execute(self, context: Context): if self.hook.check_for_bucket(self.bucket_name): self.log.info("Deleting tags for bucket %s", self.bucket_name) return self.hook.delete_bucket_tagging(self.bucket_name) - else: - self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) - return None + self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) + return None class S3CopyObjectOperator(AwsBaseOperator[S3Hook]): @@ -725,10 +722,9 @@ def execute(self, context: Context): if process.returncode: raise AirflowException(f"Transform script failed: {process.returncode}") - else: - self.log.info( - "Transform script successful. Output temporarily located at %s", f_dest.name - ) + self.log.info( + "Transform script successful. Output temporarily located at %s", f_dest.name + ) self.log.info("Uploading transformed file to S3") f_dest.flush() diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py index 778816151be01..32b66ea6161ed 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -165,13 +165,10 @@ def _get_unique_name( # in case there is collision. if fail_if_exists: raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.") - else: - max_name_len = 63 - timestamp = str( - time.time_ns() // 1000000000 - ) # only keep the relevant datetime (first 10 digits) - name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp - self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name) + max_name_len = 63 + timestamp = str(time.time_ns() // 1000000000) # only keep the relevant datetime (first 10 digits) + name = f"{proposed_name[: max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp + self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name) return name def _check_resource_type(self, resource_type: str): @@ -197,8 +194,7 @@ def _check_if_resource_exists( except ClientError as e: if e.response["Error"]["Code"] == "ValidationException": return False # ValidationException is thrown when the resource could not be found - else: - raise e + raise e def execute(self, context: Context): raise NotImplementedError("Please implement execute() in sub class!") @@ -326,7 +322,7 @@ def execute(self, context: Context) -> dict: status = response["ProcessingJobStatus"] if status in self.hook.failed_states: raise AirflowException(f"SageMaker job failed because {response['FailureReason']}") - elif status == "Completed": + if status == "Completed": self.log.info("%s completed successfully.", self.task_id) return {"Processing": serialize(response)} @@ -430,12 +426,9 @@ def execute(self, context: Context) -> dict: response = self.hook.create_endpoint_config(self.config) if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Sagemaker endpoint config creation failed: {response}") - else: - return { - "EndpointConfig": serialize( - self.hook.describe_endpoint_config(self.config["EndpointConfigName"]) - ) - } + return { + "EndpointConfig": serialize(self.hook.describe_endpoint_config(self.config["EndpointConfigName"])) + } class SageMakerEndpointOperator(SageMakerBaseOperator): @@ -1038,8 +1031,7 @@ def execute(self, context: Context) -> dict: response = self.hook.create_model(self.config) if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Sagemaker model creation failed: {response}") - else: - return {"Model": serialize(self.hook.describe_model(self.config["ModelName"]))} + return {"Model": serialize(self.hook.describe_model(self.config["ModelName"]))} class SageMakerTrainingOperator(SageMakerBaseOperator): @@ -1177,7 +1169,7 @@ def execute(self, context: Context) -> dict: if status in self.hook.failed_states: reason = description.get("FailureReason", "(No reason provided)") raise AirflowException(f"SageMaker job failed because {reason}") - elif status == "Completed": + if status == "Completed": log_message = f"{self.task_id} completed successfully." if self.print_log: billable_seconds = SageMakerHook.count_billable_seconds( diff --git a/providers/amazon/src/airflow/providers/amazon/aws/secrets/secrets_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/secrets/secrets_manager.py index 4d675771b23df..056c073a87813 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/secrets/secrets_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/secrets/secrets_manager.py @@ -224,8 +224,7 @@ def get_conn_value(self, conn_id: str) -> str | None: standardized_secret_dict = self._standardize_secret_keys(secret_dict) standardized_secret = json.dumps(standardized_secret_dict) return standardized_secret - else: - return secret + return secret def get_variable(self, key: str) -> str | None: """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/batch.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/batch.py index b368edcd093ed..9c8c700b885d2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/batch.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/batch.py @@ -226,8 +226,7 @@ def poke(self, context: Context) -> bool: if not response["jobQueues"]: if self.treat_non_existing_as_deleted: return True - else: - raise AirflowException(f"AWS Batch job queue {self.job_queue} not found") + raise AirflowException(f"AWS Batch job queue {self.job_queue} not found") status = response["jobQueues"][0]["status"] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glacier.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glacier.py index f52c157c9380b..09469c6c4079b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glacier.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glacier.py @@ -89,11 +89,10 @@ def poke(self, context: Context) -> bool: self.log.info("Job status: %s, code status: %s", response["Action"], response["StatusCode"]) self.log.info("Job finished successfully") return True - elif response["StatusCode"] == JobStatus.IN_PROGRESS.value: + if response["StatusCode"] == JobStatus.IN_PROGRESS.value: self.log.info("Processing...") self.log.warning("Code status: %s", response["StatusCode"]) return False - else: - raise AirflowException( - f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}" - ) + raise AirflowException( + f"Sensor failed. Job status: {response['Action']}, code status: {response['StatusCode']}" + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py index 79c83c88c1d57..fb360f9102171 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py @@ -84,12 +84,11 @@ def poke(self, context: Context): if job_state in self.success_states: self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state) return True - elif job_state in self.errored_states: + if job_state in self.errored_states: job_error_message = "Exiting Job %s Run State: %s", self.run_id, job_state self.log.info(job_error_message) raise AirflowException(job_error_message) - else: - return False + return False finally: if self.verbose: self.hook.print_job_logs( @@ -212,7 +211,7 @@ def poke(self, context: Context): return True - elif status in self.FAILURE_STATES: + if status in self.FAILURE_STATES: job_error_message = ( f"Error: AWS Glue data quality ruleset evaluation run RunId: {self.evaluation_run_id} Run " f"Status: {status}" @@ -220,8 +219,7 @@ def poke(self, context: Context): ) self.log.info(job_error_message) raise AirflowException(job_error_message) - else: - return False + return False class GlueDataQualityRuleRecommendationRunSensor(AwsBaseSensor[GlueDataQualityHook]): @@ -327,7 +325,7 @@ def poke(self, context: Context) -> bool: return True - elif status in self.FAILURE_STATES: + if status in self.FAILURE_STATES: job_error_message = ( f"Error: AWS Glue data quality recommendation run RunId: {self.recommendation_run_id} Run " f"Status: {status}" @@ -335,5 +333,4 @@ def poke(self, context: Context) -> bool: ) self.log.info(job_error_message) raise AirflowException(job_error_message) - else: - return False + return False diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue_crawler.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue_crawler.py index a68f643864091..d8a0a316375e8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue_crawler.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue_crawler.py @@ -73,7 +73,5 @@ def poke(self, context: Context): if crawler_status == self.success_statuses: self.log.info("Status: %s", crawler_status) return True - else: - raise AirflowException(f"Status: {crawler_status}") - else: - return False + raise AirflowException(f"Status: {crawler_status}") + return False diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/redshift_cluster.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/redshift_cluster.py index c6d356b4db216..80075530f8405 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/redshift_cluster.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/redshift_cluster.py @@ -93,7 +93,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None status = validated_event["status"] if status == "error": raise AirflowException(f"{validated_event['status']}: {validated_event['message']}") - elif status == "success": + if status == "success": self.log.info("%s completed successfully.", self.task_id) self.log.info("Cluster Identifier %s is in %s state", self.cluster_identifier, self.target_status) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py index cdbe7bceff9be..1bea70ae25544 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py @@ -176,8 +176,7 @@ def _check_key(self, key, context: Context): def poke(self, context: Context): if isinstance(self.bucket_key, str): return self._check_key(self.bucket_key, context=context) - else: - return all(self._check_key(key, context=context) for key in self.bucket_key) + return all(self._check_key(key, context=context) for key in self.bucket_key) def execute(self, context: Context) -> None: """Airflow runs this method on the worker and defers using the trigger.""" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py index ab32b50dbe89b..6c1815d2e166e 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py @@ -59,12 +59,11 @@ def poke(self, context=None): log_info_message = f"Exiting Execution {self.execution_id} State: {status}" self.log.info(log_info_message) return True - elif status in self.in_progress_states: + if status in self.in_progress_states: return False - else: - error_message = f"Exiting Execution {self.execution_id} State: {status}" - self.log.info(error_message) - raise AirflowException(error_message) + error_message = f"Exiting Execution {self.execution_id} State: {status}" + self.log.info(error_message) + raise AirflowException(error_message) def execute(self, context: Context): # This will invoke poke method in the base sensor diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sqs.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sqs.py index e5d70ed9e1a7b..371090b1ed9fd 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sqs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sqs.py @@ -221,5 +221,4 @@ def poke(self, context: Context): if message_batch: context["ti"].xcom_push(key="messages", value=message_batch) return True - else: - return False + return False diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py index de5a30120fdf0..ef2ca7b4cb343 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py @@ -183,7 +183,7 @@ def __post_init__(self, conn: Connection | AwsConnectionWrapper | _ConnectionMet # Only replace value if it not equal default value setattr(self, fl.name, value) return - elif not conn: + if not conn: return if TYPE_CHECKING: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/sqs.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/sqs.py index 3c509454655a7..6b46c5bf2baf3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/sqs.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/sqs.py @@ -73,8 +73,7 @@ def filter_messages( return filter_messages_jsonpath( messages, message_filtering_match_values, message_filtering_config, jsonpath_ng.ext.parse ) - else: - raise NotImplementedError("Override this method to define custom filters") + raise NotImplementedError("Override this method to define custom filters") def filter_messages_literal(messages, message_filtering_match_values) -> list[Any]: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/tags.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/tags.py index 5b8eb736bbc2a..9ff6c29cc7a7a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/tags.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/tags.py @@ -35,7 +35,6 @@ def format_tags(source: Any, *, key_label: str = "Key", value_label: str = "Valu """ if source is None: return [] - elif isinstance(source, dict): + if isinstance(source, dict): return [{key_label: kvp[0], value_label: kvp[1]} for kvp in source.items()] - else: - return source + return source diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_base_aws.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_base_aws.py index 22f4ef325b188..3970a766ade64 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_base_aws.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_base_aws.py @@ -134,8 +134,7 @@ def mock_conn(request): return conn if request.param == "wrapped": return AwsConnectionWrapper(conn=conn) - else: - raise ValueError("invalid internal test config") + raise ValueError("invalid internal test config") class TestSessionFactory: diff --git a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py index 787f353b37ffc..adf76a0ab6610 100644 --- a/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/apache/beam/src/airflow/providers/apache/beam/operators/beam.py @@ -399,15 +399,14 @@ def execute(self, context: Context): if self.is_dataflow and self.dataflow_hook: return self.execute_on_dataflow(context) - else: - self.beam_hook.start_python_pipeline( - variables=self.snake_case_pipeline_options, - py_file=self.py_file, - py_options=self.py_options, - py_interpreter=self.py_interpreter, - py_requirements=self.py_requirements, - py_system_site_packages=self.py_system_site_packages, - ) + self.beam_hook.start_python_pipeline( + variables=self.snake_case_pipeline_options, + py_file=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + ) def execute_on_dataflow(self, context: Context): """Execute the Apache Beam Pipeline on Dataflow runner.""" @@ -560,12 +559,11 @@ def execute(self, context: Context): if self.is_dataflow and self.dataflow_hook: return self.execute_on_dataflow(context) - else: - self.beam_hook.start_java_pipeline( - variables=self.pipeline_options, - jar=self.jar, - job_class=self.job_class, - ) + self.beam_hook.start_java_pipeline( + variables=self.pipeline_options, + jar=self.jar, + job_class=self.job_class, + ) def execute_on_dataflow(self, context: Context): """Execute the Apache Beam Pipeline on Dataflow runner.""" @@ -768,12 +766,11 @@ def execute(self, context: Context): project_id=self.dataflow_config.project_id, ) return {"dataflow_job_id": self.dataflow_job_id} - else: - go_artifact.start_pipeline( - beam_hook=self.beam_hook, - variables=snake_case_pipeline_options, - process_line_callback=process_line_callback, - ) + go_artifact.start_pipeline( + beam_hook=self.beam_hook, + variables=snake_case_pipeline_options, + process_line_callback=process_line_callback, + ) def on_kill(self) -> None: if self.dataflow_hook and self.dataflow_job_id: diff --git a/providers/apache/cassandra/src/airflow/providers/apache/cassandra/hooks/cassandra.py b/providers/apache/cassandra/src/airflow/providers/apache/cassandra/hooks/cassandra.py index 0d2279b754ce3..2e95f36b25727 100644 --- a/providers/apache/cassandra/src/airflow/providers/apache/cassandra/hooks/cassandra.py +++ b/providers/apache/cassandra/src/airflow/providers/apache/cassandra/hooks/cassandra.py @@ -190,8 +190,7 @@ def table_exists(self, table: str) -> bool: def _sanitize_input(input_string: str) -> str: if re.match(r"^\w+$", input_string): return input_string - else: - raise ValueError(f"Invalid input: {input_string}") + raise ValueError(f"Invalid input: {input_string}") def record_exists(self, table: str, keys: dict[str, str]) -> bool: """ diff --git a/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py b/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py index 14899c94ed7a4..6c7b0a9cb90b4 100644 --- a/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py +++ b/providers/apache/druid/src/airflow/providers/apache/druid/hooks/druid.py @@ -114,8 +114,7 @@ def get_status_url(self, ingestion_type): status_endpoint = self.conn.extra_dejson.get("status_endpoint", self.status_endpoint) return f"{conn_type}://{self.conn.host}:{self.conn.port}/{status_endpoint}" - else: - return self.get_conn_url(ingestion_type) + return self.get_conn_url(ingestion_type) def get_auth(self) -> requests.auth.HTTPBasicAuth | None: """ @@ -127,8 +126,7 @@ def get_auth(self) -> requests.auth.HTTPBasicAuth | None: password = self.conn.password if user is not None and password is not None: return requests.auth.HTTPBasicAuth(user, password) - else: - return None + return None def get_verify(self) -> bool | str: ca_bundle_path: str | None = self.conn.extra_dejson.get("ca_bundle_path", None) diff --git a/providers/apache/flink/src/airflow/providers/apache/flink/sensors/flink_kubernetes.py b/providers/apache/flink/src/airflow/providers/apache/flink/sensors/flink_kubernetes.py index e54317f930781..e485e14a8307b 100644 --- a/providers/apache/flink/src/airflow/providers/apache/flink/sensors/flink_kubernetes.py +++ b/providers/apache/flink/src/airflow/providers/apache/flink/sensors/flink_kubernetes.py @@ -128,9 +128,8 @@ def poke(self, context: Context) -> bool: if application_state in self.FAILURE_STATES: message = f"Flink application failed with state: {application_state}" raise AirflowException(message) - elif application_state in self.SUCCESS_STATES: + if application_state in self.SUCCESS_STATES: self.log.info("Flink application ended successfully") return True - else: - self.log.info("Flink application is still in state: %s", application_state) - return False + self.log.info("Flink application is still in state: %s", application_state) + return False diff --git a/providers/apache/hdfs/src/airflow/providers/apache/hdfs/hooks/webhdfs.py b/providers/apache/hdfs/src/airflow/providers/apache/hdfs/hooks/webhdfs.py index 47ab97db4e8df..b710c72f2f5ed 100644 --- a/providers/apache/hdfs/src/airflow/providers/apache/hdfs/hooks/webhdfs.py +++ b/providers/apache/hdfs/src/airflow/providers/apache/hdfs/hooks/webhdfs.py @@ -96,8 +96,7 @@ def _find_valid_server(self) -> Any: self.log.info("Using namenode %s for hook", namenode) host_socket.close() return client - else: - self.log.warning("Could not connect to %s:%s", namenode, connection.port) + self.log.warning("Could not connect to %s:%s", namenode, connection.port) except HdfsError as hdfs_error: self.log.info("Read operation on namenode %s failed with error: %s", namenode, hdfs_error) return None diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py index a944b9841a231..5c5b2ae2c669d 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/hooks/hive.py @@ -209,7 +209,7 @@ def _validate_beeline_parameters(self, conn): f"The schema used in beeline command ({conn.schema}) should not contain ';' character)" ) return - elif ":" in conn.host or "/" in conn.host or ";" in conn.host: + if ":" in conn.host or "/" in conn.host or ";" in conn.host: raise ValueError( f"The host used in beeline command ({conn.host}) should not contain ':/;' characters)" ) @@ -609,8 +609,7 @@ def _find_valid_host(self) -> Any: self.log.info("Connected to %s:%s", host, conn.port) host_socket.close() return host - else: - self.log.error("Could not connect to %s:%s", host, conn.port) + self.log.error("Could not connect to %s:%s", host, conn.port) return None def get_conn(self) -> Any: @@ -713,8 +712,7 @@ def get_partitions(self, schema: str, table_name: str, partition_filter: str | N pnames = [p.name for p in table.partitionKeys] return [dict(zip(pnames, p.values)) for p in parts] - else: - raise AirflowException("The table isn't partitioned") + raise AirflowException("The table isn't partitioned") @staticmethod def _get_max_partition_from_part_specs( diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/macros/hive.py b/providers/apache/hive/src/airflow/providers/apache/hive/macros/hive.py index 1d5f6f6a5cead..3f71b81264951 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/macros/hive.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/macros/hive.py @@ -75,8 +75,7 @@ def any_time(d): return min(date_list, key=any_time).date() if before_target: return min(date_list, key=time_before).date() - else: - return min(date_list, key=time_after).date() + return min(date_list, key=time_after).date() def closest_ds_partition( @@ -109,10 +108,9 @@ def closest_ds_partition( part_vals = [next(iter(p.values())) for p in partitions] if ds in part_vals: return ds - else: - parts = [datetime.datetime.strptime(pv, "%Y-%m-%d") for pv in part_vals] - target_dt = datetime.datetime.strptime(ds, "%Y-%m-%d") - closest_ds = _closest_date(target_dt, parts, before_target=before) - if closest_ds is not None: - return closest_ds.isoformat() + parts = [datetime.datetime.strptime(pv, "%Y-%m-%d") for pv in part_vals] + target_dt = datetime.datetime.strptime(ds, "%Y-%m-%d") + closest_ds = _closest_date(target_dt, parts, before_target=before) + if closest_ds is not None: + return closest_ds.isoformat() return None diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/sensors/named_hive_partition.py b/providers/apache/hive/src/airflow/providers/apache/hive/sensors/named_hive_partition.py index c18388477de9f..cce2921d2e5c2 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/sensors/named_hive_partition.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/sensors/named_hive_partition.py @@ -79,8 +79,7 @@ def parse_partition_name(partition: str) -> tuple[Any, ...]: second_split = table_partition.split("/", 1) if len(second_split) == 1: raise ValueError(f"Could not parse {partition}into table, partition") - else: - table, partition = second_split + table, partition = second_split return schema, table, partition def poke_partition(self, partition: str) -> Any: diff --git a/providers/apache/hive/src/airflow/providers/apache/hive/transfers/s3_to_hive.py b/providers/apache/hive/src/airflow/providers/apache/hive/transfers/s3_to_hive.py index 7eaf9dc1ac837..86862cec19f07 100644 --- a/providers/apache/hive/src/airflow/providers/apache/hive/transfers/s3_to_hive.py +++ b/providers/apache/hive/src/airflow/providers/apache/hive/transfers/s3_to_hive.py @@ -253,13 +253,12 @@ def _match_headers(self, header_list): test_field_match = all(h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)) if test_field_match: return True - else: - self.log.warning( - "Headers do not match field names File headers:\n %s\nField names: \n %s\n", - header_list, - field_names, - ) - return False + self.log.warning( + "Headers do not match field names File headers:\n %s\nField names: \n %s\n", + header_list, + field_names, + ) + return False @staticmethod def _delete_top_row_and_compress(input_file_name, output_file_ext, dest_dir): diff --git a/providers/apache/kafka/tests/integration/apache/kafka/hooks/test_producer.py b/providers/apache/kafka/tests/integration/apache/kafka/hooks/test_producer.py index ad2351ef736b7..8efe517b09e27 100644 --- a/providers/apache/kafka/tests/integration/apache/kafka/hooks/test_producer.py +++ b/providers/apache/kafka/tests/integration/apache/kafka/hooks/test_producer.py @@ -53,10 +53,9 @@ def test_produce(self): def acked(err, msg): if err is not None: raise Exception(f"{err}") - else: - assert msg.topic() == topic - assert msg.partition() == 0 - assert msg.offset() == 0 + assert msg.topic() == topic + assert msg.partition() == 0 + assert msg.offset() == 0 # Standard Init p_hook = KafkaProducerHook(kafka_config_id="kafka_default") diff --git a/providers/apache/kafka/tests/system/apache/kafka/example_dag_event_listener.py b/providers/apache/kafka/tests/system/apache/kafka/example_dag_event_listener.py index 2cabdde66eb3b..a0726a5577c95 100644 --- a/providers/apache/kafka/tests/system/apache/kafka/example_dag_event_listener.py +++ b/providers/apache/kafka/tests/system/apache/kafka/example_dag_event_listener.py @@ -101,11 +101,10 @@ def await_function(message): def wait_for_event(message, **context): if message % 15 == 0: return f"encountered {message}!" - else: - if message % 3 == 0: - print(f"encountered {message} FIZZ !") - if message % 5 == 0: - print(f"encountered {message} BUZZ !") + if message % 3 == 0: + print(f"encountered {message} FIZZ !") + if message % 5 == 0: + print(f"encountered {message} BUZZ !") # [START howto_sensor_await_message_trigger_function] listen_for_message = AwaitMessageTriggerFunctionSensor( diff --git a/providers/apache/kylin/tests/unit/apache/kylin/hooks/test_kylin.py b/providers/apache/kylin/tests/unit/apache/kylin/hooks/test_kylin.py index fc0ff2e720647..c029d094af2a4 100644 --- a/providers/apache/kylin/tests/unit/apache/kylin/hooks/test_kylin.py +++ b/providers/apache/kylin/tests/unit/apache/kylin/hooks/test_kylin.py @@ -60,8 +60,7 @@ def invoke_command(self, command, **kwargs): ] if command in invoke_command_list: return {"code": "000", "data": {}} - else: - raise KylinCubeError(f"Unsupported invoke command for datasource: {command}") + raise KylinCubeError(f"Unsupported invoke command for datasource: {command}") cube_source.return_value = MockCubeSource() response_data = {"code": "000", "data": {}} diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py index 9f3cd89fa257c..d1d25a9a2e3e5 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py @@ -695,9 +695,8 @@ async def dump_batch_logs(self, session_id: int | str) -> Any: for log_line in log_lines: self.log.info(log_line) return log_lines - else: - self.log.info(result["response"]) - return result["response"] + self.log.info(result["response"]) + return result["response"] @staticmethod def _validate_session_id(session_id: int | str) -> None: diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py index df3d6a41a605b..7a0c0bbd2038e 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py @@ -563,10 +563,9 @@ def submit(self, application: str = "", **kwargs: Any) -> None: f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. Error code is: {returncode}. " f"Kubernetes spark exit code is: {self._spark_exit_code}" ) - else: - raise AirflowException( - f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. Error code is: {returncode}." - ) + raise AirflowException( + f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. Error code is: {returncode}." + ) self.log.debug("Should track driver: %s", self._should_track_driver_status) diff --git a/providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py b/providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py index ca2eb6ebdb03c..edd5ad2d5b487 100644 --- a/providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py +++ b/providers/arangodb/src/airflow/providers/arangodb/hooks/arangodb.py @@ -110,10 +110,9 @@ def query(self, query, **kwargs) -> Cursor: if not isinstance(result, Cursor): raise AirflowException("Failed to execute AQLQuery, expected result to be of type Cursor") return result - else: - raise AirflowException( - f"Failed to execute AQLQuery, error connecting to database: {self.database}" - ) + raise AirflowException( + f"Failed to execute AQLQuery, error connecting to database: {self.database}" + ) except AQLQueryExecuteError as error: raise AirflowException(f"Failed to execute AQLQuery, error: {error}") @@ -122,33 +121,29 @@ def create_collection(self, name): self.log.info("Collection '%s' does not exist. Creating a new collection.", name) self.db_conn.create_collection(name) return True - else: - self.log.info("Collection already exists: %s", name) - return False + self.log.info("Collection already exists: %s", name) + return False def delete_collection(self, name): if self.db_conn.has_collection(name): self.db_conn.delete_collection(name) return True - else: - self.log.info("Collection does not exist: %s", name) - return False + self.log.info("Collection does not exist: %s", name) + return False def create_database(self, name): if not self.db_conn.has_database(name): self.db_conn.create_database(name) return True - else: - self.log.info("Database already exists: %s", name) - return False + self.log.info("Database already exists: %s", name) + return False def create_graph(self, name): if not self.db_conn.has_graph(name): self.db_conn.create_graph(name) return True - else: - self.log.info("Graph already exists: %s", name) - return False + self.log.info("Graph already exists: %s", name) + return False def insert_documents(self, collection_name, documents): if not self.db_conn.has_collection(collection_name): diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py index 97483b0865315..75e08cb240a36 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py @@ -27,20 +27,18 @@ def _convert_kube_model_object(obj, new_class): convert_op = getattr(obj, "to_k8s_client_obj", None) if callable(convert_op): return obj.to_k8s_client_obj() - elif isinstance(obj, new_class): + if isinstance(obj, new_class): return obj - else: - raise AirflowException(f"Expected {new_class}, got {type(obj)}") + raise AirflowException(f"Expected {new_class}, got {type(obj)}") def _convert_from_dict(obj, new_class): if isinstance(obj, new_class): return obj - elif isinstance(obj, dict): + if isinstance(obj, dict): api_client = ApiClient() return api_client._ApiClient__deserialize_model(obj, new_class) - else: - raise AirflowException(f"Expected dict or {new_class}, got {type(obj)}") + raise AirflowException(f"Expected dict or {new_class}, got {type(obj)}") def convert_volume(volume) -> k8s.V1Volume: @@ -111,8 +109,7 @@ def convert_image_pull_secrets(image_pull_secrets) -> list[k8s.V1LocalObjectRefe if isinstance(image_pull_secrets, str): secrets = image_pull_secrets.split(",") return [k8s.V1LocalObjectReference(name=secret) for secret in secrets] - else: - return image_pull_secrets + return image_pull_secrets def convert_configmap(configmaps) -> k8s.V1EnvFromSource: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 7a36666e9fab2..4173ea6a0d785 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -482,7 +482,7 @@ def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], li ).items if not pod_list: raise RuntimeError("Cannot find pod for ti %s", ti) - elif len(pod_list) > 1: + if len(pod_list) > 1: raise RuntimeError("Found multiple pods for ti %s: %s", ti, pod_list) res = client.read_namespaced_pod_log( name=pod_list[0].metadata.name, diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py index 7ee6b0d946df9..a8d580d3950ff 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py @@ -110,8 +110,7 @@ def _pod_events(self, kube_client: client.CoreV1Api, query_kwargs: dict): try: if self.namespace == ALL_NAMESPACES: return watcher.stream(kube_client.list_pod_for_all_namespaces, **query_kwargs) - else: - return watcher.stream(kube_client.list_namespaced_pod, self.namespace, **query_kwargs) + return watcher.stream(kube_client.list_namespaced_pod, self.namespace, **query_kwargs) except ApiException as e: if str(e.status) == "410": # Resource version is too old if self.namespace == ALL_NAMESPACES: @@ -121,8 +120,7 @@ def _pod_events(self, kube_client: client.CoreV1Api, query_kwargs: dict): resource_version = pods.metadata.resource_version query_kwargs["resource_version"] = resource_version return self._pod_events(kube_client=kube_client, query_kwargs=query_kwargs) - else: - raise + raise def _run( self, @@ -564,5 +562,4 @@ def get_base_pod_from_template(pod_template_file: str | None, kube_config: Any) """ if pod_template_file: return PodGenerator.deserialize_model_file(pod_template_file) - else: - return PodGenerator.deserialize_model_file(kube_config.pod_template_file) + return PodGenerator.deserialize_model_file(kube_config.pod_template_file) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 4adef6ba20e69..34d8673d18a96 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -177,8 +177,7 @@ def get_connection(cls, conn_id: str) -> Connection: except AirflowNotFoundException: if conn_id == cls.default_conn_name: return Connection(conn_id=cls.default_conn_name) - else: - raise + raise @cached_property def conn_extras(self): @@ -691,9 +690,8 @@ def check_kueue_deployment_running( and replicas == ready_replicas ): return - else: - self.log.info("Waiting until Deployment will be ready...") - sleep(polling_period_seconds) + self.log.info("Waiting until Deployment will be ready...") + sleep(polling_period_seconds) _timeout -= polling_period_seconds @@ -713,10 +711,10 @@ def _get_bool(val) -> bool | None: """Convert val to bool if can be done with certainty; if we cannot infer intention we return None.""" if isinstance(val, bool): return val - elif isinstance(val, str): + if isinstance(val, str): if val.strip().lower() == "true": return True - elif val.strip().lower() == "false": + if val.strip().lower() == "false": return False return None diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py index a353b7a5c8026..f0093d4671012 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py @@ -89,8 +89,7 @@ def create_unique_id( base_name = slugify(name, lowercase=True)[:max_length].strip(".-") if unique: return add_unique_suffix(name=base_name, rand_len=8, max_len=max_length) - else: - return base_name + return base_name def annotations_to_key(annotations: dict[str, str]) -> TaskInstanceKey: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py index b6f06dec04276..31ebbb752fc46 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py @@ -389,7 +389,7 @@ def reconcile_job_specs( return base_spec if not base_spec and client_spec: return client_spec - elif client_spec and base_spec: + if client_spec and base_spec: client_spec.template.spec = PodGenerator.reconcile_specs( base_spec.template.spec, client_spec.template.spec ) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/kueue.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/kueue.py index b1c28ffcbca9a..dba42c7d2396b 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/kueue.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/kueue.py @@ -101,7 +101,7 @@ def __init__(self, queue_name: str, *args, **kwargs) -> None: "The `suspend` parameter can't be False. If you want to use Kueue for running Job" " in a Kubernetes cluster, set the `suspend` parameter to True.", ) - elif self.suspend is None: + if self.suspend is None: self.log.info( "You have not set parameter `suspend` in class %s. " "For running a Job in Kueue the `suspend` parameter has been set to True.", diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py index dd50b359be3de..428c63d41d5bb 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -857,7 +857,7 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: message = event.get("stack_trace", event["message"]) raise AirflowException(message) - elif event["status"] == "running": + if event["status"] == "running": if self.get_logs: self.log.info("Resuming logs read from time %r", last_log_time) @@ -1297,7 +1297,7 @@ def __exit__(self, exctype, excinst, exctb) -> bool: matching_error = error and issubclass(exctype, self._exceptions) if (error and not matching_error) or (matching_error and self.reraise): return False - elif matching_error: + if matching_error: self.exception = excinst logger = logging.getLogger(__name__) logger.exception(excinst) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py index aef972faf26e8..cfedf84e7092a 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py @@ -99,8 +99,7 @@ def hook(self) -> KubernetesHook: def get_namespace(self) -> str: if self._namespace: return self._namespace - else: - return self.hook.get_namespace() or "default" + return self.hook.get_namespace() or "default" def get_crd_fields(self, body: dict) -> tuple[str, str, str, str]: api_version = body["apiVersion"] diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index 75190d6bf25d0..5d3ff3b51a34e 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -245,7 +245,7 @@ def find_spark_job(self, context, exclude_checked: bool = True): pod = None if len(pod_list) > 1: # and self.reattach_on_restart: raise AirflowException(f"More than one pod running with labels: {label_selector}") - elif len(pod_list) == 1: + if len(pod_list) == 1: pod = pod_list[0] self.log.info( "Found matching driver pod %s with labels %s", pod.metadata.name, pod.metadata.labels diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py index 1e073ccc18689..e91353631ae0b 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/pod_generator.py @@ -162,10 +162,9 @@ def from_obj(obj) -> dict | k8s.V1Pod | None: if isinstance(k8s_object, k8s.V1Pod): return k8s_object - else: - raise TypeError( - "Cannot convert a non-kubernetes.client.models.V1Pod object into a KubernetesExecutorConfig" - ) + raise TypeError( + "Cannot convert a non-kubernetes.client.models.V1Pod object into a KubernetesExecutorConfig" + ) @staticmethod def reconcile_pods(base_pod: k8s.V1Pod, client_pod: k8s.V1Pod | None) -> k8s.V1Pod: @@ -203,7 +202,7 @@ def reconcile_metadata(base_meta, client_meta): return base_meta if not base_meta and client_meta: return client_meta - elif client_meta and base_meta: + if client_meta and base_meta: client_meta.labels = merge_objects(base_meta.labels, client_meta.labels) client_meta.annotations = merge_objects(base_meta.annotations, client_meta.annotations) extend_object_field(base_meta, client_meta, "managed_fields") @@ -229,7 +228,7 @@ def reconcile_specs( return base_spec if not base_spec and client_spec: return client_spec - elif client_spec and base_spec: + if client_spec and base_spec: client_spec.containers = PodGenerator.reconcile_containers( base_spec.containers, client_spec.containers ) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py index db8c5301cb051..08ce0193476c9 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/resource_convert/env_variable.py @@ -33,7 +33,7 @@ def convert_env_vars(env_vars) -> list[k8s.V1EnvVar]: for k, v in env_vars.items(): res.append(k8s.V1EnvVar(name=k, value=v)) return res - elif isinstance(env_vars, list): + if isinstance(env_vars, list): if all([isinstance(e, k8s.V1EnvVar) for e in env_vars]): return env_vars raise AirflowException(f"Expected dict or list of V1EnvVar, got {type(env_vars)}") diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py index 08968d245fde7..5e3ff6964ddca 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py @@ -128,9 +128,8 @@ def poke(self, context: Context) -> bool: if application_state in self.FAILURE_STATES: message = f"Spark application failed with state: {application_state}" raise AirflowException(message) - elif application_state in self.SUCCESS_STATES: + if application_state in self.SUCCESS_STATES: self.log.info("Spark application ended successfully") return True - else: - self.log.info("Spark application is still in state: %s", application_state) - return False + self.log.info("Spark application is still in state: %s", application_state) + return False diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py index 52a12414f2c67..6ec0c99932b22 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -236,7 +236,7 @@ async def _wait_for_container_completion(self) -> TriggerEvent: "last_log_time": self.last_log_time, } ) - elif container_state == ContainerState.FAILED: + if container_state == ContainerState.FAILED: return TriggerEvent( { "status": "failed", @@ -289,8 +289,7 @@ def define_container_state(self, pod: V1Pod) -> ContainerState: if state_obj is not None: if state != ContainerState.TERMINATED: return state - else: - return ContainerState.TERMINATED if state_obj.exit_code == 0 else ContainerState.FAILED + return ContainerState.TERMINATED if state_obj.exit_code == 0 else ContainerState.FAILED return ContainerState.UNDEFINED @staticmethod diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 3bb9609c782d6..390d01e3c60bb 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -520,15 +520,14 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None return PodLoggingStatus(running=False, last_log_time=last_log_time) if not follow: return PodLoggingStatus(running=True, last_log_time=last_log_time) - else: - # a timeout is a normal thing and we ignore it and resume following logs - if not isinstance(exc, TimeoutError): - self.log.warning( - "Pod %s log read interrupted but container %s still running. Logs generated in the last one second might get duplicated.", - pod.metadata.name, - container_name, - ) - time.sleep(1) + # a timeout is a normal thing and we ignore it and resume following logs + if not isinstance(exc, TimeoutError): + self.log.warning( + "Pod %s log read interrupted but container %s still running. Logs generated in the last one second might get duplicated.", + pod.metadata.name, + container_name, + ) + time.sleep(1) def _reconcile_requested_log_containers( self, requested: Iterable[str] | str | bool | None, actual: list[str], pod_name diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/conftest.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/conftest.py index 38a94c77ea777..7e99e6fa2d60c 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/conftest.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/conftest.py @@ -30,7 +30,7 @@ def data_file(): if not DATA_FILE_DIRECTORY.exists(): msg = f"Data Directory {DATA_FILE_DIRECTORY.as_posix()!r} does not exist." raise FileNotFoundError(msg) - elif not DATA_FILE_DIRECTORY.is_dir(): + if not DATA_FILE_DIRECTORY.is_dir(): msg = f"Data Directory {DATA_FILE_DIRECTORY.as_posix()!r} expected to be a directory." raise NotADirectoryError(msg) diff --git a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py index c5144222e5cb3..89a485ff0d50c 100644 --- a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py +++ b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py @@ -131,10 +131,9 @@ def serialize_value( # type: ignore[override] if threshold < 0 or len(s_val_encoded) < threshold: # Either no threshold or value is small enough. if AIRFLOW_V_3_0_PLUS: return BaseXCom.serialize_value(value) - else: - # TODO: Remove this branch once we drop support for Airflow 2 - # This is for Airflow 2.10 where the value is expected to be bytes - return s_val_encoded + # TODO: Remove this branch once we drop support for Airflow 2 + # This is for Airflow 2.10 where the value is expected to be bytes + return s_val_encoded base_path = _get_base_path() while True: # Safeguard against collisions. diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.py b/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.py index e1b5fa33b91e7..6589685cb0b2c 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.py +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/handlers.py @@ -58,8 +58,7 @@ def fetch_all_handler(cursor) -> list[tuple] | None: ) if cursor.description is not None: return cursor.fetchall() - else: - return None + return None def fetch_one_handler(cursor) -> list[tuple] | None: @@ -71,5 +70,4 @@ def fetch_one_handler(cursor) -> list[tuple] | None: ) if cursor.description is not None: return cursor.fetchone() - else: - return None + return None diff --git a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py index 8914aa97b4dc7..2804038e87b42 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py @@ -166,7 +166,7 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa super().__init__() if not self.conn_name_attr: raise AirflowException("conn_name_attr is not defined") - elif len(args) == 1: + if len(args) == 1: setattr(self, self.conn_name_attr, args[0]) elif self.conn_name_attr not in kwargs: setattr(self, self.conn_name_attr, self.default_conn_name) @@ -599,8 +599,7 @@ def run( if handlers.return_single_query_results(sql, return_last, split_statements): self.descriptions = [_last_description] return _last_result - else: - return results + return results def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple | list[tuple]: """ diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index ed1e30674c36f..3714abc82503f 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -534,12 +534,11 @@ def _generate_sql_query(self, column, checks): def _generate_partition_clause(check): if self.partition_clause and "partition_clause" not in checks[check]: return f"WHERE {self.partition_clause}" - elif not self.partition_clause and "partition_clause" in checks[check]: + if not self.partition_clause and "partition_clause" in checks[check]: return f"WHERE {checks[check]['partition_clause']}" - elif self.partition_clause and "partition_clause" in checks[check]: + if self.partition_clause and "partition_clause" in checks[check]: return f"WHERE {self.partition_clause} AND {checks[check]['partition_clause']}" - else: - return "" + return "" checks_sql = "UNION ALL".join( self.sql_check_template.format( @@ -742,12 +741,11 @@ def _generate_sql_query(self): def _generate_partition_clause(check_name): if self.partition_clause and "partition_clause" not in self.checks[check_name]: return f"WHERE {self.partition_clause}" - elif not self.partition_clause and "partition_clause" in self.checks[check_name]: + if not self.partition_clause and "partition_clause" in self.checks[check_name]: return f"WHERE {self.checks[check_name]['partition_clause']}" - elif self.partition_clause and "partition_clause" in self.checks[check_name]: + if self.partition_clause and "partition_clause" in self.checks[check_name]: return f"WHERE {self.partition_clause} AND {self.checks[check_name]['partition_clause']}" - else: - return "" + return "" return "UNION ALL".join( self.sql_check_template.format( diff --git a/providers/common/sql/src/airflow/providers/common/sql/sensors/sql.py b/providers/common/sql/src/airflow/providers/common/sql/sensors/sql.py index 322a970c0b848..9d4ff283e15fa 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/sensors/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/sensors/sql.py @@ -105,8 +105,7 @@ def poke(self, context: Context) -> bool: if self.fail_on_empty: message = "No rows returned, raising as per fail_on_empty flag" raise AirflowException(message) - else: - return False + return False condition = self.selector(records[0]) if self.failure is not None: @@ -121,7 +120,6 @@ def poke(self, context: Context) -> bool: if self.success is not None: if callable(self.success): return self.success(condition) - else: - message = f"self.success is present, but not callable -> {self.success}" - raise AirflowException(message) + message = f"self.success is present, but not callable -> {self.success}" + raise AirflowException(message) return bool(condition) diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py index f9363aabd7f52..56bf1da68fd11 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py @@ -384,8 +384,7 @@ def find_job_id_by_name(self, job_name: str) -> int | None: if not matching_jobs: return None - else: - return matching_jobs[0]["job_id"] + return matching_jobs[0]["job_id"] def list_pipelines( self, batch_size: int = 25, pipeline_name: str | None = None, notebook_path: str | None = None @@ -445,8 +444,7 @@ def find_pipeline_id_by_name(self, pipeline_name: str) -> str | None: if not pipeline_name or len(matching_pipelines) == 0: return None - else: - return matching_pipelines[0]["pipeline_id"] + return matching_pipelines[0]["pipeline_id"] def get_run_page_url(self, run_id: int) -> str: """ @@ -640,8 +638,7 @@ def get_latest_repair_id(self, run_id: int) -> int | None: repair_history = response["repair_history"] if len(repair_history) == 1: return None - else: - return repair_history[-1]["id"] + return repair_history[-1]["id"] def get_cluster_state(self, cluster_id: str) -> ClusterState: """ diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py index f804f64a32608..b2e08ecaad989 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py @@ -197,9 +197,8 @@ def _parse_host(host: str) -> str: if urlparse_host: # In this case, host = https://xx.cloud.databricks.com return urlparse_host - else: - # In this case, host = xx.cloud.databricks.com - return host + # In this case, host = xx.cloud.databricks.com + return host def _get_retry_object(self) -> Retrying: """ @@ -555,27 +554,27 @@ def _get_token(self, raise_error: bool = False) -> str | None: "Using token auth. For security reasons, please set token in Password field instead of extra" ) return self.databricks_conn.extra_dejson["token"] - elif not self.databricks_conn.login and self.databricks_conn.password: + if not self.databricks_conn.login and self.databricks_conn.password: self.log.debug("Using token auth.") return self.databricks_conn.password - elif "azure_tenant_id" in self.databricks_conn.extra_dejson: + if "azure_tenant_id" in self.databricks_conn.extra_dejson: if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Azure SPN credentials aren't provided") self.log.debug("Using AAD Token for SPN.") return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE) - elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): + if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): self.log.debug("Using AAD Token for managed identity.") self._check_azure_metadata_service() return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE) - elif self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY, False): + if self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY, False): self.log.debug("Using default Azure Credential authentication.") return self._get_aad_token_for_default_az_credential(DEFAULT_DATABRICKS_SCOPE) - elif self.databricks_conn.extra_dejson.get("service_principal_oauth", False): + if self.databricks_conn.extra_dejson.get("service_principal_oauth", False): if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Service Principal credentials aren't provided") self.log.debug("Using Service Principal Token.") return self._get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host)) - elif raise_error: + if raise_error: raise AirflowException("Token authentication isn't configured") return None @@ -586,28 +585,28 @@ async def _a_get_token(self, raise_error: bool = False) -> str | None: "Using token auth. For security reasons, please set token in Password field instead of extra" ) return self.databricks_conn.extra_dejson["token"] - elif not self.databricks_conn.login and self.databricks_conn.password: + if not self.databricks_conn.login and self.databricks_conn.password: self.log.debug("Using token auth.") return self.databricks_conn.password - elif "azure_tenant_id" in self.databricks_conn.extra_dejson: + if "azure_tenant_id" in self.databricks_conn.extra_dejson: if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Azure SPN credentials aren't provided") self.log.debug("Using AAD Token for SPN.") return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE) - elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): + if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False): self.log.debug("Using AAD Token for managed identity.") await self._a_check_azure_metadata_service() return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE) - elif self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY, False): + if self.databricks_conn.extra_dejson.get(DEFAULT_AZURE_CREDENTIAL_SETTING_KEY, False): self.log.debug("Using AzureDefaultCredential for authentication.") return await self._a_get_aad_token_for_default_az_credential(DEFAULT_DATABRICKS_SCOPE) - elif self.databricks_conn.extra_dejson.get("service_principal_oauth", False): + if self.databricks_conn.extra_dejson.get("service_principal_oauth", False): if self.databricks_conn.login == "" or self.databricks_conn.password == "": raise AirflowException("Service Principal credentials aren't provided") self.log.debug("Using Service Principal Token.") return await self._a_get_sp_token(OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host)) - elif raise_error: + if raise_error: raise AirflowException("Token authentication isn't configured") return None diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py index d7d9423aa2965..6f32461291a74 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -283,8 +283,7 @@ def run( return None if return_single_query_results(sql, return_last, split_statements): return results[-1] - else: - return results + return results def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple[Any, ...] | list[tuple[Any, ...]]: """Transform the databricks Row objects into namedtuple.""" @@ -297,12 +296,11 @@ def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple[Any, ... rows_fields = tuple(rows[0].__fields__) rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore return cast("list[tuple[Any, ...]]", [rows_object(*row) for row in rows]) - elif isinstance(result, Row): + if isinstance(result, Row): row_fields = tuple(result.__fields__) row_object = namedtuple("Row", row_fields, rename=True) # type: ignore return cast("tuple[Any, ...]", row_object(*result)) - else: - raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}") + raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}") def bulk_dump(self, table, tmp_file): raise NotImplementedError() diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 8921dff770292..abafd217009bb 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -1277,15 +1277,14 @@ def _generate_databricks_task_key( task_key = f"{self.dag_id}__{task_id}".encode() _databricks_task_key = hashlib.md5(task_key).hexdigest() return _databricks_task_key - else: - if not self._databricks_task_key or len(self._databricks_task_key) > 100: - self.log.info( - "databricks_task_key has not be provided or the provided one exceeds 100 characters and will be truncated by the Databricks API. This will cause failure when trying to monitor the task. A task_key will be generated using the hash value of dag_id+task_id" - ) - task_key = f"{self.dag_id}__{self.task_id}".encode() - self._databricks_task_key = hashlib.md5(task_key).hexdigest() - self.log.info("Generated databricks task_key: %s", self._databricks_task_key) - return self._databricks_task_key + if not self._databricks_task_key or len(self._databricks_task_key) > 100: + self.log.info( + "databricks_task_key has not be provided or the provided one exceeds 100 characters and will be truncated by the Databricks API. This will cause failure when trying to monitor the task. A task_key will be generated using the hash value of dag_id+task_id" + ) + task_key = f"{self.dag_id}__{self.task_id}".encode() + self._databricks_task_key = hashlib.md5(task_key).hexdigest() + self.log.info("Generated databricks task_key: %s", self._databricks_task_key) + return self._databricks_task_key @property def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None: diff --git a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py index 215aaf07022e5..bee664753ce39 100644 --- a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py +++ b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py @@ -226,6 +226,5 @@ def poke(self, context: Context) -> bool: self.log.debug("Partition sensor result: %s", partition_result) if partition_result: return True - else: - message = f"Specified partition(s): {self.partitions} were not found." - raise AirflowException(message) + message = f"Specified partition(s): {self.partitions} were not found." + raise AirflowException(message) diff --git a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py index 8b0fcad654f75..e19bccf290a36 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py @@ -36,17 +36,16 @@ def normalise_json_content(content, json_path: str = "json") -> str | bool | lis normalise = normalise_json_content if isinstance(content, (str, bool)): return content - elif isinstance(content, (int, float)): + if isinstance(content, (int, float)): # Databricks can tolerate either numeric or string types in the API backend. return str(content) - elif isinstance(content, (list, tuple)): + if isinstance(content, (list, tuple)): return [normalise(e, f"{json_path}[{i}]") for i, e in enumerate(content)] - elif isinstance(content, dict): + if isinstance(content, dict): return {k: normalise(v, f"{json_path}[{k}]") for k, v in content.items()} - else: - param_type = type(content) - msg = f"Type {param_type} used for parameter {json_path} is not a number or a string" - raise AirflowException(msg) + param_type = type(content) + msg = f"Type {param_type} used for parameter {json_path} is not a number or a string" + raise AirflowException(msg) def validate_trigger_event(event: dict): diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py index d0c437ad02940..d3a85942ca8fb 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py @@ -293,11 +293,10 @@ def create_successful_response_mock(content): def create_post_side_effect(exception, status_code=500): if exception != requests_exceptions.HTTPError: return exception() - else: - response = mock.MagicMock() - response.status_code = status_code - response.raise_for_status.side_effect = exception(response=response) - return response + response = mock.MagicMock() + response.status_code = status_code + response.raise_for_status.side_effect = exception(response=response) + return response def setup_mock_requests(mock_requests, exception, status_code=500, error_count=None, response_content=None): diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py index 68e06133fd627..1a77ad1d0ca90 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/operators/dbt.py @@ -216,30 +216,29 @@ def execute(self, context: Context): raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.") return self.run_id - else: - end_time = time.time() + self.timeout - job_run_info = JobRunInfo(account_id=self.account_id, run_id=self.run_id) - job_run_status = self.hook.get_job_run_status(**job_run_info) - if not DbtCloudJobRunStatus.is_terminal(job_run_status): - self.defer( - timeout=self.execution_timeout, - trigger=DbtCloudRunJobTrigger( - conn_id=self.dbt_cloud_conn_id, - run_id=self.run_id, - end_time=end_time, - account_id=self.account_id, - poll_interval=self.check_interval, - ), - method_name="execute_complete", - ) - elif job_run_status == DbtCloudJobRunStatus.SUCCESS.value: - self.log.info("Job run %s has completed successfully.", self.run_id) - return self.run_id - elif job_run_status in ( - DbtCloudJobRunStatus.CANCELLED.value, - DbtCloudJobRunStatus.ERROR.value, - ): - raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.") + end_time = time.time() + self.timeout + job_run_info = JobRunInfo(account_id=self.account_id, run_id=self.run_id) + job_run_status = self.hook.get_job_run_status(**job_run_info) + if not DbtCloudJobRunStatus.is_terminal(job_run_status): + self.defer( + timeout=self.execution_timeout, + trigger=DbtCloudRunJobTrigger( + conn_id=self.dbt_cloud_conn_id, + run_id=self.run_id, + end_time=end_time, + account_id=self.account_id, + poll_interval=self.check_interval, + ), + method_name="execute_complete", + ) + elif job_run_status == DbtCloudJobRunStatus.SUCCESS.value: + self.log.info("Job run %s has completed successfully.", self.run_id) + return self.run_id + elif job_run_status in ( + DbtCloudJobRunStatus.CANCELLED.value, + DbtCloudJobRunStatus.ERROR.value, + ): + raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.") else: if self.deferrable is True: warnings.warn( @@ -255,7 +254,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> int: self.run_id = event["run_id"] if event["status"] == "cancelled": raise DbtCloudJobRunException(f"Job run {self.run_id} has been cancelled.") - elif event["status"] == "error": + if event["status"] == "error": raise DbtCloudJobRunException(f"Job run {self.run_id} has failed.") self.log.info(event["message"]) return int(event["run_id"]) diff --git a/providers/docker/src/airflow/providers/docker/operators/docker.py b/providers/docker/src/airflow/providers/docker/operators/docker.py index a314282adcd39..ba2a7adef9151 100644 --- a/providers/docker/src/airflow/providers/docker/operators/docker.py +++ b/providers/docker/src/airflow/providers/docker/operators/docker.py @@ -60,8 +60,7 @@ def stringify(line: str | bytes): decode_method = getattr(line, "decode", None) if decode_method: return decode_method(encoding="utf-8", errors="surrogateescape") - else: - return line + return line def fetch_logs(log_stream, log: Logger): @@ -424,19 +423,18 @@ def _run_image_with_mounts(self, target_mounts, add_tmp_variable: bool) -> list[ raise DockerContainerFailedSkipException( f"Docker container returned exit code {self.skip_on_exit_code}. Skipping.", logs=log_lines ) - elif result["StatusCode"] != 0: + if result["StatusCode"] != 0: raise DockerContainerFailedException(f"Docker container failed: {result!r}", logs=log_lines) if self.retrieve_output: return self._attempt_to_retrieve_result() - elif self.do_xcom_push: + if self.do_xcom_push: if not log_lines: return None try: if self.xcom_all: return log_lines - else: - return log_lines[-1] + return log_lines[-1] except StopIteration: # handle the case when there is not a single line to iterate on return None diff --git a/providers/docker/src/airflow/providers/docker/operators/docker_swarm.py b/providers/docker/src/airflow/providers/docker/operators/docker_swarm.py index 2c403802f2f40..8c07df0bf2826 100644 --- a/providers/docker/src/airflow/providers/docker/operators/docker_swarm.py +++ b/providers/docker/src/airflow/providers/docker/operators/docker_swarm.py @@ -230,7 +230,7 @@ def _run_service(self) -> None: if self.auto_remove == "force": self.cli.remove_service(self.service["ID"]) raise AirflowException(f"Service did not complete: {self.service!r}") - elif self.auto_remove in ["success", "force"]: + if self.auto_remove in ["success", "force"]: if not self.service: raise RuntimeError("The 'service' should be initialized before!") self.cli.remove_service(self.service["ID"]) @@ -296,8 +296,7 @@ def _attempt_to_retrieve_results(self): file_contents.append(file_content) if len(file_contents) == 1: return file_contents[0] - else: - return file_contents + return file_contents except APIError: return None diff --git a/providers/edge/src/airflow/providers/edge/cli/edge_command.py b/providers/edge/src/airflow/providers/edge/cli/edge_command.py index a24a0e2603438..74de6081d6012 100644 --- a/providers/edge/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/edge/src/airflow/providers/edge/cli/edge_command.py @@ -98,8 +98,7 @@ def force_use_internal_api_on_edge_worker(): def _status_signal() -> signal.Signals: if IS_WINDOWS: return signal.SIGBREAK # type: ignore[attr-defined] - else: - return signal.SIGUSR2 + return signal.SIGUSR2 SIG_STATUS = _status_signal() @@ -133,18 +132,16 @@ def _write_pid_to_pidfile(pid_file_path: str): pid_stored_in_pid_file = read_pid_from_pidfile(pid_file_path) if os.getpid() == pid_stored_in_pid_file: raise SystemExit("A PID file has already been written") - else: - # PID file was written by dead or already running instance - if psutil.pid_exists(pid_stored_in_pid_file): - # case 1: another instance uses the same path for its PID file - raise SystemExit( - f"The PID file {pid_file_path} contains the PID of another running process. " - "Configuration issue: edge worker instance must use different PID file paths!" - ) - else: - # case 2: previous instance crashed without cleaning up its PID file - logger.warning("PID file is orphaned. Cleaning up.") - remove_existing_pidfile(pid_file_path) + # PID file was written by dead or already running instance + if psutil.pid_exists(pid_stored_in_pid_file): + # case 1: another instance uses the same path for its PID file + raise SystemExit( + f"The PID file {pid_file_path} contains the PID of another running process. " + "Configuration issue: edge worker instance must use different PID file paths!" + ) + # case 2: previous instance crashed without cleaning up its PID file + logger.warning("PID file is orphaned. Cleaning up.") + remove_existing_pidfile(pid_file_path) logger.debug("PID file written to %s.", pid_file_path) write_pid_to_pidfile(pid_file_path) diff --git a/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py b/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py index 59d3ef5cdb778..81322756ffb51 100644 --- a/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py +++ b/providers/edge/src/airflow/providers/edge/plugins/edge_executor_plugin.py @@ -105,7 +105,7 @@ def modify_maintenance_comment_on_update(maintenance_comment: str | None, userna f"[{datetime.now().strftime('%Y-%m-%d %H:%M')}] - {username} updated maintenance mode\nComment:", maintenance_comment, ) - elif re.search(r"^\[[-\d:\s]+\] - .+ updated maintenance mode\r?\nComment:.*", maintenance_comment): + if re.search(r"^\[[-\d:\s]+\] - .+ updated maintenance mode\r?\nComment:.*", maintenance_comment): return re.sub( r"^\[[-\d:\s]+\] - .+ updated maintenance mode\r?\nComment:", f"[{datetime.now().strftime('%Y-%m-%d %H:%M')}] - {username} updated maintenance mode\nComment:", diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py index 5bf3d2308c294..ee5fc205902b2 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/log/es_task_handler.py @@ -113,8 +113,7 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: if isinstance(val, TaskInstance): val.try_number = ti.try_number return val - else: - raise AirflowException(f"Could not find TaskInstance for {ti}") + raise AirflowException(f"Could not find TaskInstance for {ti}") class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, LoggingMixin): @@ -356,8 +355,7 @@ def _read( ) if AIRFLOW_V_3_0_PLUS: return missing_log_message, metadata - else: - return [("", missing_log_message)], metadata # type: ignore[list-item] + return [("", missing_log_message)], metadata # type: ignore[list-item] if ( # Assume end of log after not receiving new log for N min, cur_ts.diff(last_log_ts).in_minutes() >= 5 diff --git a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py index 4eb4c5e68d796..c7746001d6856 100644 --- a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py +++ b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/fake_elasticsearch.py @@ -424,8 +424,7 @@ def delete(self, index, doc_type, id, params=None, headers=None): if found: return result_dict - else: - raise NotFoundError(404, json.dumps(result_dict)) + raise NotFoundError(404, json.dumps(result_dict)) @query_params("allow_no_indices", "expand_wildcards", "ignore_unavailable", "preference", "routing") def suggest(self, body, index=None): diff --git a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/utilities/__init__.py b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/utilities/__init__.py index 62fef03473aa6..50b883e0f02f3 100644 --- a/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/utilities/__init__.py +++ b/providers/elasticsearch/tests/unit/elasticsearch/log/elasticmock/utilities/__init__.py @@ -139,7 +139,7 @@ def _wrapped(*args, **kwargs): if http_auth is not None and api_key is not None: raise ValueError("Only one of 'http_auth' and 'api_key' may be passed at a time") - elif http_auth is not None: + if http_auth is not None: headers["authorization"] = f"Basic {_base64_auth_header(http_auth)}" elif api_key is not None: headers["authorization"] = f"ApiKey {_base64_auth_header(api_key)}" diff --git a/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py b/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py index 27bab93afbbe9..aab96627a4354 100644 --- a/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py +++ b/providers/exasol/src/airflow/providers/exasol/hooks/exasol.py @@ -257,8 +257,7 @@ def run( if return_single_query_results(sql, return_last, split_statements): self.descriptions = [_last_columns] return _last_result - else: - return results + return results def set_autocommit(self, conn, autocommit: bool) -> None: """ diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py b/providers/fab/src/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py index c71f5ebed927d..bf48beff41c9a 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py @@ -63,7 +63,6 @@ def requires_authentication(function: T): def decorated(*args, **kwargs): if auth_current_user() is not None or current_app.config.get("AUTH_ROLE_PUBLIC", None): return function(*args, **kwargs) - else: - return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"}) + return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"}) return cast("T", decorated) diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py b/providers/fab/src/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py index 769b70bae7585..d419524bf1043 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py @@ -101,7 +101,7 @@ def _gssapi_authenticate(token) -> _KerberosAuth | None: user=kerberos.authGSSServerUserName(state), token=kerberos.authGSSServerResponse(state), ) - elif return_code == kerberos.AUTH_GSS_CONTINUE: + if return_code == kerberos.AUTH_GSS_CONTINUE: return _KerberosAuth(return_code=return_code) return _KerberosAuth(return_code=return_code) except kerberos.GSSError: @@ -139,7 +139,7 @@ def decorated(*args, **kwargs): if auth.token is not None: response.headers["WWW-Authenticate"] = f"negotiate {auth.token}" return response - elif auth.return_code != kerberos.AUTH_GSS_CONTINUE: + if auth.return_code != kerberos.AUTH_GSS_CONTINUE: return _forbidden() return _unauthorized() diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/api_fastapi/services/login.py b/providers/fab/src/airflow/providers/fab/auth_manager/api_fastapi/services/login.py index 57f883f047e3b..0792efd18ce46 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/api_fastapi/services/login.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/api_fastapi/services/login.py @@ -54,5 +54,4 @@ def create_token( user=user, expiration_time_in_seconds=expiration_time_in_seconds ) ) - else: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password") + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password") diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 1f96a7d7e23fd..009e9ed9d6d33 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -329,24 +329,23 @@ def is_authorized_dag( if not access_entity: # Scenario 1 return self._is_authorized_dag(method=method, details=details, user=user) - else: - # Scenario 2 - resource_types = self._get_fab_resource_types(access_entity) - dag_method: ResourceMethod = "GET" if method == "GET" else "PUT" - - if (details and details.id) and not self._is_authorized_dag( - method=dag_method, details=details, user=user - ): - return False - - return all( - ( - self._is_authorized(method=method, resource_type=resource_type, user=user) - if resource_type != RESOURCE_DAG_RUN or not hasattr(permissions, "resource_name") - else self._is_authorized_dag_run(method=method, details=details, user=user) - ) - for resource_type in resource_types + # Scenario 2 + resource_types = self._get_fab_resource_types(access_entity) + dag_method: ResourceMethod = "GET" if method == "GET" else "PUT" + + if (details and details.id) and not self._is_authorized_dag( + method=dag_method, details=details, user=user + ): + return False + + return all( + ( + self._is_authorized(method=method, resource_type=resource_type, user=user) + if resource_type != RESOURCE_DAG_RUN or not hasattr(permissions, "resource_name") + else self._is_authorized_dag_run(method=method, details=details, user=user) ) + for resource_type in resource_types + ) def is_authorized_backfill( self, diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py index 1e05a2e1bfe8d..5bd05cd131da3 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -1401,12 +1401,11 @@ def find_user(self, username=None, email=None): .filter(func.lower(self.user_model.username) == func.lower(username)) .one_or_none() ) - else: - return ( - self.get_session.query(self.user_model) - .filter(func.lower(self.user_model.username) == func.lower(username)) - .one_or_none() - ) + return ( + self.get_session.query(self.user_model) + .filter(func.lower(self.user_model.username) == func.lower(username)) + .one_or_none() + ) except MultipleResultsFound: log.error("Multiple results found for user %s", username) return None @@ -1883,8 +1882,7 @@ def auth_user_ldap(self, username, password): self._rotate_session_id() self.update_user_auth_stat(user) return user - else: - return None + return None except ldap.LDAPError as e: msg = None @@ -1893,9 +1891,8 @@ def auth_user_ldap(self, username, password): if (msg is not None) and ("desc" in msg): log.error(LOGMSG_ERR_SEC_AUTH_LDAP, e.message["desc"]) return None - else: - log.error(e) - return None + log.error(e) + return None def check_password(self, username, password) -> bool: """ @@ -1933,14 +1930,13 @@ def auth_user_db(self, username, password): ) log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username) return None - elif check_password_hash(user.password, password): + if check_password_hash(user.password, password): self._rotate_session_id() self.update_user_auth_stat(user, True) return user - else: - self.update_user_auth_stat(user, False) - log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username) - return None + self.update_user_auth_stat(user, False) + log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username) + return None def set_oauth_session(self, provider, oauth_response): """Set the current session with OAuth user secrets.""" @@ -2033,8 +2029,7 @@ def auth_user_oauth(self, userinfo): self._rotate_session_id() self.update_user_auth_stat(user) return user - else: - return None + return None def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, Any]: """ @@ -2108,8 +2103,7 @@ def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, "email": data["email"], "role_keys": data.get("groups", []), } - else: - log.error(data.get("error_description")) + log.error(data.get("error_description")) return {} # for Auth0 if provider == "auth0": @@ -2147,8 +2141,7 @@ def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, "role_keys": me.get("groups", []), } - else: - return {} + return {} @staticmethod def oauth_token_getter(): diff --git a/providers/fab/src/airflow/providers/fab/www/auth.py b/providers/fab/src/airflow/providers/fab/www/auth.py index c4bb1b3701c45..0ac83def61986 100644 --- a/providers/fab/src/airflow/providers/fab/www/auth.py +++ b/providers/fab/src/airflow/providers/fab/www/auth.py @@ -91,9 +91,8 @@ def wraps(self, *args, **kwargs): resource_pk=kwargs.get("pk"), ): return f(self, *args, **kwargs) - else: - log.warning(LOGMSG_ERR_SEC_ACCESS_DENIED, permission_str, self.__class__.__name__) - flash(as_unicode(FLAMSG_ERR_SEC_ACCESS_DENIED), "danger") + log.warning(LOGMSG_ERR_SEC_ACCESS_DENIED, permission_str, self.__class__.__name__) + flash(as_unicode(FLAMSG_ERR_SEC_ACCESS_DENIED), "danger") return redirect(get_auth_manager().get_url_login(next_url=request.url)) f._permission_name = permission_str @@ -139,7 +138,7 @@ def _has_access(*, is_authorized: bool, func: Callable, args, kwargs): """ if is_authorized: return func(*args, **kwargs) - elif get_fab_auth_manager().is_logged_in() and not get_auth_manager().is_authorized_view( + if get_fab_auth_manager().is_logged_in() and not get_auth_manager().is_authorized_view( access_view=AccessView.WEBSITE, user=get_fab_auth_manager().get_user(), ): @@ -151,11 +150,10 @@ def _has_access(*, is_authorized: bool, func: Callable, args, kwargs): ), 403, ) - elif not get_fab_auth_manager().is_logged_in(): + if not get_fab_auth_manager().is_logged_in(): return redirect(get_auth_manager().get_url_login(next_url=request.url)) - else: - access_denied = get_access_denied_message() - flash(access_denied, "danger") + access_denied = get_access_denied_message() + flash(access_denied, "danger") return redirect(url_for("FabIndexView.index")) diff --git a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py index 5c84f90e76153..5776d2b2aff6a 100644 --- a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py +++ b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py @@ -236,8 +236,7 @@ def get_app(self): """ if self.app: return self.app - else: - return current_app + return current_app @property def get_session(self): diff --git a/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py b/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py index 2de408b3d5187..76883ae3ad891 100644 --- a/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py +++ b/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py @@ -146,15 +146,13 @@ def _handle_api_not_found(ex): # i.e. "no route for it" defined, need to be handled # here on the application level return common_error_handler(ex) - else: - return views.not_found(ex) + return views.not_found(ex) @app.errorhandler(405) def _handle_method_not_allowed(ex): if any([request.path.startswith(p) for p in base_paths]): return common_error_handler(ex) - else: - return views.method_not_allowed(ex) + return views.method_not_allowed(ex) app.register_error_handler(ProblemException, common_error_handler) diff --git a/providers/fab/src/airflow/providers/fab/www/views.py b/providers/fab/src/airflow/providers/fab/www/views.py index 9279b4e42678c..d38faab802cd7 100644 --- a/providers/fab/src/airflow/providers/fab/www/views.py +++ b/providers/fab/src/airflow/providers/fab/www/views.py @@ -74,8 +74,7 @@ def index(self): response.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure) return response - else: - return redirect(conf.get("api", "base_url", fallback="/"), code=302) + return redirect(conf.get("api", "base_url", fallback="/"), code=302) def show_traceback(error): diff --git a/providers/fab/tests/unit/fab/decorators.py b/providers/fab/tests/unit/fab/decorators.py index 26c7bf701b2bb..f359563e2eaf6 100644 --- a/providers/fab/tests/unit/fab/decorators.py +++ b/providers/fab/tests/unit/fab/decorators.py @@ -59,5 +59,4 @@ def func(*args, **kwargs): if _func is None: return decorator_dont_initialize_flask_app_submodules - else: - return decorator_dont_initialize_flask_app_submodules(_func) + return decorator_dont_initialize_flask_app_submodules(_func) diff --git a/providers/facebook/src/airflow/providers/facebook/ads/hooks/ads.py b/providers/facebook/src/airflow/providers/facebook/ads/hooks/ads.py index 78538cba08123..0cff27bbd68d1 100644 --- a/providers/facebook/src/airflow/providers/facebook/ads/hooks/ads.py +++ b/providers/facebook/src/airflow/providers/facebook/ads/hooks/ads.py @@ -141,14 +141,13 @@ def bulk_facebook_report( "%s Account Id used to extract data from Facebook Ads Iterators successfully", account_id ) return all_insights - else: - return self._facebook_report( - account_id=self.facebook_ads_config["account_id"], - api=api, - params=params, - fields=fields, - sleep_time=sleep_time, - ) + return self._facebook_report( + account_id=self.facebook_ads_config["account_id"], + api=api, + params=params, + fields=fields, + sleep_time=sleep_time, + ) def _facebook_report( self, diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 46b7732b2831a..a7ae255e5a9a8 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -1376,8 +1376,7 @@ def split_tablename( def var_print(var_name): if var_name is None: return "" - else: - return f"Format exception for {var_name}: " + return f"Format exception for {var_name}: " if table_input.count(".") + table_input.count(":") > 3: raise ValueError(f"{var_print(var_name)}Use either : or . to specify project got {table_input}") @@ -1955,8 +1954,7 @@ def split_tablename( def var_print(var_name): if var_name is None: return "" - else: - return f"Format exception for {var_name}: " + return f"Format exception for {var_name}: " if table_input.count(".") + table_input.count(":") > 3: raise ValueError(f"{var_print(var_name)}Use either : or . to specify project got {table_input}") diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_batch.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_batch.py index c50c8b6ce28b7..36fbed8c917a2 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_batch.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_batch.py @@ -147,19 +147,18 @@ def wait_for_job( status: JobStatus.State = job.status.state if status == JobStatus.State.SUCCEEDED: return job - elif status == JobStatus.State.FAILED: + if status == JobStatus.State.FAILED: message = ( "Unexpected error in the operation: " "Batch job with name {job_name} has failed its execution." ) raise AirflowException(message) - elif status == JobStatus.State.DELETION_IN_PROGRESS: + if status == JobStatus.State.DELETION_IN_PROGRESS: message = ( "Unexpected error in the operation: Batch job with name {job_name} is being deleted." ) raise AirflowException(message) - else: - time.sleep(polling_period_seconds) + time.sleep(polling_period_seconds) except Exception as e: self.log.exception("Exception occurred while checking for job completion.") raise e diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py index bc6dedf459868..8ffdd7c1a484d 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -634,35 +634,32 @@ def start_proxy(self) -> None: self._download_sql_proxy_if_needed() if self.sql_proxy_process: raise AirflowException(f"The sql proxy is already running: {self.sql_proxy_process}") - else: - command_to_run = [self.sql_proxy_path] - command_to_run.extend(self.command_line_parameters) - self.log.info("Creating directory %s", self.cloud_sql_proxy_socket_directory) - Path(self.cloud_sql_proxy_socket_directory).mkdir(parents=True, exist_ok=True) - command_to_run.extend(self._get_credential_parameters()) - self.log.info("Running the command: `%s`", " ".join(command_to_run)) - - self.sql_proxy_process = Popen(command_to_run, stdin=PIPE, stdout=PIPE, stderr=PIPE) - self.log.info("The pid of cloud_sql_proxy: %s", self.sql_proxy_process.pid) - while True: - line = ( - self.sql_proxy_process.stderr.readline().decode("utf-8") - if self.sql_proxy_process.stderr - else "" - ) - return_code = self.sql_proxy_process.poll() - if line == "" and return_code is not None: - self.sql_proxy_process = None - raise AirflowException( - f"The cloud_sql_proxy finished early with return code {return_code}!" - ) - if line != "": - self.log.info(line) - if "googleapi: Error" in line or "invalid instance name:" in line: - self.stop_proxy() - raise AirflowException(f"Error when starting the cloud_sql_proxy {line}!") - if "Ready for new connections" in line: - return + command_to_run = [self.sql_proxy_path] + command_to_run.extend(self.command_line_parameters) + self.log.info("Creating directory %s", self.cloud_sql_proxy_socket_directory) + Path(self.cloud_sql_proxy_socket_directory).mkdir(parents=True, exist_ok=True) + command_to_run.extend(self._get_credential_parameters()) + self.log.info("Running the command: `%s`", " ".join(command_to_run)) + + self.sql_proxy_process = Popen(command_to_run, stdin=PIPE, stdout=PIPE, stderr=PIPE) + self.log.info("The pid of cloud_sql_proxy: %s", self.sql_proxy_process.pid) + while True: + line = ( + self.sql_proxy_process.stderr.readline().decode("utf-8") + if self.sql_proxy_process.stderr + else "" + ) + return_code = self.sql_proxy_process.poll() + if line == "" and return_code is not None: + self.sql_proxy_process = None + raise AirflowException(f"The cloud_sql_proxy finished early with return code {return_code}!") + if line != "": + self.log.info(line) + if "googleapi: Error" in line or "invalid instance name:" in line: + self.stop_proxy() + raise AirflowException(f"Error when starting the cloud_sql_proxy {line}!") + if "Ready for new connections" in line: + return def stop_proxy(self) -> None: """ @@ -672,10 +669,9 @@ def stop_proxy(self) -> None: """ if not self.sql_proxy_process: raise AirflowException("The sql proxy is not started yet") - else: - self.log.info("Stopping the cloud_sql_proxy pid: %s", self.sql_proxy_process.pid) - self.sql_proxy_process.kill() - self.sql_proxy_process = None + self.log.info("Stopping the cloud_sql_proxy pid: %s", self.sql_proxy_process.pid) + self.sql_proxy_process.kill() + self.sql_proxy_process = None # Cleanup! self.log.info("Removing the socket directory: %s", self.cloud_sql_proxy_socket_directory) shutil.rmtree(self.cloud_sql_proxy_socket_directory, ignore_errors=True) @@ -704,8 +700,7 @@ def get_proxy_version(self) -> str | None: matched = re.search("[Vv]ersion (.*?);", result) if matched: return matched.group(1) - else: - return None + return None def get_socket_path(self) -> str: """ @@ -908,10 +903,9 @@ def _get_cert_from_secret(self, cert_name: str) -> str | None: secret_data = json.loads(base64.b64decode(secret.payload.data)) if cert_name in secret_data: return secret_data[cert_name] - else: - raise AirflowException( - "Invalid secret format. Expected dictionary with keys: `sslcert`, `sslkey`, `sslrootcert`" - ) + raise AirflowException( + "Invalid secret format. Expected dictionary with keys: `sslcert`, `sslkey`, `sslrootcert`" + ) def _set_temporary_ssl_file( self, cert_name: str, cert_path: str | None = None, cert_value: str | None = None @@ -1205,8 +1199,7 @@ def _get_iam_db_login(self) -> str: if self.database_type == "postgres": return self.cloudsql_connection.login.split(".gserviceaccount.com")[0] - else: - return self.cloudsql_connection.login.split("@")[0] + return self.cloudsql_connection.login.split("@")[0] def _generate_login_token(self, service_account) -> str: """Generate an IAM login token for Cloud SQL and return the token.""" diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index b9d19e81e174f..912c152bb13b8 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -218,7 +218,7 @@ def create_transfer_job(self, body: dict) -> dict: return ( self.get_conn().transferJobs().create(body=body).execute(num_retries=self.num_retries) ) - elif transfer_job.get(STATUS) == GcpTransferJobsStatus.DISABLED: + if transfer_job.get(STATUS) == GcpTransferJobsStatus.DISABLED: return self.enable_transfer_job(job_name=job_name, project_id=body.get(PROJECT_ID)) else: raise e diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py b/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py index c4aaa45e3d689..a8e00ff453201 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py @@ -148,8 +148,7 @@ def _compute_hook(self) -> ComputeEngineHook: return ComputeEngineHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain ) - else: - return ComputeEngineHook(gcp_conn_id=self.gcp_conn_id) + return ComputeEngineHook(gcp_conn_id=self.gcp_conn_id) def _load_connection_config(self): def _boolify(value): @@ -158,7 +157,7 @@ def _boolify(value): if isinstance(value, str): if value.lower() == "false": return False - elif value.lower() == "true": + if value.lower() == "true": return True return False diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py b/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py index 03077981d7b8f..2493bb91d9b0a 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/dataflow.py @@ -262,15 +262,14 @@ def _get_current_jobs(self) -> list[dict]: """ if not self._multiple_jobs and self._job_id: return [self.fetch_job_by_id(self._job_id)] - elif self._jobs: + if self._jobs: return [self.fetch_job_by_id(job["id"]) for job in self._jobs] - elif self._job_name: + if self._job_name: jobs = self._fetch_jobs_by_prefix_name(self._job_name.lower()) if len(jobs) == 1: self._job_id = jobs[0]["id"] return jobs - else: - raise ValueError("Missing both dataflow job ID and name.") + raise ValueError("Missing both dataflow job ID and name.") def fetch_job_by_id(self, job_id: str) -> dict[str, str]: """ @@ -435,12 +434,12 @@ def _check_dataflow_job_state(self, job) -> bool: f"'{current_expected_state}' is invalid." f" The value should be any of the following: {terminal_states}" ) - elif is_streaming and current_expected_state == DataflowJobStatus.JOB_STATE_DONE: + if is_streaming and current_expected_state == DataflowJobStatus.JOB_STATE_DONE: raise AirflowException( "Google Cloud Dataflow job's expected terminal state cannot be " "JOB_STATE_DONE while it is a streaming job" ) - elif not is_streaming and current_expected_state == DataflowJobStatus.JOB_STATE_DRAINED: + if not is_streaming and current_expected_state == DataflowJobStatus.JOB_STATE_DRAINED: raise AirflowException( "Google Cloud Dataflow job's expected terminal state cannot be " "JOB_STATE_DRAINED while it is a batch job" diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py b/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py index 66e073e375ce1..c228b3fa7c0e9 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/datafusion.py @@ -170,9 +170,9 @@ def _cdap_request( def _check_response_status_and_data(response, message: str) -> None: if response.status == 404: raise AirflowNotFoundException(message) - elif response.status == 409: + if response.status == 409: raise ConflictException("Conflict: Resource is still in use.") - elif response.status != 200: + if response.status != 200: raise AirflowException(message) if response.data is None: raise AirflowException( @@ -572,8 +572,7 @@ async def _get_link(self, url: str, session): raise if pipeline: return pipeline - else: - raise AirflowException("Could not retrieve pipeline. Aborting.") + raise AirflowException("Could not retrieve pipeline. Aborting.") async def get_pipeline( self, diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/dlp.py b/providers/google/src/airflow/providers/google/cloud/hooks/dlp.py index c0730497e904e..a81f279d8f876 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/dlp.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/dlp.py @@ -268,7 +268,7 @@ def create_dlp_job( if job.state == DlpJob.JobState.DONE: return job - elif job.state in [ + if job.state in [ DlpJob.JobState.PENDING, DlpJob.JobState.RUNNING, DlpJob.JobState.JOB_STATE_UNSPECIFIED, diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/gcs.py b/providers/google/src/airflow/providers/google/cloud/hooks/gcs.py index fd8fa26904339..0202b70d5465c 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/gcs.py @@ -358,11 +358,10 @@ def download( ) self.log.info("File downloaded to %s", filename) return filename - else: - get_hook_lineage_collector().add_input_asset( - context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name} - ) - return blob.download_as_bytes() + get_hook_lineage_collector().add_input_asset( + context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name} + ) + return blob.download_as_bytes() except GoogleCloudError: if attempt == num_max_attempts - 1: @@ -556,7 +555,7 @@ def _call_with_retry(f: Callable[[], None]) -> None: "specify a single parameter, either 'filename' for " "local file uploads or 'data' for file content uploads." ) - elif filename: + if filename: if not mime_type: mime_type = "application/octet-stream" if gzip: diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/providers/google/src/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 309eb45d2e69a..e35c9a0a5ebc9 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -352,8 +352,7 @@ def check_cluster_autoscaling_ability(self, cluster: Cluster | dict): or node_pools_autoscaled ): return True - else: - return False + return False class GKEAsyncHook(GoogleBaseAsyncHook): diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/mlengine.py b/providers/google/src/airflow/providers/google/cloud/hooks/mlengine.py index ff33946cf9cbd..e37bc13c89802 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/mlengine.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/mlengine.py @@ -78,8 +78,7 @@ def _poll_with_exponential_delay( if e.resp.status != 429: log.info("Something went wrong. Not retrying: %s", format(e)) raise - else: - time.sleep((2**i) + random.random()) + time.sleep((2**i) + random.random()) raise ValueError(f"Connection could not be established after {max_n} retries.") @@ -219,12 +218,11 @@ def cancel_job( if e.resp.status == 404: self.log.error("Job with job_id %s does not exist. ", job_id) raise - elif e.resp.status == 400: + if e.resp.status == 400: self.log.info("Job with job_id %s is already complete, cancellation aborted.", job_id) return {} - else: - self.log.error("Failed to cancel MLEngine job: %s", e) - raise + self.log.error("Failed to cancel MLEngine job: %s", e) + raise def get_job(self, project_id: str, job_id: str) -> dict: """ diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/stackdriver.py b/providers/google/src/airflow/providers/google/cloud/hooks/stackdriver.py index a24c540c73d1c..2163747e37c74 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/stackdriver.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/stackdriver.py @@ -121,10 +121,9 @@ def list_alert_policies( ) if format_ == "dict": return [AlertPolicy.to_dict(policy) for policy in policies_] - elif format_ == "json": + if format_ == "json": return [AlertPolicy.to_jsoon(policy) for policy in policies_] - else: - return policies_ + return policies_ @GoogleBaseHook.fallback_to_default_project_id def _toggle_policy_status( @@ -395,10 +394,9 @@ def list_notification_channels( ) if format_ == "dict": return [NotificationChannel.to_dict(channel) for channel in channels] - elif format_ == "json": + if format_ == "json": return [NotificationChannel.to_json(channel) for channel in channels] - else: - return channels + return channels @GoogleBaseHook.fallback_to_default_project_id def _toggle_channel_status( diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/vision.py b/providers/google/src/airflow/providers/google/cloud/hooks/vision.py index a96b75aa13b81..a144ffad04add 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/vision.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/vision.py @@ -107,8 +107,7 @@ def get_entity_with_name( # Not enough parameters to construct the name. Trying to use the name from Product / ProductSet. if explicit_name: return entity - else: - raise AirflowException(ERR_UNABLE_TO_CREATE.format(label=self.label, id_label=self.id_label)) + raise AirflowException(ERR_UNABLE_TO_CREATE.format(label=self.label, id_label=self.id_label)) class CloudVisionHook(GoogleBaseHook): diff --git a/providers/google/src/airflow/providers/google/cloud/openlineage/mixins.py b/providers/google/src/airflow/providers/google/cloud/openlineage/mixins.py index 57bbdafef2882..0c2f65b1b1bb2 100644 --- a/providers/google/src/airflow/providers/google/cloud/openlineage/mixins.py +++ b/providers/google/src/airflow/providers/google/cloud/openlineage/mixins.py @@ -207,15 +207,14 @@ def _get_dataset(self, table: dict, dataset_type: str) -> Dataset: name=dataset_name, facets=dataset_facets, ) - elif dataset_type == "output": + if dataset_type == "output": # Logic specific to creating OutputDataset (if needed) return OutputDataset( namespace=BIGQUERY_NAMESPACE, name=dataset_name, facets=dataset_facets, ) - else: - raise ValueError("Invalid dataset_type. Must be 'input' or 'output'") + raise ValueError("Invalid dataset_type. Must be 'input' or 'output'") def _get_table_facets_safely(self, table_name: str) -> dict[str, DatasetFacet]: try: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/alloy_db.py b/providers/google/src/airflow/providers/google/cloud/operators/alloy_db.py index 634f97b6d9878..5346076699bec 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/alloy_db.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/alloy_db.py @@ -145,8 +145,7 @@ def get_operation_result(self, operation: Operation) -> proto.Message | None: if self.validate_request: # Validation requests are only validated and aren't executed, thus no operation result is expected return None - else: - return self.hook.wait_for_operation(timeout=self.timeout, operation=operation) + return self.hook.wait_for_operation(timeout=self.timeout, operation=operation) class AlloyDBCreateClusterOperator(AlloyDBWriteBaseOperator): diff --git a/providers/google/src/airflow/providers/google/cloud/operators/automl.py b/providers/google/src/airflow/providers/google/cloud/operators/automl.py index 16b6fd47f7a5d..cb106537bf7ea 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/automl.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/automl.py @@ -259,11 +259,11 @@ def hook(self) -> CloudAutoMLHook | PredictionServiceHook: gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - else: # endpoint_id defined - return PredictionServiceHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + # endpoint_id defined + return PredictionServiceHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) @cached_property def model(self) -> Model | None: diff --git a/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py b/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py index 1952c7261cf87..64f0e07fdba3e 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/bigquery.py @@ -287,7 +287,7 @@ def _handle_job_error(job: BigQueryJob | UnknownJob) -> None: def _validate_records(self, records) -> None: if not records: raise AirflowException(f"The following query returned zero rows: {self.sql}") - elif not all(records): + if not all(records): self._raise_exception( # type: ignore[attr-defined] f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}" ) @@ -2976,14 +2976,13 @@ def execute(self, context: Any): f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" ) - else: - # Job already reached state DONE - if job.state == "DONE": - raise AirflowException("Job is already in state DONE. Can not reattach to this job.") + # Job already reached state DONE + if job.state == "DONE": + raise AirflowException("Job is already in state DONE. Can not reattach to this job.") - # We are reattaching to a job - self.log.info("Reattaching to existing Job in state %s", job.state) - self._handle_job_error(job) + # We are reattaching to a job + self.log.info("Reattaching to existing Job in state %s", job.state) + self._handle_job_error(job) job_types = { LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"], @@ -3036,24 +3035,23 @@ def execute(self, context: Any): self._handle_job_error(job) return self.job_id - else: - if job.running(): - self.defer( - timeout=self.execution_timeout, - trigger=BigQueryInsertJobTrigger( - conn_id=self.gcp_conn_id, - job_id=self.job_id, - project_id=self.project_id, - location=self.location or hook.location, - poll_interval=self.poll_interval, - impersonation_chain=self.impersonation_chain, - cancel_on_kill=self.cancel_on_kill, - ), - method_name="execute_complete", - ) - self.log.info("Current state of job %s is %s", job.job_id, job.state) - self._handle_job_error(job) - return self.job_id + if job.running(): + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryInsertJobTrigger( + conn_id=self.gcp_conn_id, + job_id=self.job_id, + project_id=self.project_id, + location=self.location or hook.location, + poll_interval=self.poll_interval, + impersonation_chain=self.impersonation_chain, + cancel_on_kill=self.cancel_on_kill, + ), + method_name="execute_complete", + ) + self.log.info("Current state of job %s is %s", job.job_id, job.state) + self._handle_job_error(job) + return self.job_id def execute_complete(self, context: Context, event: dict[str, Any]) -> str | None: """ diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_batch.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_batch.py index 0587076f9e263..bc6da9ab502ec 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_batch.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_batch.py @@ -100,19 +100,18 @@ def execute(self, context: Context): return Job.to_dict(completed_job) - else: - self.defer( - trigger=CloudBatchJobFinishedTrigger( - job_name=job.name, - project_id=self.project_id, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - location=self.region, - polling_period_seconds=self.polling_period_seconds, - timeout=self.timeout_seconds, - ), - method_name="execute_complete", - ) + self.defer( + trigger=CloudBatchJobFinishedTrigger( + job_name=job.name, + project_id=self.project_id, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + location=self.region, + polling_period_seconds=self.polling_period_seconds, + timeout=self.timeout_seconds, + ), + method_name="execute_complete", + ) def execute_complete(self, context: Context, event: dict): job_status = event["status"] @@ -120,8 +119,7 @@ def execute_complete(self, context: Context, event: dict): hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id, self.impersonation_chain) job = hook.get_job(job_name=event["job_name"]) return Job.to_dict(job) - else: - raise AirflowException(f"Unexpected error in the operation: {event['message']}") + raise AirflowException(f"Unexpected error in the operation: {event['message']}") class CloudBatchDeleteJobOperator(GoogleCloudBaseOperator): diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py index 565b2be13aa27..7ae38b7f236f3 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py @@ -275,8 +275,7 @@ def execute_complete(self, context: Context, event: dict): build_id=event["id_"], ) return event["instance"] - else: - raise AirflowException(f"Unexpected error in the operation: {event['message']}") + raise AirflowException(f"Unexpected error in the operation: {event['message']}") class CloudBuildCreateBuildTriggerOperator(GoogleCloudBaseOperator): diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_composer.py index 5bc0d11afdbb1..77d411cc08e90 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_composer.py @@ -186,18 +186,17 @@ def execute(self, context: Context): if not self.deferrable: environment = hook.wait_for_operation(timeout=self.timeout, operation=result) return Environment.to_dict(environment) - else: - self.defer( - trigger=CloudComposerExecutionTrigger( - project_id=self.project_id, - region=self.region, - operation_name=result.operation.name, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - pooling_period_seconds=self.pooling_period_seconds, - ), - method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, - ) + self.defer( + trigger=CloudComposerExecutionTrigger( + project_id=self.project_id, + region=self.region, + operation_name=result.operation.name, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + pooling_period_seconds=self.pooling_period_seconds, + ), + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + ) except AlreadyExists: environment = hook.get_environment( project_id=self.project_id, @@ -225,8 +224,7 @@ def execute_complete(self, context: Context, event: dict): metadata=self.metadata, ) return Environment.to_dict(env) - else: - raise AirflowException(f"Unexpected error in the operation: {event['operation_name']}") + raise AirflowException(f"Unexpected error in the operation: {event['operation_name']}") class CloudComposerDeleteEnvironmentOperator(GoogleCloudBaseOperator): @@ -555,18 +553,17 @@ def execute(self, context: Context): if not self.deferrable: environment = hook.wait_for_operation(timeout=self.timeout, operation=result) return Environment.to_dict(environment) - else: - self.defer( - trigger=CloudComposerExecutionTrigger( - project_id=self.project_id, - region=self.region, - operation_name=result.operation.name, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - pooling_period_seconds=self.pooling_period_seconds, - ), - method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, - ) + self.defer( + trigger=CloudComposerExecutionTrigger( + project_id=self.project_id, + region=self.region, + operation_name=result.operation.name, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + pooling_period_seconds=self.pooling_period_seconds, + ), + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + ) def execute_complete(self, context: Context, event: dict): if event["operation_done"]: @@ -584,8 +581,7 @@ def execute_complete(self, context: Context, event: dict): metadata=self.metadata, ) return Environment.to_dict(env) - else: - raise AirflowException(f"Unexpected error in the operation: {event['operation_name']}") + raise AirflowException(f"Unexpected error in the operation: {event['operation_name']}") class CloudComposerListImageVersionsOperator(GoogleCloudBaseOperator): diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py index 1b5e433b7d017..1e0e03d52ef06 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_run.py @@ -308,19 +308,18 @@ def execute(self, context: Context): self._fail_if_execution_failed(result) job = hook.get_job(job_name=result.job, region=self.region, project_id=self.project_id) return Job.to_dict(job) - else: - self.defer( - trigger=CloudRunJobFinishedTrigger( - operation_name=self.operation.operation.name, - job_name=self.job_name, - project_id=self.project_id, - location=self.region, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - polling_period_seconds=self.polling_period_seconds, - ), - method_name="execute_complete", - ) + self.defer( + trigger=CloudRunJobFinishedTrigger( + operation_name=self.operation.operation.name, + job_name=self.job_name, + project_id=self.project_id, + location=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_period_seconds=self.polling_period_seconds, + ), + method_name="execute_complete", + ) def execute_complete(self, context: Context, event: dict): status = event["status"] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py index 18f3d9a443d12..b1ad44b121d22 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_sql.py @@ -477,15 +477,14 @@ def execute(self, context: Context): f"Cloud SQL instance with ID {self.instance} does not exist. " "Please specify another instance to patch." ) - else: - CloudSQLInstanceLink.persist( - context=context, - task_instance=self, - cloud_sql_instance=self.instance, - project_id=self.project_id or hook.project_id, - ) + CloudSQLInstanceLink.persist( + context=context, + task_instance=self, + cloud_sql_instance=self.instance, + project_id=self.project_id or hook.project_id, + ) - return hook.patch_instance(project_id=self.project_id, body=self.body, instance=self.instance) + return hook.patch_instance(project_id=self.project_id, body=self.body, instance=self.instance) class CloudSQLDeleteInstanceOperator(CloudSQLBaseOperator): @@ -531,8 +530,7 @@ def execute(self, context: Context) -> bool | None: if not self._check_if_instance_exists(self.instance, hook): print(f"Cloud SQL instance with ID {self.instance} does not exist. Aborting delete.") return True - else: - return hook.delete_instance(project_id=self.project_id, instance=self.instance) + return hook.delete_instance(project_id=self.project_id, instance=self.instance) class CloudSQLCloneInstanceOperator(CloudSQLBaseOperator): @@ -612,19 +610,18 @@ def execute(self, context: Context): f"Cloud SQL instance with ID {self.instance} does not exist. " "Please specify another instance to patch." ) - else: - body = { - "cloneContext": { - "kind": "sql#cloneContext", - "destinationInstanceName": self.destination_instance_name, - **self.clone_context, - } + body = { + "cloneContext": { + "kind": "sql#cloneContext", + "destinationInstanceName": self.destination_instance_name, + **self.clone_context, } - return hook.clone_instance( - project_id=self.project_id, - body=body, - instance=self.instance, - ) + } + return hook.clone_instance( + project_id=self.project_id, + body=body, + instance=self.instance, + ) class CloudSQLCreateInstanceDatabaseOperator(CloudSQLBaseOperator): @@ -728,8 +725,7 @@ def execute(self, context: Context) -> bool | None: database, ) return True - else: - return hook.create_database(project_id=self.project_id, instance=self.instance, body=self.body) + return hook.create_database(project_id=self.project_id, instance=self.instance, body=self.body) class CloudSQLPatchInstanceDatabaseOperator(CloudSQLBaseOperator): @@ -824,16 +820,15 @@ def execute(self, context: Context) -> None: f"Cloud SQL instance with ID {self.instance} does not contain database '{self.database}'. " "Please specify another database to patch." ) - else: - CloudSQLInstanceDatabaseLink.persist( - context=context, - task_instance=self, - cloud_sql_instance=self.instance, - project_id=self.project_id or hook.project_id, - ) - return hook.patch_database( - project_id=self.project_id, instance=self.instance, database=self.database, body=self.body - ) + CloudSQLInstanceDatabaseLink.persist( + context=context, + task_instance=self, + cloud_sql_instance=self.instance, + project_id=self.project_id or hook.project_id, + ) + return hook.patch_database( + project_id=self.project_id, instance=self.instance, database=self.database, body=self.body + ) class CloudSQLDeleteInstanceDatabaseOperator(CloudSQLBaseOperator): @@ -910,10 +905,9 @@ def execute(self, context: Context) -> bool | None: f"Aborting database delete." ) return True - else: - return hook.delete_database( - project_id=self.project_id, instance=self.instance, database=self.database - ) + return hook.delete_database( + project_id=self.project_id, instance=self.instance, database=self.database + ) class CloudSQLExportInstanceOperator(CloudSQLBaseOperator): @@ -1029,17 +1023,16 @@ def execute(self, context: Context) -> None: return hook._wait_for_operation_to_complete( project_id=self.project_id, operation_name=operation_name ) - else: - self.defer( - trigger=CloudSQLExportTrigger( - operation_name=operation_name, - project_id=self.project_id or hook.project_id, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - poke_interval=self.poke_interval, - ), - method_name="execute_complete", - ) + self.defer( + trigger=CloudSQLExportTrigger( + operation_name=operation_name, + project_id=self.project_id or hook.project_id, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) def execute_complete(self, context, event=None) -> None: """ diff --git a/providers/google/src/airflow/providers/google/cloud/operators/compute.py b/providers/google/src/airflow/providers/google/cloud/operators/compute.py index e797e193c3bbc..411e8b666d6e6 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/compute.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/compute.py @@ -1402,16 +1402,15 @@ def execute(self, context: Context) -> bool | None: request_id=self.request_id, project_id=self.project_id, ) - else: - # Idempotence achieved - ComputeInstanceGroupManagerDetailsLink.persist( - context=context, - task_instance=self, - location_id=self.zone, - resource_id=self.resource_id, - project_id=self.project_id or hook.project_id, - ) - return True + # Idempotence achieved + ComputeInstanceGroupManagerDetailsLink.persist( + context=context, + task_instance=self, + location_id=self.zone, + resource_id=self.resource_id, + project_id=self.project_id or hook.project_id, + ) + return True class ComputeEngineInsertInstanceGroupManagerOperator(ComputeEngineBaseOperator): diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py index 64ba220f4999a..4b47aef0e2d62 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py @@ -340,7 +340,7 @@ def _get_init_action_timeout(self) -> dict: unit = match.group(2) if unit == "s": return {"seconds": val} - elif unit == "m": + if unit == "m": return {"seconds": int(timedelta(minutes=val).total_seconds())} raise AirflowException( @@ -822,26 +822,24 @@ def execute(self, context: Context) -> dict: ) self.log.info("Cluster created.") return Cluster.to_dict(cluster) - else: - cluster = hook.get_cluster( - project_id=self.project_id, region=self.region, cluster_name=self.cluster_name - ) - if cluster.status.state == cluster.status.State.RUNNING: - self.log.info("Cluster created.") - return Cluster.to_dict(cluster) - else: - self.defer( - trigger=DataprocClusterTrigger( - cluster_name=self.cluster_name, - project_id=self.project_id, - region=self.region, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - polling_interval_seconds=self.polling_interval_seconds, - delete_on_error=self.delete_on_error, - ), - method_name="execute_complete", - ) + cluster = hook.get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name + ) + if cluster.status.state == cluster.status.State.RUNNING: + self.log.info("Cluster created.") + return Cluster.to_dict(cluster) + self.defer( + trigger=DataprocClusterTrigger( + cluster_name=self.cluster_name, + project_id=self.project_id, + region=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + delete_on_error=self.delete_on_error, + ), + method_name="execute_complete", + ) except AlreadyExists: if not self.use_if_exists: raise @@ -1022,7 +1020,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None """ if event and event["status"] == "error": raise AirflowException(event["message"]) - elif event is None: + if event is None: raise AirflowException("No event received in trigger callback") self.log.info("Cluster deleted.") @@ -1377,8 +1375,7 @@ def execute(self, context: Context): self.hook.wait_for_job(job_id=job_id, region=self.region, project_id=self.project_id) self.log.info("Job %s completed successfully.", job_id) return job_id - else: - raise AirflowException("Create a job template before") + raise AirflowException("Create a job template before") def execute_complete(self, context, event=None) -> None: """ @@ -1916,9 +1913,9 @@ def execute(self, context: Context): state = job.status.state if state == JobStatus.State.DONE: return self.job_id - elif state == JobStatus.State.ERROR: + if state == JobStatus.State.ERROR: raise AirflowException(f"Job failed:\n{job}") - elif state == JobStatus.State.CANCELLED: + if state == JobStatus.State.CANCELLED: raise AirflowException(f"Job was cancelled:\n{job}") self.defer( trigger=DataprocSubmitTrigger( diff --git a/providers/google/src/airflow/providers/google/cloud/operators/functions.py b/providers/google/src/airflow/providers/google/cloud/operators/functions.py index 9b067fec76518..26782eb3f79aa 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/functions.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/functions.py @@ -379,10 +379,9 @@ def __init__( def _validate_inputs(self) -> None: if not self.name: raise AttributeError("Empty parameter: name") - else: - pattern = FUNCTION_NAME_COMPILED_PATTERN - if not pattern.match(self.name): - raise AttributeError(f"Parameter name must match pattern: {FUNCTION_NAME_PATTERN}") + pattern = FUNCTION_NAME_COMPILED_PATTERN + if not pattern.match(self.name): + raise AttributeError(f"Parameter name must match pattern: {FUNCTION_NAME_PATTERN}") def execute(self, context: Context): hook = CloudFunctionsHook( @@ -404,9 +403,8 @@ def execute(self, context: Context): if status == 404: self.log.info("The function does not exist in this project") return None - else: - self.log.error("An error occurred. Exiting.") - raise e + self.log.error("An error occurred. Exiting.") + raise e class CloudFunctionInvokeFunctionOperator(GoogleCloudBaseOperator): diff --git a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py index 9e6675ca7af27..300494b91d53b 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -457,8 +457,7 @@ def _body_field(self, field_name: str, default_value: Any = None) -> Any: """Extract the value of the given field name.""" if isinstance(self.body, dict): return self.body.get(field_name, default_value) - else: - return getattr(self.body, field_name, default_value) + return getattr(self.body, field_name, default_value) def _alert_deprecated_body_fields(self) -> None: """Generate warning messages if deprecated fields were used in the body.""" diff --git a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py index 36cccd1ce1812..5dced11d3b4f4 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py @@ -188,13 +188,12 @@ def execute(self, context: Context) -> bool | None: ) if hook.get_instance(project_id=self.project_id, instance_id=self.instance_id): return hook.delete_instance(project_id=self.project_id, instance_id=self.instance_id) - else: - self.log.info( - "Instance '%s' does not exist in project '%s'. Aborting delete.", - self.instance_id, - self.project_id, - ) - return True + self.log.info( + "Instance '%s' does not exist in project '%s'. Aborting delete.", + self.instance_id, + self.project_id, + ) + return True class SpannerQueryDatabaseInstanceOperator(GoogleCloudBaseOperator): @@ -401,13 +400,12 @@ def execute(self, context: Context) -> bool | None: database_id=self.database_id, ddl_statements=self.ddl_statements, ) - else: - self.log.info( - "The database '%s' in project '%s' and instance '%s' already exists. Nothing to do. Exiting.", - self.database_id, - self.project_id, - self.instance_id, - ) + self.log.info( + "The database '%s' in project '%s' and instance '%s' already exists. Nothing to do. Exiting.", + self.database_id, + self.project_id, + self.instance_id, + ) return True @@ -496,21 +494,20 @@ def execute(self, context: Context) -> None: f"and instance '{self.instance_id}' is missing. " f"Create the database first before you can update it." ) - else: - SpannerDatabaseLink.persist( - context=context, - task_instance=self, - instance_id=self.instance_id, - database_id=self.database_id, - project_id=self.project_id or hook.project_id, - ) - return hook.update_database( - project_id=self.project_id, - instance_id=self.instance_id, - database_id=self.database_id, - ddl_statements=self.ddl_statements, - operation_id=self.operation_id, - ) + SpannerDatabaseLink.persist( + context=context, + task_instance=self, + instance_id=self.instance_id, + database_id=self.database_id, + project_id=self.project_id or hook.project_id, + ) + return hook.update_database( + project_id=self.project_id, + instance_id=self.instance_id, + database_id=self.database_id, + ddl_statements=self.ddl_statements, + operation_id=self.operation_id, + ) class SpannerDeleteDatabaseInstanceOperator(GoogleCloudBaseOperator): @@ -589,7 +586,6 @@ def execute(self, context: Context) -> bool: self.instance_id, ) return True - else: - return hook.delete_database( - project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id - ) + return hook.delete_database( + project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id + ) diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py index c59a97fe6fb1f..ae5c5d60808a4 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -108,14 +108,12 @@ def _get_logical_dates(self, context) -> tuple[datetime, datetime]: if isinstance(self.execution_range, timedelta): if self.execution_range < timedelta(0): return context["logical_date"], context["logical_date"] - self.execution_range - else: - return context["logical_date"] - self.execution_range, context["logical_date"] - elif isinstance(self.execution_range, list) and len(self.execution_range) > 0: + return context["logical_date"] - self.execution_range, context["logical_date"] + if isinstance(self.execution_range, list) and len(self.execution_range) > 0: return self.execution_range[0], self.execution_range[1] if len( self.execution_range ) > 1 else context["logical_date"] - else: - return context["logical_date"] - timedelta(1), context["logical_date"] + return context["logical_date"] - timedelta(1), context["logical_date"] def poke(self, context: Context) -> bool: start_date, end_date = self._get_logical_dates(context) diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py index a38cc2db9ee17..e9721096a5e64 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py @@ -117,7 +117,7 @@ def poke(self, context: Context) -> bool: if job_status in self.expected_statuses: return True - elif job_status in DataflowJobStatus.TERMINAL_STATES: + if job_status in DataflowJobStatus.TERMINAL_STATES: message = f"Job with id '{self.job_id}' is already in terminal state: {job_status}" raise AirflowException(message) diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/dataproc.py b/providers/google/src/airflow/providers/google/cloud/sensors/dataproc.py index b8385c5b58b51..3138921aae4d4 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/dataproc.py @@ -100,17 +100,17 @@ def poke(self, context: Context) -> bool: if state == JobStatus.State.ERROR: message = f"Job failed:\n{job}" raise AirflowException(message) - elif state in { + if state in { JobStatus.State.CANCELLED, JobStatus.State.CANCEL_PENDING, JobStatus.State.CANCEL_STARTED, }: message = f"Job was cancelled:\n{job}" raise AirflowException(message) - elif state == JobStatus.State.DONE: + if state == JobStatus.State.DONE: self.log.debug("Job %s completed successfully.", self.dataproc_job_id) return True - elif state == JobStatus.State.ATTEMPT_FAILURE: + if state == JobStatus.State.ATTEMPT_FAILURE: self.log.debug("Job %s attempt has failed.", self.dataproc_job_id) self.log.info("Waiting for job %s to complete.", self.dataproc_job_id) @@ -179,13 +179,13 @@ def poke(self, context: Context) -> bool: if state == Batch.State.FAILED: message = "Batch failed" raise AirflowException(message) - elif state in { + if state in { Batch.State.CANCELLED, Batch.State.CANCELLING, }: message = "Batch was cancelled." raise AirflowException(message) - elif state == Batch.State.SUCCEEDED: + if state == Batch.State.SUCCEEDED: self.log.debug("Batch %s completed successfully.", self.batch_id) return True diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/gcs.py b/providers/google/src/airflow/providers/google/cloud/sensors/gcs.py index 569efea172c6c..a2f43149bf8e5 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/gcs.py @@ -306,23 +306,22 @@ def execute(self, context: Context): if not self.deferrable: super().execute(context) return self._matches + if not self.poke(context=context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=GCSPrefixBlobTrigger( + bucket=self.bucket, + prefix=self.prefix, + poke_interval=self.poke_interval, + google_cloud_conn_id=self.google_cloud_conn_id, + hook_params={ + "impersonation_chain": self.impersonation_chain, + }, + ), + method_name="execute_complete", + ) else: - if not self.poke(context=context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=GCSPrefixBlobTrigger( - bucket=self.bucket, - prefix=self.prefix, - poke_interval=self.poke_interval, - google_cloud_conn_id=self.google_cloud_conn_id, - hook_params={ - "impersonation_chain": self.impersonation_chain, - }, - ), - method_name="execute_complete", - ) - else: - return self._matches + return self._matches def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[str]]) -> str | list[str]: """Return immediately and rely on trigger to throw a success event. Callback for the trigger.""" diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/looker.py b/providers/google/src/airflow/providers/google/cloud/sensors/looker.py index ef51abcb2e011..95b2eb36583db 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/looker.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/looker.py @@ -65,13 +65,13 @@ def poke(self, context: Context) -> bool: msg = status_dict["message"] message = f'PDT materialization job failed. Job id: {self.materialization_id}. Message:\n"{msg}"' raise AirflowException(message) - elif status == JobStatus.CANCELLED.value: + if status == JobStatus.CANCELLED.value: message = f"PDT materialization job was cancelled. Job id: {self.materialization_id}." raise AirflowException(message) - elif status == JobStatus.UNKNOWN.value: + if status == JobStatus.UNKNOWN.value: message = f"PDT materialization job has unknown status. Job id: {self.materialization_id}." raise AirflowException(message) - elif status == JobStatus.DONE.value: + if status == JobStatus.DONE.value: self.log.debug( "PDT materialization job completed successfully. Job id: %s.", self.materialization_id ) diff --git a/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py b/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py index 03cf00f56695b..9b49ad7d7f82c 100644 --- a/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/sensors/pubsub.py @@ -168,20 +168,19 @@ def execute(self, context: Context) -> None: if not self.deferrable: super().execute(context) return self._return_value - else: - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=PubsubPullTrigger( - project_id=self.project_id, - subscription=self.subscription, - max_messages=self.max_messages, - ack_messages=self.ack_messages, - poke_interval=self.poke_interval, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ), - method_name="execute_complete", - ) + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=PubsubPullTrigger( + project_id=self.project_id, + subscription=self.subscription, + max_messages=self.max_messages, + ack_messages=self.ack_messages, + poke_interval=self.poke_interval, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) def execute_complete(self, context: Context, event: dict[str, str | list[str]]) -> Any: """If messages_callback is provided, execute it; otherwise, return immediately with trigger event message.""" diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index b005b79971f15..046cfc7a9dcee 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -241,14 +241,13 @@ def execute(self, context: Context): f"want to force rerun it consider setting `force_rerun=True`." f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" ) - else: - # Job already reached state DONE - if job.state == "DONE": - raise AirflowException("Job is already in state DONE. Can not reattach to this job.") - - # We are reattaching to a job - self.log.info("Reattaching to existing Job in state %s", job.state) - self._handle_job_error(job) + # Job already reached state DONE + if job.state == "DONE": + raise AirflowException("Job is already in state DONE. Can not reattach to this job.") + + # We are reattaching to a job + self.log.info("Reattaching to existing Job in state %s", job.state) + self._handle_job_error(job) self.job_id = job.job_id conf = job.to_api_repr()["configuration"]["extract"]["sourceTable"] diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py index ec6dba891d004..11500011109e2 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -261,29 +261,27 @@ def convert_value(self, value: Any | None) -> Any | None: """Convert value to BQ type.""" if not value or isinstance(value, (str, int, float, bool, dict)): return value - elif isinstance(value, bytes): + if isinstance(value, bytes): return b64encode(value).decode("ascii") - elif isinstance(value, UUID): + if isinstance(value, UUID): if self.encode_uuid: return b64encode(value.bytes).decode("ascii") - else: - return str(value) - elif isinstance(value, (datetime, Date)): return str(value) - elif isinstance(value, Decimal): + if isinstance(value, (datetime, Date)): + return str(value) + if isinstance(value, Decimal): return float(value) - elif isinstance(value, Time): + if isinstance(value, Time): return str(value).split(".")[0] - elif isinstance(value, (list, SortedSet)): + if isinstance(value, (list, SortedSet)): return self.convert_array_types(value) - elif hasattr(value, "_fields"): + if hasattr(value, "_fields"): return self.convert_user_type(value) - elif isinstance(value, tuple): + if isinstance(value, tuple): return self.convert_tuple_type(value) - elif isinstance(value, OrderedMapSerializedKey): + if isinstance(value, OrderedMapSerializedKey): return self.convert_map_type(value) - else: - raise AirflowException(f"Unexpected value: {value}") + raise AirflowException(f"Unexpected value: {value}") def convert_array_types(self, value: list[Any] | SortedSet) -> list[Any]: """Map convert_value over array.""" @@ -376,19 +374,17 @@ def get_bq_type(cls, type_: Any) -> str: """Convert type to equivalent BQ type.""" if cls.is_simple_type(type_): return CassandraToGCSOperator.CQL_TYPE_MAP[type_.cassname] - elif cls.is_record_type(type_): + if cls.is_record_type(type_): return "RECORD" - elif cls.is_array_type(type_): + if cls.is_array_type(type_): return cls.get_bq_type(type_.subtypes[0]) - else: - raise AirflowException("Not a supported type_: " + type_.cassname) + raise AirflowException("Not a supported type_: " + type_.cassname) @classmethod def get_bq_mode(cls, type_: Any) -> str: """Convert type to equivalent BQ mode.""" if cls.is_array_type(type_) or type_.cassname == "MapType": return "REPEATED" - elif cls.is_record_type(type_) or cls.is_simple_type(type_): + if cls.is_record_type(type_) or cls.is_simple_type(type_): return "NULLABLE" - else: - raise AirflowException("Not a supported type_: " + type_.cassname) + raise AirflowException("Not a supported type_: " + type_.cassname) diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py index 8d6adc398d8b1..938c257448bb8 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py @@ -157,8 +157,7 @@ def execute(self, context: Context): def _generate_rows_with_action(self, type_check: bool): if type_check and self.upload_as_account: return {FlushAction.EXPORT_EVERY_ACCOUNT: []} - else: - return {FlushAction.EXPORT_ONCE: []} + return {FlushAction.EXPORT_ONCE: []} def _prepare_rows_for_upload( self, diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 259cbd8e327b8..bd1fbf22cfda3 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -262,8 +262,7 @@ def __init__( f"{source_format} is not a valid source format. " f"Please use one of the following types: {ALLOWED_FORMATS}." ) - else: - self.source_format = source_format.upper() + self.source_format = source_format.upper() self.compression = compression self.create_disposition = create_disposition self.skip_leading_rows = skip_leading_rows @@ -406,14 +405,13 @@ def execute(self, context: Context): f"want to force rerun it consider setting `force_rerun=True`." f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" ) - else: - # Job already reached state DONE - if job.state == "DONE": - raise AirflowException("Job is already in state DONE. Can not reattach to this job.") + # Job already reached state DONE + if job.state == "DONE": + raise AirflowException("Job is already in state DONE. Can not reattach to this job.") - # We are reattaching to a job - self.log.info("Reattaching to existing Job in state %s", job.state) - self._handle_job_error(job) + # We are reattaching to a job + self.log.info("Reattaching to existing Job in state %s", job.state) + self._handle_job_error(job) job_types = { LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"], @@ -506,8 +504,7 @@ def _find_max_value_in_column(self): f"Could not determine MAX value in column {self.max_id_key} " f"since the default value of 'string_field_n' was set by BQ" ) - else: - raise AirflowException(e.message) + raise AirflowException(e.message) if rows: for row in rows: max_id = row[0] if row[0] else 0 @@ -644,11 +641,10 @@ def _use_existing_table(self): "allowed if write_disposition is " "'WRITE_APPEND' or 'WRITE_TRUNCATE'." ) - else: - # To provide backward compatibility - self.schema_update_options = list(self.schema_update_options or []) - self.log.info("Adding experimental 'schemaUpdateOptions': %s", self.schema_update_options) - self.configuration["load"]["schemaUpdateOptions"] = self.schema_update_options + # To provide backward compatibility + self.schema_update_options = list(self.schema_update_options or []) + self.log.info("Adding experimental 'schemaUpdateOptions': %s", self.schema_update_options) + self.configuration["load"]["schemaUpdateOptions"] = self.schema_update_options if self.max_bad_records: self.configuration["load"]["maxBadRecords"] = self.max_bad_records diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py b/providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py index c1a731c9b44ac..9839552cd0dc6 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/postgres_to_gcs.py @@ -57,9 +57,8 @@ def __iter__(self): def __next__(self): if self.rows: return self.rows.pop() - else: - self.initialized = True - return next(self.cursor) + self.initialized = True + return next(self.cursor) @property def description(self): diff --git a/providers/google/src/airflow/providers/google/cloud/utils/bigquery.py b/providers/google/src/airflow/providers/google/cloud/utils/bigquery.py index 53b92732d84a7..b0fd98539aeae 100644 --- a/providers/google/src/airflow/providers/google/cloud/utils/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/utils/bigquery.py @@ -27,16 +27,15 @@ def bq_cast(string_field: str, bq_type: str) -> None | int | float | bool | str: """ if string_field is None: return None - elif bq_type == "INTEGER": + if bq_type == "INTEGER": return int(string_field) - elif bq_type in ("FLOAT", "TIMESTAMP"): + if bq_type in ("FLOAT", "TIMESTAMP"): return float(string_field) - elif bq_type == "BOOLEAN": + if bq_type == "BOOLEAN": if string_field not in ["true", "false"]: raise ValueError(f"{string_field} must have value 'true' or 'false'") return string_field == "true" - else: - return string_field + return string_field def convert_job_id(job_id: str | list[str], project_id: str, location: str | None) -> Any: @@ -51,5 +50,4 @@ def convert_job_id(job_id: str | list[str], project_id: str, location: str | Non location = location or "US" if isinstance(job_id, list): return [f"{project_id}:{location}:{i}" for i in job_id] - else: - return f"{project_id}:{location}:{job_id}" + return f"{project_id}:{location}:{job_id}" diff --git a/providers/google/src/airflow/providers/google/cloud/utils/dataform.py b/providers/google/src/airflow/providers/google/cloud/utils/dataform.py index c5fe3fd3e2c46..f2baddf0a9626 100644 --- a/providers/google/src/airflow/providers/google/cloud/utils/dataform.py +++ b/providers/google/src/airflow/providers/google/cloud/utils/dataform.py @@ -202,7 +202,7 @@ def make_initialization_workspace_flow( def define_default_location(region: str) -> DataformLocations: if "us" in region: return DataformLocations.US - elif "europe" in region: + if "europe" in region: return DataformLocations.EUROPE regions_mapping: Mapping[str, DataformLocations] = {} diff --git a/providers/google/src/airflow/providers/google/common/utils/id_token_credentials.py b/providers/google/src/airflow/providers/google/common/utils/id_token_credentials.py index cdbffeee0cb4c..c1bdddf5880da 100644 --- a/providers/google/src/airflow/providers/google/common/utils/id_token_credentials.py +++ b/providers/google/src/airflow/providers/google/common/utils/id_token_credentials.py @@ -111,7 +111,7 @@ def _load_credentials_from_file( return current_credentials - elif credential_type == _SERVICE_ACCOUNT_TYPE: + if credential_type == _SERVICE_ACCOUNT_TYPE: try: return service_account.IDTokenCredentials.from_service_account_info( info, target_audience=target_audience diff --git a/providers/google/src/airflow/providers/google/leveldb/hooks/leveldb.py b/providers/google/src/airflow/providers/google/leveldb/hooks/leveldb.py index bac8e20a700a6..7e7e971d2860b 100644 --- a/providers/google/src/airflow/providers/google/leveldb/hooks/leveldb.py +++ b/providers/google/src/airflow/providers/google/leveldb/hooks/leveldb.py @@ -96,18 +96,17 @@ def run( if not value: raise ValueError("Please provide `value`!") return self.put(key, value) - elif command == "get": + if command == "get": return self.get(key) - elif command == "delete": + if command == "delete": return self.delete(key) - elif command == "write_batch": + if command == "write_batch": if not keys: raise ValueError("Please provide `keys`!") if not values: raise ValueError("Please provide `values`!") return self.write_batch(keys, values) - else: - raise LevelDBHookException("Unknown command for LevelDB hook") + raise LevelDBHookException("Unknown command for LevelDB hook") def put(self, key: bytes, value: bytes): """ diff --git a/providers/google/tests/system/google/cloud/cloud_build/example_cloud_build_trigger.py b/providers/google/tests/system/google/cloud/cloud_build/example_cloud_build_trigger.py index 50f1843ad6f91..c6a66a88b52b8 100644 --- a/providers/google/tests/system/google/cloud/cloud_build/example_cloud_build_trigger.py +++ b/providers/google/tests/system/google/cloud/cloud_build/example_cloud_build_trigger.py @@ -106,8 +106,7 @@ def get_project_number(): "No project found with specified name, " "or caller does not have permissions to read specified project" ) - else: - raise exc + raise exc with DAG( diff --git a/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query.py b/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query.py index cda64455eb525..f395cf35a54d1 100644 --- a/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query.py +++ b/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query.py @@ -108,17 +108,16 @@ def ip_configuration() -> dict[str, Any]: "enablePrivatePathForGoogleCloudServices": True, "privateNetwork": f"projects/{PROJECT_ID}/global/networks/default", } - else: - # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0). - # Consider specifying your network mask - # for allowing requests only from the trusted sources, not from anywhere. - return { - "ipv4Enabled": True, - "requireSsl": False, - "authorizedNetworks": [ - {"value": "0.0.0.0/0"}, - ], - } + # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0). + # Consider specifying your network mask + # for allowing requests only from the trusted sources, not from anywhere. + return { + "ipv4Enabled": True, + "requireSsl": False, + "authorizedNetworks": [ + {"value": "0.0.0.0/0"}, + ], + } def cloud_sql_instance_create_body(database_provider: dict[str, Any]) -> dict[str, Any]: diff --git a/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_iam.py b/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_iam.py index c1bc5df2b2940..6d69ad2018123 100644 --- a/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_iam.py +++ b/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_iam.py @@ -92,18 +92,17 @@ def ip_configuration() -> dict[str, Any]: "enablePrivatePathForGoogleCloudServices": True, "privateNetwork": f"projects/{PROJECT_ID}/global/networks/default", } - else: - # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0). - # Consider specifying your network mask - # for allowing requests only from the trusted sources, not from anywhere. - return { - "ipv4Enabled": True, - "requireSsl": True, - "sslMode": "TRUSTED_CLIENT_CERTIFICATE_REQUIRED", - "authorizedNetworks": [ - {"value": "0.0.0.0/0"}, - ], - } + # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0). + # Consider specifying your network mask + # for allowing requests only from the trusted sources, not from anywhere. + return { + "ipv4Enabled": True, + "requireSsl": True, + "sslMode": "TRUSTED_CLIENT_CERTIFICATE_REQUIRED", + "authorizedNetworks": [ + {"value": "0.0.0.0/0"}, + ], + } def cloud_sql_instance_create_body(database_provider: dict[str, Any]) -> dict[str, Any]: diff --git a/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py b/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py index 46e5381812312..8efb38af07f8d 100644 --- a/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py +++ b/providers/google/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_ssl.py @@ -114,18 +114,17 @@ def ip_configuration() -> dict[str, Any]: "enablePrivatePathForGoogleCloudServices": True, "privateNetwork": f"projects/{PROJECT_ID}/global/networks/default", } - else: - # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0). - # Consider specifying your network mask - # for allowing requests only from the trusted sources, not from anywhere. - return { - "ipv4Enabled": True, - "requireSsl": False, - "sslMode": "ENCRYPTED_ONLY", - "authorizedNetworks": [ - {"value": "0.0.0.0/0"}, - ], - } + # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0). + # Consider specifying your network mask + # for allowing requests only from the trusted sources, not from anywhere. + return { + "ipv4Enabled": True, + "requireSsl": False, + "sslMode": "ENCRYPTED_ONLY", + "authorizedNetworks": [ + {"value": "0.0.0.0/0"}, + ], + } def cloud_sql_instance_create_body(database_provider: dict[str, Any]) -> dict[str, Any]: diff --git a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_bigquery.py b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_bigquery.py index 304694126ffc1..24b90b95929c1 100644 --- a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_bigquery.py +++ b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_bigquery.py @@ -55,13 +55,13 @@ def split_tablename_side_effect(*args, **kwargs): TEST_DATASET, TEST_TABLE_ID, ) - elif kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLE2: + if kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLE2: return ( TEST_GCP_PROJECT_ID, TEST_DATASET, TEST_TABLE_ID + "-2", ) - elif kwargs["table_input"] == DESTINATION_PROJECT_DATASET_TABLE: + if kwargs["table_input"] == DESTINATION_PROJECT_DATASET_TABLE: return ( TEST_GCP_PROJECT_ID, TEST_DATASET + "_new", diff --git a/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py b/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py index 272ecbd37f48f..010863fd7c1ef 100644 --- a/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py +++ b/providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py @@ -245,8 +245,7 @@ def _client(self) -> hvac.Client: if _client.is_authenticated(): return _client - else: - raise VaultError("Vault Authentication Error!") + raise VaultError("Vault Authentication Error!") def _auth_userpass(self, _client: hvac.Client) -> None: if self.auth_mount_point: @@ -385,8 +384,7 @@ def _parse_secret_path(self, secret_path: str) -> tuple[str, str]: if len(split_secret_path) < 2: raise InvalidPath return split_secret_path[0], split_secret_path[1] - else: - return self.mount_point, secret_path + return self.mount_point, secret_path def get_secret(self, secret_path: str, secret_version: int | None = None) -> dict | None: """ diff --git a/providers/hashicorp/src/airflow/providers/hashicorp/secrets/vault.py b/providers/hashicorp/src/airflow/providers/hashicorp/secrets/vault.py index cf71a0c2f587d..7329cb3418fb3 100644 --- a/providers/hashicorp/src/airflow/providers/hashicorp/secrets/vault.py +++ b/providers/hashicorp/src/airflow/providers/hashicorp/secrets/vault.py @@ -171,8 +171,7 @@ def _parse_path(self, secret_path: str) -> tuple[str | None, str | None]: if len(split_secret_path) < 2: return None, None return split_secret_path[0], split_secret_path[1] - else: - return "", secret_path + return "", secret_path def get_response(self, conn_id: str) -> dict | None: """ diff --git a/providers/http/src/airflow/providers/http/hooks/http.py b/providers/http/src/airflow/providers/http/hooks/http.py index dab26a1a1bd90..9828cc1301918 100644 --- a/providers/http/src/airflow/providers/http/hooks/http.py +++ b/providers/http/src/airflow/providers/http/hooks/http.py @@ -189,7 +189,7 @@ def _configure_session_from_auth(self, session: Session, connection: Connection) def _extract_auth(self, connection: Connection) -> Any | None: if connection.login: return self.auth_type(connection.login, connection.password) - elif self._auth_type: + if self._auth_type: return self.auth_type() return None diff --git a/providers/http/src/airflow/providers/http/operators/http.py b/providers/http/src/airflow/providers/http/operators/http.py index b4b7ce7e012c1..ee03180881b80 100644 --- a/providers/http/src/airflow/providers/http/operators/http.py +++ b/providers/http/src/airflow/providers/http/operators/http.py @@ -276,8 +276,7 @@ def execute_complete( self.paginate_async(context=context, response=response, previous_responses=paginated_responses) return self.process_response(context=context, response=response) - else: - raise AirflowException(f"Unexpected error in the operation: {event['message']}") + raise AirflowException(f"Unexpected error in the operation: {event['message']}") def paginate_async( self, context: Context, response: Response, previous_responses: None | list[Response] = None diff --git a/providers/http/src/airflow/providers/http/sensors/http.py b/providers/http/src/airflow/providers/http/sensors/http.py index 04b6704095d01..d4fb153f6ffa4 100644 --- a/providers/http/src/airflow/providers/http/sensors/http.py +++ b/providers/http/src/airflow/providers/http/sensors/http.py @@ -167,7 +167,7 @@ def poke(self, context: Context) -> bool | PokeReturnValue: def execute(self, context: Context) -> Any: if not self.deferrable or self.response_check: return super().execute(context=context) - elif not self.poke(context): + if not self.poke(context): self.defer( timeout=timedelta(seconds=self.timeout), trigger=HttpSensorTrigger( diff --git a/providers/jenkins/src/airflow/providers/jenkins/operators/jenkins_job_trigger.py b/providers/jenkins/src/airflow/providers/jenkins/operators/jenkins_job_trigger.py index 218a3a7842c8b..a4e5d6eeec3c2 100644 --- a/providers/jenkins/src/airflow/providers/jenkins/operators/jenkins_job_trigger.py +++ b/providers/jenkins/src/airflow/providers/jenkins/operators/jenkins_job_trigger.py @@ -64,10 +64,9 @@ def jenkins_request_with_headers(jenkins_server: Jenkins, req: Request) -> Jenki # Jenkins's funky authentication means its nigh impossible to distinguish errors. if e.code in [401, 403, 500]: raise JenkinsException(f"Error in request. Possibly authentication failed [{e.code}]: {e.reason}") - elif e.code == 404: + if e.code == 404: raise jenkins.NotFoundException("Requested item could not be found") - else: - raise + raise except socket.timeout as e: raise jenkins.TimeoutException(f"Error in request: {e}") except URLError as e: diff --git a/providers/jenkins/src/airflow/providers/jenkins/sensors/jenkins.py b/providers/jenkins/src/airflow/providers/jenkins/sensors/jenkins.py index 29e69b4047afc..fbf3a62e9b487 100644 --- a/providers/jenkins/src/airflow/providers/jenkins/sensors/jenkins.py +++ b/providers/jenkins/src/airflow/providers/jenkins/sensors/jenkins.py @@ -72,9 +72,8 @@ def poke(self, context: Context) -> bool: self.log.info("Build is finished, result is %s", "build_result") if build_result in self.target_states: return True - else: - message = ( - f"Build {build_number} finished with a result {build_result}, " - f"which does not meet the target state {self.target_states}." - ) - raise AirflowException(message) + message = ( + f"Build {build_number} finished with a result {build_result}, " + f"which does not meet the target state {self.target_states}." + ) + raise AirflowException(message) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py index d4517fd170112..71c5ea1ca9978 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/batch.py @@ -233,8 +233,7 @@ def create_pool(self, pool: PoolAddParameter) -> None: except batch_models.BatchErrorException as err: if not err.error or err.error.code != "PoolExists": raise - else: - self.log.info("Pool %s already exists", pool.id) + self.log.info("Pool %s already exists", pool.id) def _get_latest_verified_image_vm_and_sku( self, @@ -322,8 +321,7 @@ def create_job(self, job: JobAddParameter) -> None: except batch_models.BatchErrorException as err: if not err.error or err.error.code != "JobExists": raise - else: - self.log.info("Job %s already exists", job.id) + self.log.info("Job %s already exists", job.id) def configure_task( self, @@ -366,8 +364,7 @@ def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None: except batch_models.BatchErrorException as err: if not err.error or err.error.code != "TaskExists": raise - else: - self.log.info("Task %s already exists", task.id) + self.log.info("Task %s already exists", task.id) def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> list[batch_models.CloudTask]: """ diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py index 806e63466351f..3e479c530e5e1 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_lake.py @@ -239,8 +239,7 @@ def list(self, path: str) -> list: """ if "*" in path: return self.get_conn().glob(path) - else: - return self.get_conn().walk(path) + return self.get_conn().walk(path) def remove(self, path: str, recursive: bool = False, ignore_not_found: bool = True) -> None: """ diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/fileshare.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/fileshare.py index 37c8647954d7d..24eb45be5d68f 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/fileshare.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/fileshare.py @@ -123,15 +123,14 @@ def share_service_client(self): return ShareServiceClient.from_connection_string( conn_str=self._connection_string, ) - elif self._account_url and (self._sas_token or self._account_access_key): + if self._account_url and (self._sas_token or self._account_access_key): credential = self._sas_token or self._account_access_key return ShareServiceClient(account_url=self._account_url, credential=credential) - else: - return ShareServiceClient( - account_url=self._account_url, - credential=self._get_sync_default_azure_credential(), - token_intent="backup", - ) + return ShareServiceClient( + account_url=self._account_url, + credential=self._get_sync_default_azure_credential(), + token_intent="backup", + ) @property def share_directory_client(self): @@ -142,7 +141,7 @@ def share_directory_client(self): share_name=self.share_name, directory_path=self.directory_path, ) - elif self._account_url and (self._sas_token or self._account_access_key): + if self._account_url and (self._sas_token or self._account_access_key): credential = self._sas_token or self._account_access_key return ShareDirectoryClient( account_url=self._account_url, @@ -150,14 +149,13 @@ def share_directory_client(self): directory_path=self.directory_path, credential=credential, ) - else: - return ShareDirectoryClient( - account_url=self._account_url, - share_name=self.share_name, - directory_path=self.directory_path, - credential=self._get_sync_default_azure_credential(), - token_intent="backup", - ) + return ShareDirectoryClient( + account_url=self._account_url, + share_name=self.share_name, + directory_path=self.directory_path, + credential=self._get_sync_default_azure_credential(), + token_intent="backup", + ) @property def share_file_client(self): @@ -168,7 +166,7 @@ def share_file_client(self): share_name=self.share_name, file_path=self.file_path, ) - elif self._account_url and (self._sas_token or self._account_access_key): + if self._account_url and (self._sas_token or self._account_access_key): credential = self._sas_token or self._account_access_key return ShareFileClient( account_url=self._account_url, @@ -176,14 +174,13 @@ def share_file_client(self): file_path=self.file_path, credential=credential, ) - else: - return ShareFileClient( - account_url=self._account_url, - share_name=self.share_name, - file_path=self.file_path, - credential=self._get_sync_default_azure_credential(), - token_intent="backup", - ) + return ShareFileClient( + account_url=self._account_url, + share_name=self.share_name, + file_path=self.file_path, + credential=self._get_sync_default_azure_credential(), + token_intent="backup", + ) def check_for_directory(self) -> bool: """Check if a directory exists on Azure File Share.""" diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index e0007570d2067..94ee15847f8ff 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -89,7 +89,7 @@ async def handle_response_async( status_code = HTTPStatus(resp.status_code) if status_code == HTTPStatus.BAD_REQUEST: raise AirflowBadRequest(message) - elif status_code == HTTPStatus.NOT_FOUND: + if status_code == HTTPStatus.NOT_FOUND: raise AirflowNotFoundException(message) raise AirflowException(message) return value diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py index af5a77aebf07e..fe00439a0b916 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/synapse.py @@ -347,8 +347,7 @@ def get_conn(self) -> ArtifactsClient: if self._conn is not None: return self._conn - else: - raise ValueError("Failed to create ArtifactsClient") + raise ValueError("Failed to create ArtifactsClient") @staticmethod def _create_client(credential: Credentials, endpoint: str) -> ArtifactsClient: diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py index 1162b4edf59da..8eb009592a15f 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py @@ -193,8 +193,7 @@ def get_conn(self) -> BlobServiceClient: if sas_token: if sas_token.startswith("https"): return BlobServiceClient(account_url=sas_token, **extra) - else: - return BlobServiceClient(account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra) + return BlobServiceClient(account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra) # Fall back to old auth (password) or use managed identity if not provided. credential = conn.password diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/adx.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/adx.py index 05c2cd8d54eeb..700d4a963e0fe 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/adx.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/adx.py @@ -81,5 +81,4 @@ def execute(self, context: Context) -> KustoResultTable | str: # TODO: Remove this after minimum Airflow version is 3.0 if conf.getboolean("core", "enable_xcom_pickling", fallback=False): return response.primary_results[0] - else: - return str(response.primary_results[0]) + return str(response.primary_results[0]) diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py index ea2c96d62a5a3..089baefb96c3b 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/container_instances.py @@ -402,8 +402,7 @@ def _monitor_logging(self, resource_group: str, name: str) -> int: "(make sure that the name is unique)." ) return 1 - else: - self.log.exception("Exception while getting container groups") + self.log.exception("Exception while getting container groups") except Exception: self.log.exception("Exception while getting container groups") diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index e25fbc2f46314..241b98431f0cd 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -254,10 +254,8 @@ def append_result( if append_result_as_list_if_absent: if isinstance(result, list): return result - else: - return [result] - else: - return result + return [result] + return result return results def pull_xcom(self, context: Context | dict[str, Any]) -> list: diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/transfers/s3_to_wasb.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/transfers/s3_to_wasb.py index ce7e3d7371a58..46c99381e7955 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/transfers/s3_to_wasb.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/transfers/s3_to_wasb.py @@ -261,7 +261,6 @@ def _create_key(full_path: str | None, prefix: str | None, file_name: str | None """Return a file key using its components.""" if full_path: return full_path - elif prefix and file_name: + if prefix and file_name: return f"{prefix}/{file_name}" - else: - raise InvalidKeyComponents + raise InvalidKeyComponents diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/utils.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/utils.py index faea3b7ba4b52..0c05953702e84 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/utils.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/utils.py @@ -77,8 +77,7 @@ def _get_default_azure_credential( workload_identity_tenant_id=workload_identity_tenant_id, additionally_allowed_tenants=[workload_identity_tenant_id], ) - else: - return credential_cls() + return credential_cls() get_sync_default_azure_credential: partial[DefaultAzureCredential] = partial( diff --git a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py index e69bdbbc9be39..c71ce92ea355d 100644 --- a/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py +++ b/providers/mongo/src/airflow/providers/mongo/hooks/mongo.py @@ -278,8 +278,7 @@ def find( if find_one: return collection.find_one(query, projection, **kwargs) - else: - return collection.find(query, projection, **kwargs) + return collection.find(query, projection, **kwargs) def insert_one( self, mongo_collection: str, doc: dict, mongo_db: str | None = None, **kwargs diff --git a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py index 95eaebd0bc667..a9edb03b453dc 100644 --- a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py +++ b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py @@ -111,8 +111,7 @@ def get_autocommit(self, conn: MySQLConnectionTypes) -> bool: """ if hasattr(conn.__class__, "autocommit") and isinstance(conn.__class__.autocommit, property): return conn.autocommit - else: - return conn.get_autocommit() # type: ignore[union-attr] + return conn.get_autocommit() # type: ignore[union-attr] def _get_conn_config_mysql_client(self, conn: Connection) -> dict: conn_config = { diff --git a/providers/odbc/src/airflow/providers/odbc/hooks/odbc.py b/providers/odbc/src/airflow/providers/odbc/hooks/odbc.py index 7ba3d3d57f058..8a4797def53c6 100644 --- a/providers/odbc/src/airflow/providers/odbc/hooks/odbc.py +++ b/providers/odbc/src/airflow/providers/odbc/hooks/odbc.py @@ -224,6 +224,5 @@ def _make_common_data_structure(self, result: Sequence[Row] | Row) -> list[tuple field_names = [col[0] for col in result[0].cursor_description] row_object = namedtuple("Row", field_names, rename=True) # type: ignore return cast("list[tuple]", [row_object(*row) for row in result]) - else: - field_names = [col[0] for col in result.cursor_description] - return cast("tuple", namedtuple("Row", field_names, rename=True)(*result)) # type: ignore + field_names = [col[0] for col in result.cursor_description] + return cast("tuple", namedtuple("Row", field_names, rename=True)(*result)) # type: ignore diff --git a/providers/openai/src/airflow/providers/openai/hooks/openai.py b/providers/openai/src/airflow/providers/openai/hooks/openai.py index 3ad02a0ec4337..b2fd19b96478d 100644 --- a/providers/openai/src/airflow/providers/openai/hooks/openai.py +++ b/providers/openai/src/airflow/providers/openai/hooks/openai.py @@ -469,9 +469,9 @@ def wait_for_batch(self, batch_id: str, wait_seconds: float = 3, timeout: float return if batch.status == BatchStatus.FAILED: raise OpenAIBatchJobException(f"Batch failed - \n{batch_id}") - elif batch.status in (BatchStatus.CANCELLED, BatchStatus.CANCELLING): + if batch.status in (BatchStatus.CANCELLED, BatchStatus.CANCELLING): raise OpenAIBatchJobException(f"Batch failed - batch was cancelled:\n{batch_id}") - elif batch.status == BatchStatus.EXPIRED: + if batch.status == BatchStatus.EXPIRED: raise OpenAIBatchJobException( f"Batch failed - batch couldn't be completed within the hour time window :\n{batch_id}" ) diff --git a/providers/openfaas/src/airflow/providers/openfaas/hooks/openfaas.py b/providers/openfaas/src/airflow/providers/openfaas/hooks/openfaas.py index 492128946710c..9bda14775c3c5 100644 --- a/providers/openfaas/src/airflow/providers/openfaas/hooks/openfaas.py +++ b/providers/openfaas/src/airflow/providers/openfaas/hooks/openfaas.py @@ -64,8 +64,7 @@ def deploy_function(self, overwrite_function_if_exist: bool, body: dict[str, Any self.log.error("Response status %d", response.status_code) self.log.error("Failed to deploy") raise AirflowException("failed to deploy") - else: - self.log.info("Function deployed %s", self.function_name) + self.log.info("Function deployed %s", self.function_name) def invoke_async_function(self, body: dict[str, Any]) -> None: """Invoke function asynchronously.""" @@ -100,8 +99,7 @@ def update_function(self, body: dict[str, Any]) -> None: self.log.error("Response status %d", response.status_code) self.log.error("Failed to update response %s", response.content.decode("utf-8")) raise AirflowException("failed to update " + self.function_name) - else: - self.log.info("Function was updated") + self.log.info("Function was updated") def does_function_exist(self) -> bool: """Whether OpenFaaS function exists or not.""" @@ -110,6 +108,5 @@ def does_function_exist(self) -> bool: response = requests.get(url) if response.ok: return True - else: - self.log.error("Failed to find function %s", self.function_name) - return False + self.log.error("Failed to find function %s", self.function_name) + return False diff --git a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py index a3744602132c1..cc899b335e8c5 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py +++ b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py @@ -298,12 +298,11 @@ def convert_to_ol_dataset(obj) -> Dataset | None: if isinstance(obj, Dataset): return obj - elif isinstance(obj, Table): + if isinstance(obj, Table): return ExtractorManager.convert_to_ol_dataset_from_table(obj) - elif isinstance(obj, File): + if isinstance(obj, File): return ExtractorManager.convert_to_ol_dataset_from_object_storage_uri(obj.url) - else: - return None + return None def validate_task_metadata(self, task_metadata) -> OperatorLineage | None: try: diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py index 8f26d11e190a6..55fa5fa976961 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py @@ -109,8 +109,7 @@ def get_openlineage_config(self) -> dict | None: if openlineage_config_path: config = self._read_yaml_config(openlineage_config_path) return config - else: - self.log.debug("OpenLineage config_path configuration not found.") + self.log.debug("OpenLineage config_path configuration not found.") # Second, try to get transport config transport_config = conf.transport() diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py index becb4bd76703a..1da6a2a97e6dc 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py @@ -111,7 +111,7 @@ def _format_transport(props: dict, transport: dict, name: str | None): props = _format_transport(props, http_transport, name) return props - elif transport.kind == "http": + if transport.kind == "http": return _format_transport({}, _get_transport_information(transport), None) log.info( diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py index e8b0339ac3e03..072944ee05d23 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/utils.py @@ -195,10 +195,9 @@ def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator | SdkB return True if isinstance(obj, DAG): return is_dag_lineage_enabled(obj) - elif isinstance(obj, (BaseOperator, MappedOperator, SdkBaseOperator)): + if isinstance(obj, (BaseOperator, MappedOperator, SdkBaseOperator)): return is_task_lineage_enabled(obj) - else: - raise TypeError("is_selective_lineage_enabled can only be used on DAG or Operator objects") + raise TypeError("is_selective_lineage_enabled can only be used on DAG or Operator objects") @provide_session @@ -714,7 +713,7 @@ class AirflowContextDeprecationWarning(UserWarning): ), ) return item - elif is_json_serializable(item) and hasattr(item, "__dict__"): + if is_json_serializable(item) and hasattr(item, "__dict__"): for dict_key, subval in item.__dict__.items(): if type(subval).__name__ == "Proxy": return "<>" @@ -730,8 +729,7 @@ class AirflowContextDeprecationWarning(UserWarning): ), ) return item - else: - return super()._redact(item, name, depth, max_depth) + return super()._redact(item, name, depth, max_depth) except Exception as exc: log.warning("Unable to redact %r. Error was: %s: %s", item, type(exc).__name__, exc) return item diff --git a/providers/openlineage/tests/system/openlineage/operator.py b/providers/openlineage/tests/system/openlineage/operator.py index f371cf1de25c9..740eaa044e8b5 100644 --- a/providers/openlineage/tests/system/openlineage/operator.py +++ b/providers/openlineage/tests/system/openlineage/operator.py @@ -82,7 +82,7 @@ def env_var(var: str, default: str | None = None) -> str: """ if var in os.environ: return os.environ[var] - elif default is not None: + if default is not None: return default raise ValueError(f"Env var required but not provided: '{var}'") @@ -166,7 +166,7 @@ def match(expected, result, env: Environment) -> bool: return True log.error("Rendered value %s does not equal 'true' or %s", rendered, result) return False - elif expected != result: + if expected != result: log.error("Expected value %s does not equal result %s", expected, result) return False elif expected != result: diff --git a/providers/opensearch/src/airflow/providers/opensearch/hooks/opensearch.py b/providers/opensearch/src/airflow/providers/opensearch/hooks/opensearch.py index fd19db4a411aa..55184d1b9448f 100644 --- a/providers/opensearch/src/airflow/providers/opensearch/hooks/opensearch.py +++ b/providers/opensearch/src/airflow/providers/opensearch/hooks/opensearch.py @@ -122,7 +122,7 @@ def delete(self, index_name: str, query: dict | None = None, doc_id: int | None if self.log_query: self.log.info("Deleting from %s using Query: %s", index_name, query) return self.client.delete_by_query(index=index_name, body=query) - elif doc_id is not None: + if doc_id is not None: return self.client.delete(index=index_name, id=doc_id) raise AirflowException( "To delete a document you must include one of either a query or a document id." diff --git a/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py b/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py index 58561acfb0628..bdb8659c858fd 100644 --- a/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py +++ b/providers/opensearch/src/airflow/providers/opensearch/log/os_task_handler.py @@ -98,8 +98,7 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: if isinstance(val, TaskInstance): val.try_number = ti.try_number return val - else: - raise AirflowException(f"Could not find TaskInstance for {ti}") + raise AirflowException(f"Could not find TaskInstance for {ti}") def get_os_kwargs_from_config() -> dict[str, Any]: diff --git a/providers/papermill/src/airflow/providers/papermill/operators/papermill.py b/providers/papermill/src/airflow/providers/papermill/operators/papermill.py index 4dffe10aa5701..a624bf23d5365 100644 --- a/providers/papermill/src/airflow/providers/papermill/operators/papermill.py +++ b/providers/papermill/src/airflow/providers/papermill/operators/papermill.py @@ -148,5 +148,4 @@ def hook(self) -> KernelHook | None: """Get valid hook.""" if self.kernel_conn_id: return KernelHook(kernel_conn_id=self.kernel_conn_id) - else: - return None + return None diff --git a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py index 3504e5e0074a0..b201a155bd29c 100644 --- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py +++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py @@ -141,9 +141,8 @@ def _get_cursor(self, raw_cursor: str) -> CursorType: } if _cursor in cursor_types: return cursor_types[_cursor] - else: - valid_cursors = ", ".join(cursor_types.keys()) - raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}") + valid_cursors = ", ".join(cursor_types.keys()) + raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}") def get_conn(self) -> connection: """Establish a connection to a postgres database.""" diff --git a/providers/presto/src/airflow/providers/presto/hooks/presto.py b/providers/presto/src/airflow/providers/presto/hooks/presto.py index d012510c96d1d..b217d900c5a56 100644 --- a/providers/presto/src/airflow/providers/presto/hooks/presto.py +++ b/providers/presto/src/airflow/providers/presto/hooks/presto.py @@ -75,7 +75,7 @@ def _boolify(value): if isinstance(value, str): if value.lower() == "false": return False - elif value.lower() == "true": + if value.lower() == "true": return True return value @@ -107,7 +107,7 @@ def get_conn(self) -> Connection: auth = None if db.password and extra.get("auth") == "kerberos": raise AirflowException("Kerberos authorization doesn't support password.") - elif db.password: + if db.password: auth = prestodb.auth.BasicAuthentication(db.login, db.password) elif extra.get("auth") == "kerberos": auth = prestodb.auth.KerberosAuthentication( diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index fe2f1b59f15f5..e13fd80fffcb1 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -244,16 +244,15 @@ def create_directory(self, path: str, mode: int = 0o777) -> None: if self.isdir(path): self.log.info("%s already exists", path) return - elif self.isfile(path): + if self.isfile(path): raise AirflowException(f"{path} already exists and is a file") - else: - dirname, basename = os.path.split(path) - if dirname and not self.isdir(dirname): - self.create_directory(dirname, mode) - if basename: - self.log.info("Creating %s", path) - with self.get_managed_conn() as conn: - conn.mkdir(path, mode=mode) + dirname, basename = os.path.split(path) + if dirname and not self.isdir(dirname): + self.create_directory(dirname, mode) + if basename: + self.log.info("Creating %s", path) + with self.get_managed_conn() as conn: + conn.mkdir(path, mode=mode) def delete_directory(self, path: str, include_files: bool = False) -> None: """ diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py index cab012921f29b..43f8208dbc640 100644 --- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py @@ -587,22 +587,19 @@ def __init__(self): async def listdir(self, path: str): if path == "/path/does_not/exist/": raise SFTPNoSuchFile("File does not exist") - else: - return ["..", ".", "file"] + return ["..", ".", "file"] async def readdir(self, path: str): if path == "/path/does_not/exist/": raise SFTPNoSuchFile("File does not exist") - else: - return [SFTPName(".."), SFTPName("."), SFTPName("file")] + return [SFTPName(".."), SFTPName("."), SFTPName("file")] async def stat(self, path: str): if path == "/path/does_not/exist/": raise SFTPNoSuchFile("No files matching") - else: - sftp_obj = SFTPAttrs() - sftp_obj.mtime = 1667302566 - return sftp_obj + sftp_obj = SFTPAttrs() + sftp_obj.mtime = 1667302566 + return sftp_obj class MockSSHClient: diff --git a/providers/slack/src/airflow/providers/slack/transfers/sql_to_slack.py b/providers/slack/src/airflow/providers/slack/transfers/sql_to_slack.py index 82aaf0523b8b1..6cf42418640fc 100644 --- a/providers/slack/src/airflow/providers/slack/transfers/sql_to_slack.py +++ b/providers/slack/src/airflow/providers/slack/transfers/sql_to_slack.py @@ -149,9 +149,9 @@ def execute(self, context: Context) -> None: if df_result.empty: if self.action_on_empty_df == "skip": raise AirflowSkipException("SQL output df is empty. Skipping.") - elif self.action_on_empty_df == "error": + if self.action_on_empty_df == "error": raise ValueError("SQL output df must be non-empty. Failing.") - elif self.action_on_empty_df != "send": + if self.action_on_empty_df != "send": raise ValueError(f"Invalid `action_on_empty_df` value {self.action_on_empty_df!r}") if output_file_format == "CSV": df_result.to_csv(output_file_name, **self.df_kwargs) diff --git a/providers/slack/src/airflow/providers/slack/utils/__init__.py b/providers/slack/src/airflow/providers/slack/utils/__init__.py index 6201dafaea597..6c59b85c0531b 100644 --- a/providers/slack/src/airflow/providers/slack/utils/__init__.py +++ b/providers/slack/src/airflow/providers/slack/utils/__init__.py @@ -56,7 +56,7 @@ def get(self, field, default: Any = NOTSET): stacklevel=2, ) return self.extra[field] - elif backcompat_key in self.extra and self.extra[backcompat_key] not in (None, ""): + if backcompat_key in self.extra and self.extra[backcompat_key] not in (None, ""): # Addition validation with non-empty required for connection which created in the UI # in Airflow 2.2. In these connections always present key-value pair for all prefixed extras # even if user do not fill this fields. @@ -64,13 +64,12 @@ def get(self, field, default: Any = NOTSET): # E.g.: `{'extra__slackwebhook__proxy': '', 'extra__slackwebhook__timeout': None}` # From Airflow 2.3, using the prefix is no longer required. return self.extra[backcompat_key] - else: - if default is NOTSET: - raise KeyError( - f"Couldn't find {backcompat_key!r} or {field!r} " - f"in Connection ({self.conn_id!r}) Extra and no default value specified." - ) - return default + if default is NOTSET: + raise KeyError( + f"Couldn't find {backcompat_key!r} or {field!r} " + f"in Connection ({self.conn_id!r}) Extra and no default value specified." + ) + return default def getint(self, field, default: Any = NOTSET) -> Any: """ @@ -107,17 +106,16 @@ def parse_filename( raise ValueError(f"No file extension specified in filename {filename!r}.") if parts[-1] in supported_file_formats: return parts[-1], None - elif len(parts) == 2: + if len(parts) == 2: raise ValueError( f"Unsupported file format {parts[-1]!r}, expected one of {supported_file_formats}." ) - else: - if parts[-2] not in supported_file_formats: - raise ValueError( - f"Unsupported file format '{parts[-2]}.{parts[-1]}', " - f"expected one of {supported_file_formats} with compression extension." - ) - return parts[-2], parts[-1] + if parts[-2] not in supported_file_formats: + raise ValueError( + f"Unsupported file format '{parts[-2]}.{parts[-1]}', " + f"expected one of {supported_file_formats} with compression extension." + ) + return parts[-2], parts[-1] except ValueError as ex: if fallback: return fallback, None diff --git a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py index 991dea54867c9..3eb3b6ea6c372 100644 --- a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py @@ -331,12 +331,11 @@ def _get_email_address_list(self, addresses: str | Iterable[str]) -> list[str]: """ if isinstance(addresses, str): return self._get_email_list_from_str(addresses) - elif isinstance(addresses, collections.abc.Iterable): + if isinstance(addresses, collections.abc.Iterable): if not all(isinstance(item, str) for item in addresses): raise TypeError("The items in your iterable must be strings.") return list(addresses) - else: - raise TypeError(f"Unexpected argument type: Received '{type(addresses).__name__}'.") + raise TypeError(f"Unexpected argument type: Received '{type(addresses).__name__}'.") def _get_email_list_from_str(self, addresses: str) -> list[str]: """ diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 092c8be9069a8..088c7177171e1 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -254,7 +254,7 @@ def _get_conn_params(self) -> dict[str, str | None]: "The private_key_file and private_key_content extra fields are mutually exclusive. " "Please remove one." ) - elif private_key_file: + if private_key_file: private_key_file_path = Path(private_key_file) if not private_key_file_path.is_file() or private_key_file_path.stat().st_size == 0: raise ValueError("The private_key_file path points to an empty or invalid file.") @@ -497,8 +497,7 @@ def run( if return_single_query_results(sql, return_last, split_statements): self.descriptions = [_last_description] return _last_result - else: - return results + return results @contextmanager def _get_cursor(self, conn: Any, return_dictionaries: bool): diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py index 8770492d06ec1..762f1b10991dc 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py @@ -117,7 +117,7 @@ def get_private_key(self) -> None: "The private_key_file and private_key_content extra fields are mutually exclusive. " "Please remove one." ) - elif private_key_file: + if private_key_file: private_key_pem = Path(private_key_file).read_bytes() elif private_key_content: private_key_pem = private_key_content.encode() @@ -289,9 +289,9 @@ def _process_response(self, status_code, resp): self.log.info("Snowflake SQL GET statements status API response: %s", resp) if status_code == 202: return {"status": "running", "message": "Query statements are still running"} - elif status_code == 422: + if status_code == 422: return {"status": "error", "message": resp["message"]} - elif status_code == 200: + if status_code == 200: if resp_statement_handles := resp.get("statementHandles"): statement_handles = resp_statement_handles elif resp_statement_handle := resp.get("statementHandle"): @@ -303,8 +303,7 @@ def _process_response(self, status_code, resp): "message": resp["message"], "statement_handles": statement_handles, } - else: - return {"status": "error", "message": resp["message"]} + return {"status": "error", "message": resp["message"]} def get_sql_api_query_status(self, query_id: str) -> dict[str, str | list[str]]: """ diff --git a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py index b4eb71079f205..1018871654030 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py @@ -433,7 +433,7 @@ def execute(self, context: Context) -> None: statement_status = self._hook.get_sql_api_query_status(query_id) if statement_status.get("status") == "running": break - elif statement_status.get("status") == "success": + if statement_status.get("status") == "success": succeeded_query_ids.append(query_id) else: raise AirflowException(f"{statement_status.get('status')}: {statement_status.get('message')}") @@ -503,7 +503,7 @@ def execute_complete(self, context: Context, event: dict[str, str | list[str]] | if "status" in event and event["status"] == "error": msg = f"{event['status']}: {event['message']}" raise AirflowException(msg) - elif "status" in event and event["status"] == "success": + if "status" in event and event["status"] == "success": hook = SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id) query_ids = cast("list[str]", event["statement_query_ids"]) hook.check_query_output(query_ids) diff --git a/providers/snowflake/src/airflow/providers/snowflake/utils/snowpark.py b/providers/snowflake/src/airflow/providers/snowflake/utils/snowpark.py index a6617bb92029f..72b1192ad46e2 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/utils/snowpark.py +++ b/providers/snowflake/src/airflow/providers/snowflake/utils/snowpark.py @@ -40,5 +40,4 @@ def inject_session_into_op_kwargs( signature = inspect.signature(python_callable) if "session" in signature.parameters: return {**op_kwargs, "session": session} - else: - return op_kwargs + return op_kwargs diff --git a/providers/standard/src/airflow/providers/standard/operators/bash.py b/providers/standard/src/airflow/providers/standard/operators/bash.py index 3be7b58d16bc4..84db1e081e7f2 100644 --- a/providers/standard/src/airflow/providers/standard/operators/bash.py +++ b/providers/standard/src/airflow/providers/standard/operators/bash.py @@ -229,7 +229,7 @@ def execute(self, context: Context): if result.exit_code in self.skip_on_exit_code: raise AirflowSkipException(f"Bash command returned exit code {result.exit_code}. Skipping.") - elif result.exit_code != 0: + if result.exit_code != 0: raise AirflowException( f"Bash command failed. The command returned a non-zero exit code {result.exit_code}." ) diff --git a/providers/standard/src/airflow/providers/standard/operators/latest_only.py b/providers/standard/src/airflow/providers/standard/operators/latest_only.py index d7f4c636ebfff..fd5cf62bbad2f 100644 --- a/providers/standard/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/standard/src/airflow/providers/standard/operators/latest_only.py @@ -93,9 +93,8 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: # we return an empty list, thus the parent BaseBranchOperator # won't exclude any downstream tasks from skipping. return [] - else: - self.log.info("Latest, allowing execution to proceed.") - return list(context["task"].get_direct_relative_ids(upstream=False)) + self.log.info("Latest, allowing execution to proceed.") + return list(context["task"].get_direct_relative_ids(upstream=False)) def _get_next_run_info(self, context: Context, dag_run: DagRun) -> DagRunInfo | None: dag: DAG = context["dag"] # type: ignore[assignment] diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index c4d0d373e58d2..1f6759ef8b89f 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -573,13 +573,12 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): except subprocess.CalledProcessError as e: if e.returncode in self.skip_on_exit_code: raise AirflowSkipException(f"Process exited with code {e.returncode}. Skipping.") - elif termination_log_path.exists() and termination_log_path.stat().st_size > 0: + if termination_log_path.exists() and termination_log_path.stat().st_size > 0: error_msg = f"Process returned non-zero exit status {e.returncode}.\n" with open(termination_log_path) as file: error_msg += file.read() raise AirflowException(error_msg) from None - else: - raise + raise if 0 in self.skip_on_exit_code: raise AirflowSkipException("Process exited with code 0. Skipping.") @@ -590,8 +589,7 @@ def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: keyword_params = KeywordParameters.determine(self.python_callable, self.op_args, context) if AIRFLOW_V_3_0_PLUS: return keyword_params.unpacking() - else: - return keyword_params.serializing() # type: ignore[attr-defined] + return keyword_params.serializing() # type: ignore[attr-defined] class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): @@ -1130,8 +1128,7 @@ def my_task(): from airflow.sdk import get_current_context return get_current_context() - else: - return _get_current_context() + return _get_current_context() def _get_current_context() -> Mapping[str, Any]: diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index 2057ee178fc2b..f20e9c97c8ae1 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -96,11 +96,10 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: from airflow.utils.helpers import build_airflow_dagrun_url return build_airflow_dagrun_url(dag_id=trigger_dag_id, run_id=triggered_dag_run_id) - else: - from airflow.utils.helpers import build_airflow_url_with_query # type:ignore[attr-defined] + from airflow.utils.helpers import build_airflow_url_with_query # type:ignore[attr-defined] - query = {"dag_id": trigger_dag_id, "dag_run_id": triggered_dag_run_id} - return build_airflow_url_with_query(query) + query = {"dag_id": trigger_dag_id, "dag_run_id": triggered_dag_run_id} + return build_airflow_url_with_query(query) class TriggerDagRunOperator(BaseOperator): diff --git a/providers/standard/src/airflow/providers/standard/sensors/bash.py b/providers/standard/src/airflow/providers/standard/sensors/bash.py index 023983e7b90dc..64a3220e48502 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/bash.py +++ b/providers/standard/src/airflow/providers/standard/sensors/bash.py @@ -107,14 +107,12 @@ def poke(self, context: Context): return True # we have a retry exit code, sensor retries if return code matches, otherwise error - elif self.retry_exit_code is not None: + if self.retry_exit_code is not None: if resp.returncode == self.retry_exit_code: self.log.info("Return code matches retry code, will retry later") return False - else: - raise AirflowFailException(f"Command exited with return code {resp.returncode}") + raise AirflowFailException(f"Command exited with return code {resp.returncode}") # backwards compatibility: sensor retries no matter the error code - else: - self.log.info("Non-zero return code and no retry code set, will retry later") - return False + self.log.info("Non-zero return code and no retry code set, will retry later") + return False diff --git a/providers/standard/src/airflow/providers/standard/sensors/external_task.py b/providers/standard/src/airflow/providers/standard/sensors/external_task.py index 012abb272a010..e7fdcf7850f9c 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py @@ -85,11 +85,10 @@ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str: from airflow.utils.helpers import build_airflow_dagrun_url return build_airflow_dagrun_url(dag_id=external_dag_id, run_id=ti_key.run_id) - else: - from airflow.utils.helpers import build_airflow_url_with_query # type:ignore[attr-defined] + from airflow.utils.helpers import build_airflow_url_with_query # type:ignore[attr-defined] - query = {"dag_id": external_dag_id, "run_id": ti_key.run_id} - return build_airflow_url_with_query(query) + query = {"dag_id": external_dag_id, "run_id": ti_key.run_id} + return build_airflow_url_with_query(query) class ExternalTaskSensor(BaseSensorOperator): @@ -299,8 +298,7 @@ def poke(self, context: Context) -> bool: if AIRFLOW_V_3_0_PLUS: return self._poke_af3(context, dttm_filter) - else: - return self._poke_af2(dttm_filter) + return self._poke_af2(dttm_filter) def _poke_af3(self, context: Context, dttm_filter: list[datetime.datetime]) -> bool: from airflow.providers.standard.utils.sensor_helper import _get_count_by_matched_states @@ -316,19 +314,18 @@ def _get_count(states: list[str]) -> int: logical_dates=dttm_filter, states=states, ) - elif self.external_task_group_id: + if self.external_task_group_id: run_id_task_state_map = ti.get_task_states( dag_id=self.external_dag_id, task_group_id=self.external_task_group_id, logical_dates=dttm_filter, ) return _get_count_by_matched_states(run_id_task_state_map, states) - else: - return ti.get_dr_count( - dag_id=self.external_dag_id, - logical_dates=dttm_filter, - states=states, - ) + return ti.get_dr_count( + dag_id=self.external_dag_id, + logical_dates=dttm_filter, + states=states, + ) if self.failed_states: count = _get_count(self.failed_states) @@ -348,8 +345,7 @@ def _calculate_count(self, count: int, dttm_filter: list[datetime.datetime]) -> """Calculate the normalized count based on the type of check.""" if self.external_task_ids: return count / len(self.external_task_ids) - else: - return count + return count def _handle_failed_states(self, count_failed: float | int) -> None: """Handle failed states and raise appropriate exceptions.""" @@ -364,7 +360,7 @@ def _handle_failed_states(self, count_failed: float | int) -> None: f"Some of the external tasks {self.external_task_ids} " f"in DAG {self.external_dag_id} failed." ) - elif self.external_task_group_id: + if self.external_task_group_id: if self.soft_fail: raise AirflowSkipException( f"The external task_group '{self.external_task_group_id}' " @@ -374,12 +370,11 @@ def _handle_failed_states(self, count_failed: float | int) -> None: f"The external task_group '{self.external_task_group_id}' " f"in DAG '{self.external_dag_id}' failed." ) - else: - if self.soft_fail: - raise AirflowSkipException( - f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail." - ) - raise AirflowException(f"The external DAG {self.external_dag_id} failed.") + if self.soft_fail: + raise AirflowSkipException( + f"The external DAG {self.external_dag_id} failed. Skipping due to soft_fail." + ) + raise AirflowException(f"The external DAG {self.external_dag_id} failed.") def _handle_skipped_states(self, count_skipped: float | int) -> None: """Handle skipped states and raise appropriate exceptions.""" @@ -389,16 +384,15 @@ def _handle_skipped_states(self, count_skipped: float | int) -> None: f"Some of the external tasks {self.external_task_ids} " f"in DAG {self.external_dag_id} reached a state in our states-to-skip-on list. Skipping." ) - elif self.external_task_group_id: + if self.external_task_group_id: raise AirflowSkipException( f"The external task_group '{self.external_task_group_id}' " f"in DAG {self.external_dag_id} reached a state in our states-to-skip-on list. Skipping." ) - else: - raise AirflowSkipException( - f"The external DAG {self.external_dag_id} reached a state in our states-to-skip-on list. " - "Skipping." - ) + raise AirflowSkipException( + f"The external DAG {self.external_dag_id} reached a state in our states-to-skip-on list. " + "Skipping." + ) @provide_session def _poke_af2(self, dttm_filter: list[datetime.datetime], session: Session = NEW_SESSION) -> bool: @@ -450,11 +444,10 @@ def execute_complete(self, context, event=None): else: if self.soft_fail: raise AirflowSkipException("External job has failed skipping.") - else: - raise AirflowException( - "Error occurred while trying to retrieve task status. Please, check the " - "name of executed task and Dag." - ) + raise AirflowException( + "Error occurred while trying to retrieve task status. Please, check the " + "name of executed task and Dag." + ) def _check_for_existence(self, session) -> None: dag_to_wait = DagModel.get_current(self.external_dag_id, session) diff --git a/providers/standard/src/airflow/providers/standard/sensors/python.py b/providers/standard/src/airflow/providers/standard/sensors/python.py index 28f293135fdb3..33d59b903ce83 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/python.py +++ b/providers/standard/src/airflow/providers/standard/sensors/python.py @@ -81,5 +81,4 @@ def poke(self, context: Context) -> PokeReturnValue | bool: return_value = self.python_callable(*self.op_args, **self.op_kwargs) if isinstance(return_value, PokeReturnValue): return return_value - else: - return PokeReturnValue(bool(return_value)) + return PokeReturnValue(bool(return_value)) diff --git a/providers/standard/src/airflow/providers/standard/sensors/weekday.py b/providers/standard/src/airflow/providers/standard/sensors/weekday.py index 705805eb5a0b9..5b2fce987095e 100644 --- a/providers/standard/src/airflow/providers/standard/sensors/weekday.py +++ b/providers/standard/src/airflow/providers/standard/sensors/weekday.py @@ -119,5 +119,4 @@ def poke(self, context: Context) -> bool: ) return determined_weekday_num in self._week_day_num - else: - return timezone.utcnow().isoweekday() in self._week_day_num + return timezone.utcnow().isoweekday() in self._week_day_num diff --git a/providers/standard/src/airflow/providers/standard/triggers/external_task.py b/providers/standard/src/airflow/providers/standard/triggers/external_task.py index d3db7027fefb9..03062ae40f1f0 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/external_task.py +++ b/providers/standard/src/airflow/providers/standard/triggers/external_task.py @@ -165,8 +165,7 @@ async def _get_count_af_3(self, states): ) if self.external_task_ids: return count / len(self.external_task_ids) - else: - return count + return count @sync_to_async def _get_count(self, states: typing.Iterable[str] | None) -> int: diff --git a/providers/standard/src/airflow/providers/standard/triggers/temporal.py b/providers/standard/src/airflow/providers/standard/triggers/temporal.py index 12834509b5126..48ebb223cb6ca 100644 --- a/providers/standard/src/airflow/providers/standard/triggers/temporal.py +++ b/providers/standard/src/airflow/providers/standard/triggers/temporal.py @@ -51,10 +51,9 @@ def __init__(self, moment: datetime.datetime, *, end_from_trigger: bool = False) if not isinstance(moment, datetime.datetime): raise TypeError(f"Expected datetime.datetime type for moment. Got {type(moment)}") # Make sure it's in UTC - elif moment.tzinfo is None: + if moment.tzinfo is None: raise ValueError("You cannot pass naive datetimes") - else: - self.moment: pendulum.DateTime = timezone.convert_to_utc(moment) + self.moment: pendulum.DateTime = timezone.convert_to_utc(moment) if not AIRFLOW_V_2_10_PLUS and end_from_trigger: raise AirflowException("end_from_trigger is only supported in Airflow 2.10 and later. ") diff --git a/providers/standard/src/airflow/providers/standard/utils/python_virtualenv.py b/providers/standard/src/airflow/providers/standard/utils/python_virtualenv.py index 9d03e43367a49..e182df55b5d3b 100644 --- a/providers/standard/src/airflow/providers/standard/utils/python_virtualenv.py +++ b/providers/standard/src/airflow/providers/standard/utils/python_virtualenv.py @@ -49,7 +49,7 @@ def _use_uv() -> bool: venv_install_method = conf.get("standard", "venv_install_method", fallback="auto").lower() if venv_install_method == "auto": return _is_uv_installed() - elif venv_install_method == "uv": + if venv_install_method == "uv": return True return False diff --git a/providers/standard/tests/unit/standard/decorators/test_external_python.py b/providers/standard/tests/unit/standard/decorators/test_external_python.py index cd8121c4b43c6..ef2007048f577 100644 --- a/providers/standard/tests/unit/standard/decorators/test_external_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_external_python.py @@ -162,8 +162,7 @@ def test_with_args(self, serializer, dag_maker, venv_python_with_cloudpickle_and def f(a, b, c=False, d=False): if a == 0 and b == 1 and c and not d: return True - else: - raise Exception + raise Exception with dag_maker(serialized=True): ret = f(0, 1, c=True) diff --git a/providers/standard/tests/unit/standard/decorators/test_python_virtualenv.py b/providers/standard/tests/unit/standard/decorators/test_python_virtualenv.py index c5ad3f3b85d8f..caf438a04d1e0 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python_virtualenv.py +++ b/providers/standard/tests/unit/standard/decorators/test_python_virtualenv.py @@ -284,8 +284,7 @@ def test_with_args(self, serializer, extra_requirements, dag_maker): def f(a, b, c=False, d=False): if a == 0 and b == 1 and c and not d: return True - else: - raise Exception + raise Exception with dag_maker(serialized=True): ret = f(0, 1, c=True) diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 5d0c99dc354bb..ef84c0dda3d08 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -866,8 +866,7 @@ def test_with_args(self): def f(a, b, c=False, d=False): if a == 0 and b == 1 and c and not d: return True - else: - raise RuntimeError + raise RuntimeError self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True}) @@ -1528,8 +1527,7 @@ def test_with_args(self): def f(a, b, c=False, d=False): if a == 0 and b == 1 and c and not d: return True - else: - raise RuntimeError + raise RuntimeError with pytest.raises( AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}.|'branch_task_ids'.*task.*" diff --git a/providers/teradata/src/airflow/providers/teradata/operators/teradata_compute_cluster.py b/providers/teradata/src/airflow/providers/teradata/operators/teradata_compute_cluster.py index 3257465f3446a..81e4fc22b34cc 100644 --- a/providers/teradata/src/airflow/providers/teradata/operators/teradata_compute_cluster.py +++ b/providers/teradata/src/airflow/providers/teradata/operators/teradata_compute_cluster.py @@ -157,7 +157,7 @@ def _compute_cluster_execute(self): def _compute_cluster_execute_complete(self, event: dict[str, Any]) -> None: if event["status"] == "success": return event["message"] - elif event["status"] == "error": + if event["status"] == "error": raise AirflowException(event["message"]) def _handle_cc_status(self, operation_type, sql): @@ -185,8 +185,7 @@ def _hook_run(self, query, handler=None): try: if handler is not None: return self.hook.run(query, handler=handler) - else: - return self.hook.run(query) + return self.hook.run(query) except Exception as ex: self.log.error(str(ex)) raise @@ -305,13 +304,12 @@ def _compute_cluster_execute(self): msg = f"Compute Profile {self.compute_profile_name} is already exists under Compute Group {self.compute_group_name}. Status is {cp_status_result}" self.log.info(msg) return cp_status_result - else: - create_cp_query = self._build_ccp_setup_query() - operation = Constants.CC_CREATE_OPR - initially_suspended = self._get_initially_suspended(create_cp_query) - if initially_suspended == "TRUE": - operation = Constants.CC_CREATE_SUSPEND_OPR - return self._handle_cc_status(operation, create_cp_query) + create_cp_query = self._build_ccp_setup_query() + operation = Constants.CC_CREATE_OPR + initially_suspended = self._get_initially_suspended(create_cp_query) + if initially_suspended == "TRUE": + operation = Constants.CC_CREATE_SUSPEND_OPR + return self._handle_cc_status(operation, create_cp_query) class TeradataComputeClusterDecommissionOperator(_TeradataComputeClusterOperator): @@ -444,10 +442,9 @@ def _compute_cluster_execute(self): if self.compute_group_name: cp_resume_query = f"{cp_resume_query} IN COMPUTE GROUP {self.compute_group_name}" return self._handle_cc_status(Constants.CC_RESUME_OPR, cp_resume_query) - else: - self.log.info( - "Compute Cluster %s already %s", self.compute_profile_name, Constants.CC_RESUME_DB_STATUS - ) + self.log.info( + "Compute Cluster %s already %s", self.compute_profile_name, Constants.CC_RESUME_DB_STATUS + ) class TeradataComputeClusterSuspendOperator(_TeradataComputeClusterOperator): @@ -516,7 +513,6 @@ def _compute_cluster_execute(self): if self.compute_group_name: sql = f"{sql} IN COMPUTE GROUP {self.compute_group_name}" return self._handle_cc_status(Constants.CC_SUSPEND_OPR, sql) - else: - self.log.info( - "Compute Cluster %s already %s", self.compute_profile_name, Constants.CC_SUSPEND_DB_STATUS - ) + self.log.info( + "Compute Cluster %s already %s", self.compute_profile_name, Constants.CC_SUSPEND_DB_STATUS + ) diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py b/providers/trino/src/airflow/providers/trino/hooks/trino.py index 728d993472fb1..f907e95f4eec3 100644 --- a/providers/trino/src/airflow/providers/trino/hooks/trino.py +++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py @@ -77,7 +77,7 @@ def _boolify(value): if isinstance(value, str): if value.lower() == "false": return False - elif value.lower() == "true": + if value.lower() == "true": return True return value @@ -146,7 +146,7 @@ def get_conn(self) -> Connection: user = db.login if db.password and extra.get("auth") in ("kerberos", "certs"): raise AirflowException(f"The {extra.get('auth')!r} authorization type doesn't support password.") - elif db.password: + if db.password: auth = trino.auth.BasicAuthentication(db.login, db.password) # type: ignore[attr-defined] elif extra.get("auth") == "jwt": if not exactly_one(jwt_file := "jwt__file" in extra, jwt_token := "jwt__token" in extra): @@ -159,7 +159,7 @@ def get_conn(self) -> Connection: else: msg += "none of them provided." raise ValueError(msg) - elif jwt_file: + if jwt_file: token = Path(extra["jwt__file"]).read_text() else: token = extra["jwt__token"] diff --git a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py index f70c934845180..7d5fa129492d1 100644 --- a/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py +++ b/providers/weaviate/src/airflow/providers/weaviate/hooks/weaviate.py @@ -571,8 +571,7 @@ def _generate_uuids( raise ValueError( "Property 'id' already in dataset. Consider renaming or specify 'uuid_column'." ) - else: - uuid_column = "id" + uuid_column = "id" if uuid_column in column_names: raise ValueError( @@ -847,7 +846,7 @@ def create_or_replace_document_objects( f"Documents {', '.join(changed_documents)} already exists. You can either skip or replace" f" them by passing 'existing=skip' or 'existing=replace' respectively." ) - elif existing == "skip": + if existing == "skip": data = data[data[document_column].isin(new_documents)] if verbose: self.log.info( diff --git a/pyproject.toml b/pyproject.toml index a364e603db05c..3b6502cc22af8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -588,6 +588,10 @@ extend-select = [ "B019", # Use of functools.lru_cache or functools.cache on methods can lead to memory leaks "B028", # No explicit stacklevel keyword argument found "TRY002", # Prohibit use of `raise Exception`, use specific exceptions instead. + "RET505", # Unnecessary {branch} after return statement + "RET506", # Unnecessary {branch} after raise statement + "RET507", # Unnecessary {branch} after continue statement + "RET508", # Unnecessary {branch} after break statement ] ignore = [ "D100", # Unwanted; Docstring at the top of every file. diff --git a/scripts/ci/pre_commit/check_deprecations.py b/scripts/ci/pre_commit/check_deprecations.py index fad1009463ed5..c1fb6354c9a91 100755 --- a/scripts/ci/pre_commit/check_deprecations.py +++ b/scripts/ci/pre_commit/check_deprecations.py @@ -91,8 +91,7 @@ def allowed_group_warnings(group: str) -> tuple[str, tuple[str, ...]]: group_warnings = allowed_warnings[group] if len(group_warnings) == 1: return f"expected {group_warnings[0]} type", group_warnings - else: - return f"expected one of {', '.join(group_warnings)} types", group_warnings + return f"expected one of {', '.join(group_warnings)} types", group_warnings def built_import_from(import_from: ast.ImportFrom) -> list[str]: @@ -151,7 +150,7 @@ def resolve_name(obj: ast.Attribute | ast.Name) -> str: if isinstance(obj, ast.Name): name = f"{obj.id}.{name}" if name else obj.id break - elif isinstance(obj, ast.Attribute): + if isinstance(obj, ast.Attribute): name = f"{obj.attr}.{name}" if name else obj.attr obj = obj.value # type: ignore[assignment] else: @@ -189,7 +188,7 @@ def check_decorators(mod: ast.Module, file: str, file_group: str) -> int: f"{expected_types}" ) continue - elif not hasattr(category_keyword, "value"): + if not hasattr(category_keyword, "value"): continue category_value_ast = category_keyword.value diff --git a/scripts/ci/pre_commit/check_integrations_list.py b/scripts/ci/pre_commit/check_integrations_list.py index a525063806fef..9aea158fa124f 100755 --- a/scripts/ci/pre_commit/check_integrations_list.py +++ b/scripts/ci/pre_commit/check_integrations_list.py @@ -110,10 +110,9 @@ def _list_matcher(j): """Filter callable to exclude header and empty cells.""" if len(j) == 0: return False - elif j in ["Description", "Identifier"]: + if j in ["Description", "Identifier"]: return False - else: - return True + return True table_cells = list(filter(_list_matcher, table_cells)) return table_cells diff --git a/scripts/ci/pre_commit/checkout_no_credentials.py b/scripts/ci/pre_commit/checkout_no_credentials.py index 02a720eda6d3a..bb0c906b603b4 100755 --- a/scripts/ci/pre_commit/checkout_no_credentials.py +++ b/scripts/ci/pre_commit/checkout_no_credentials.py @@ -72,14 +72,12 @@ def check_file(the_file: Path) -> int: ) error_num += 1 continue - else: - if persist_credentials: - console.print( - "\n[red]The `with` clause have persist-credentials=True in step:[/]" - f"\n\n{pretty_step}" - ) - error_num += 1 - continue + if persist_credentials: + console.print( + f"\n[red]The `with` clause have persist-credentials=True in step:[/]\n\n{pretty_step}" + ) + error_num += 1 + continue return error_num diff --git a/scripts/ci/pre_commit/generate_pypi_readme.py b/scripts/ci/pre_commit/generate_pypi_readme.py index e263e798ff6c3..c3699b65e0eda 100755 --- a/scripts/ci/pre_commit/generate_pypi_readme.py +++ b/scripts/ci/pre_commit/generate_pypi_readme.py @@ -53,8 +53,7 @@ def extract_section(content, section_name): ) if section_match: return section_match.group(1) - else: - raise RuntimeError(f"Cannot find section {section_name} in README.md") + raise RuntimeError(f"Cannot find section {section_name} in README.md") if __name__ == "__main__": diff --git a/scripts/ci/pre_commit/update_example_dags_paths.py b/scripts/ci/pre_commit/update_example_dags_paths.py index 3a3a6a9807b8d..9ef30fdb5fb22 100755 --- a/scripts/ci/pre_commit/update_example_dags_paths.py +++ b/scripts/ci/pre_commit/update_example_dags_paths.py @@ -72,11 +72,10 @@ def replace_match(file: Path, line: str) -> str | None: if proper_system_tests_url in file.read_text(): console.print(f"[yellow] Removing from {file}[/]\n{line.strip()}") return None - else: - new_line = re.sub(EXAMPLE_DAGS_URL_MATCHER, r"\1" + proper_system_tests_url + r"\5", line) - if new_line != line: - console.print(f"[yellow] Replacing in {file}[/]\n{line.strip()}\n{new_line.strip()}") - return new_line + new_line = re.sub(EXAMPLE_DAGS_URL_MATCHER, r"\1" + proper_system_tests_url + r"\5", line) + if new_line != line: + console.print(f"[yellow] Replacing in {file}[/]\n{line.strip()}\n{new_line.strip()}") + return new_line return line diff --git a/scripts/ci/pre_commit/update_installers_and_pre_commit.py b/scripts/ci/pre_commit/update_installers_and_pre_commit.py index 014962b511bac..1f4b07f77e59b 100755 --- a/scripts/ci/pre_commit/update_installers_and_pre_commit.py +++ b/scripts/ci/pre_commit/update_installers_and_pre_commit.py @@ -155,9 +155,9 @@ class Quoting(Enum): def get_replacement(value: str, quoting: Quoting) -> str: if quoting == Quoting.DOUBLE_QUOTED: return f'"{value}"' - elif quoting == Quoting.SINGLE_QUOTED: + if quoting == Quoting.SINGLE_QUOTED: return f"'{value}'" - elif quoting == Quoting.REVERSE_SINGLE_QUOTED: + if quoting == Quoting.REVERSE_SINGLE_QUOTED: return f"`{value}`" return value diff --git a/scripts/ci/testing/summarize_captured_warnings.py b/scripts/ci/testing/summarize_captured_warnings.py index 1286359c3b8b5..f4843b43381fa 100755 --- a/scripts/ci/testing/summarize_captured_warnings.py +++ b/scripts/ci/testing/summarize_captured_warnings.py @@ -149,7 +149,7 @@ def merge_files(files: Iterator[tuple[Path, str]], output_directory: Path) -> Pa record = json.loads(line) if not isinstance(record, dict): raise TypeError - elif not all(field in record for field in REQUIRED_FIELDS): + if not all(field in record for field in REQUIRED_FIELDS): raise ValueError except Exception: bad_records += 1 diff --git a/scripts/in_container/verify_providers.py b/scripts/in_container/verify_providers.py index de2a297ae6fbd..bfcff18f0f39d 100755 --- a/scripts/in_container/verify_providers.py +++ b/scripts/in_container/verify_providers.py @@ -352,8 +352,7 @@ def strip_package_from_class(base_package: str, class_name: str) -> str: """Strips base package name from the class (if it starts with the package name).""" if class_name.startswith(base_package): return class_name[len(base_package) + 1 :] - else: - return class_name + return class_name def convert_class_name_to_url(base_url: str, class_name) -> str: diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index e966a1e5f7a76..7e0f4e657ae93 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -504,8 +504,7 @@ def trigger( log.info("DAG Run already exists!", detail=e.detail, dag_id=dag_id, run_id=run_id) return ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS) - else: - raise + raise return OKResponse(ok=True) diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index 7dfc1a3e00834..6780fe29fe27d 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -101,7 +101,7 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> ) if len(kwargs_left) == 1: raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}") - elif kwargs_left: + if kwargs_left: names = ", ".join(repr(n) for n in kwargs_left) raise TypeError(f"{func}() got unexpected keyword arguments {names}") @@ -382,7 +382,7 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]): if len(context_keys_being_mapped) == 1: (name,) = context_keys_being_mapped raise ValueError(f"cannot call {func}() on task context variable {name!r}") - elif context_keys_being_mapped: + if context_keys_being_mapped: names = ", ".join(repr(n) for n in context_keys_being_mapped) raise ValueError(f"cannot call {func}() on task context variables {names}") @@ -674,7 +674,7 @@ def task_decorator_factory( kwargs=kwargs, ) return cast("TaskDecorator", decorator) - elif python_callable is not None: + if python_callable is not None: raise TypeError("No args allowed while using @task, use kwargs instead") def decorator_factory(python_callable): diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 2a141f0f2ac27..09ca8c6c03e5c 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -158,7 +158,7 @@ def get_merged_defaults( def parse_retries(retries: Any) -> int | None: if retries is None: return 0 - elif type(retries) == int: # noqa: E721 + if type(retries) == int: # noqa: E721 return retries try: parsed_retries = int(retries) @@ -494,7 +494,7 @@ def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: missing_args = non_optional_args.difference(kwargs) if len(missing_args) == 1: raise TypeError(f"missing keyword argument {missing_args.pop()!r}") - elif missing_args: + if missing_args: display = ", ".join(repr(a) for a in sorted(missing_args)) raise TypeError(f"missing keyword arguments {display}") @@ -1307,8 +1307,7 @@ def dag(self) -> DAG: """Returns the Operator's DAG if set, otherwise raises an error.""" if dag := self._dag: return dag - else: - raise RuntimeError(f"Operator {self} has not been assigned to a DAG yet") + raise RuntimeError(f"Operator {self} has not been assigned to a DAG yet") @dag.setter def dag(self, dag: DAG | None) -> None: @@ -1324,7 +1323,7 @@ def _convert__dag(self, dag: DAG | None) -> DAG | None: if not isinstance(dag, DAG): raise TypeError(f"Expected DAG; received {dag.__class__.__name__}") - elif self._dag is not None and self._dag is not dag: + if self._dag is not None and self._dag is not dag: raise ValueError(f"The DAG assigned to {self} can not be changed.") if self.__from_mapped: @@ -1337,7 +1336,7 @@ def _convert__dag(self, dag: DAG | None) -> DAG | None: def _convert_retries(retries: Any) -> int | None: if retries is None: return 0 - elif type(retries) == int: # noqa: E721 + if type(retries) == int: # noqa: E721 return retries try: parsed_retries = int(retries) @@ -1615,8 +1614,7 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, self.log.error("Trigger failed:\n%s", "\n".join(traceback)) if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: raise TaskDeferralTimeout(error) - else: - raise TaskDeferralError(error) + raise TaskDeferralError(error) # Grab the callable off the Operator/Task and add in any kwargs execute_callable = getattr(self, next_method) return execute_callable(context, **next_kwargs) diff --git a/task-sdk/src/airflow/sdk/bases/sensor.py b/task-sdk/src/airflow/sdk/bases/sensor.py index 3cbbac9a4926b..ca2b80cb97eb6 100644 --- a/task-sdk/src/airflow/sdk/bases/sensor.py +++ b/task-sdk/src/airflow/sdk/bases/sensor.py @@ -212,7 +212,7 @@ def run_duration() -> float: ) as e: if self.soft_fail: raise AirflowSkipException("Skipping due to soft_fail is set to True.") from e - elif self.never_fail: + if self.never_fail: raise AirflowSkipException("Skipping due to never_fail is set to True.") from e raise e except AirflowSkipException as e: @@ -240,15 +240,13 @@ def run_duration() -> float: if self.soft_fail: raise AirflowSkipException(message) - else: - raise AirflowSensorTimeout(message) + raise AirflowSensorTimeout(message) if self.reschedule: next_poke_interval = self._get_next_poke_interval(started_at, run_duration, poke_count) reschedule_date = timezone.utcnow() + timedelta(seconds=next_poke_interval) raise AirflowRescheduleException(reschedule_date) - else: - time.sleep(self._get_next_poke_interval(started_at, run_duration, poke_count)) - poke_count += 1 + time.sleep(self._get_next_poke_interval(started_at, run_duration, poke_count)) + poke_count += 1 self.log.info("Success criteria met. Exiting.") return xcom_value diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py b/task-sdk/src/airflow/sdk/definitions/_internal/node.py index b9f73ea132c65..21fa4ede5b1c9 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py @@ -180,7 +180,7 @@ def _set_relatives( if len(dags) > 1: raise RuntimeError(f"Tried to set relationships between tasks in more than one DAG: {dags}") - elif len(dags) == 1: + if len(dags) == 1: dag = dags.pop() else: raise ValueError( @@ -241,15 +241,13 @@ def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: """Get set of the direct relative ids to the current task, upstream or downstream.""" if upstream: return self.upstream_task_ids - else: - return self.downstream_task_ids + return self.downstream_task_ids def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: """Get list of the direct relatives to the current task, upstream or downstream.""" if upstream: return self.upstream_list - else: - return self.downstream_list + return self.downstream_list def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Serialize a task group's content; used by TaskGroupSerialization.""" diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py index d141b0328b015..024c11ec12306 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py @@ -184,13 +184,13 @@ def render_template( # Fast path for common built-in collections. if value.__class__ is tuple: return tuple(self.render_template(element, context, jinja_env, oids) for element in value) - elif isinstance(value, tuple): # Special case for named tuples. + if isinstance(value, tuple): # Special case for named tuples. return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value)) - elif isinstance(value, list): + if isinstance(value, list): return [self.render_template(element, context, jinja_env, oids) for element in value] - elif isinstance(value, dict): + if isinstance(value, dict): return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()} - elif isinstance(value, set): + if isinstance(value, set): return {self.render_template(element, context, jinja_env, oids) for element in value} # More complex collections. diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index f156df5895365..c81732cf40414 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -356,7 +356,7 @@ def __init__( ) -> None: if name is None and uri is None: raise TypeError("Asset() requires either 'name' or 'uri'") - elif name is None: + if name is None: name = str(uri) elif uri is None: uri = name diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 3adc2a98a0fb2..e9915d74cac4d 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -172,8 +172,7 @@ def _convert_tags(tags: Collection[str] | None) -> MutableSet[str]: def _convert_access_control(value, self_: DAG): if hasattr(self_, "_upgrade_outdated_dag_access_control"): return self_._upgrade_outdated_dag_access_control(value) - else: - return value + return value def _convert_doc_md(doc_md: str | None) -> str | None: @@ -512,16 +511,15 @@ def _default_timetable(instance: DAG): # delattr(self, "schedule") if isinstance(schedule, Timetable): return schedule - elif isinstance(schedule, BaseAsset): + if isinstance(schedule, BaseAsset): return AssetTriggeredTimetable(schedule) - elif isinstance(schedule, Collection) and not isinstance(schedule, str): + if isinstance(schedule, Collection) and not isinstance(schedule, str): if not all(isinstance(x, BaseAsset) for x in schedule): raise ValueError( "All elements in 'schedule' should be either assets, asset references, or asset aliases" ) return AssetTriggeredTimetable(AssetAll(*schedule)) - else: - return _create_timetable(schedule, instance.timezone) + return _create_timetable(schedule, instance.timezone) @timezone.default def _extract_tz(instance): @@ -944,12 +942,11 @@ def add_task(self, task: Operator) -> None: task_id in self.task_dict and self.task_dict[task_id] is not task ) or task_id in self.task_group.used_group_ids: raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") - else: - self.task_dict[task_id] = task - # TODO: Task-SDK: this type ignore shouldn't be needed! - task.dag = self # type: ignore[assignment] - # Add task_id to used_group_ids to prevent group_id and task_id collisions. - self.task_group.used_group_ids.add(task_id) + self.task_dict[task_id] = task + # TODO: Task-SDK: this type ignore shouldn't be needed! + task.dag = self # type: ignore[assignment] + # Add task_id to used_group_ids to prevent group_id and task_id collisions. + self.task_group.used_group_ids.add(task_id) FailFastDagInvalidTriggerRule.check(fail_fast=self.fail_fast, trigger_rule=task.trigger_rule) @@ -991,8 +988,7 @@ def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeI empty = cast("EdgeInfoType", {}) if self.edge_info: return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) - else: - return empty + return empty def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): """ diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 8c39b50e4a92f..3363424dee68e 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -413,9 +413,9 @@ def has_non_teardown_downstream(task, exclude: str): for down_task in task.downstream_list: if down_task.task_id == exclude: continue - elif down_task.task_id not in ids: + if down_task.task_id not in ids: continue - elif not down_task.is_teardown: + if not down_task.is_teardown: return True return False diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d5411b9dd8e46..ab3a5a82b22c1 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -516,7 +516,7 @@ def __getitem__(self, index: Any) -> Any: for value in self.values: if i < 0: break - elif i >= (curlen := len(value)): + if i >= (curlen := len(value)): i -= curlen elif isinstance(value, Sequence): return value[i] diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 9cad56e4176c8..9a4f02854e55f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -470,9 +470,9 @@ class _AssetEventAccessorsMixin(Generic[T]): def for_asset(self, *, name: str | None = None, uri: str | None = None) -> T: if name and uri: return self[Asset(name=name, uri=uri)] - elif name: + if name: return self[Asset.ref(name=name)] - elif uri: + if uri: return self[Asset.ref(uri=uri)] raise ValueError("name and uri cannot both be None") diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index acbca29c0b954..095ab051fb1c4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -97,7 +97,7 @@ def __len__(self) -> int: msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, ErrorResponse): raise RuntimeError(msg) - elif not isinstance(msg, XComCountResponse): + if not isinstance(msg, XComCountResponse): raise TypeError(f"Got unexpected response to GetXComCount: {msg}") self._len = msg.len return self._len @@ -112,11 +112,10 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: if isinstance(key, int): if key >= 0: return self._get_item(key) - else: - # val[-1] etc. - return self._get_item(len(self) + key) + # val[-1] etc. + return self._get_item(len(self) + key) - elif isinstance(key, slice): + if isinstance(key, slice): # This implements the slicing syntax. We want to optimize negative slicing (e.g. seq[-10:]) by not # doing an additional COUNT query (via HEAD http request) if possible. We can do this unless the # start and stop have different signs (i.e. one is positive and another negative). diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py index 6d2b42080c320..8daca414a7d86 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets_masker.py @@ -241,13 +241,12 @@ def _redact_all(self, item: Redactable, depth: int, max_depth: int = MAX_RECURSI return { dict_key: self._redact_all(subval, depth + 1, max_depth) for dict_key, subval in item.items() } - elif isinstance(item, (tuple, set)): + if isinstance(item, (tuple, set)): # Turn set in to tuple! return tuple(self._redact_all(subval, depth + 1, max_depth) for subval in item) - elif isinstance(item, list): + if isinstance(item, list): return list(self._redact_all(subval, depth + 1, max_depth) for subval in item) - else: - return item + return item def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int) -> Redacted: # Avoid spending too much effort on redacting on deeply nested @@ -264,33 +263,32 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int for dict_key, subval in item.items() } return to_return - elif isinstance(item, Enum): + if isinstance(item, Enum): return self._redact(item=item.value, name=name, depth=depth, max_depth=max_depth) - elif _is_v1_env_var(item): + if _is_v1_env_var(item): tmp: dict = item.to_dict() if should_hide_value_for_key(tmp.get("name", "")) and "value" in tmp: tmp["value"] = "***" else: return self._redact(item=tmp, name=name, depth=depth, max_depth=max_depth) return tmp - elif isinstance(item, str): + if isinstance(item, str): if self.replacer: # We can't replace specific values, but the key-based redacting # can still happen, so we can't short-circuit, we need to walk # the structure. return self.replacer.sub("***", str(item)) return item - elif isinstance(item, (tuple, set)): + if isinstance(item, (tuple, set)): # Turn set in to tuple! return tuple( self._redact(subval, name=None, depth=(depth + 1), max_depth=max_depth) for subval in item ) - elif isinstance(item, list): + if isinstance(item, list): return [ self._redact(subval, name=None, depth=(depth + 1), max_depth=max_depth) for subval in item ] - else: - return item + return item # I think this should never happen, but it does not hurt to leave it just in case # Well. It happened (see https://github.com/apache/airflow/issues/19816#issuecomment-983311373) # but it caused infinite recursion, to avoid this we mark the log as already filtered. diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 819549ef02872..165e27c0ed79c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -915,7 +915,7 @@ def final_state(self): """ if self._exit_code == 0: return self._terminal_state or TerminalTIState.SUCCESS - elif self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED: + if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED: return SERVER_TERMINATED return TerminalTIState.FAILED diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 344746c63b09a..66d45e36f7f1e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -954,7 +954,7 @@ def _handle_trigger_dag_run( method_name="execute_complete", ) return _defer_task(defer, ti, log) - elif drte.wait_for_completion: + if drte.wait_for_completion: while True: log.info( "Waiting for dag run to complete execution in allowed state.", diff --git a/task-sdk/src/airflow/sdk/io/path.py b/task-sdk/src/airflow/sdk/io/path.py index 8c519d0ee27df..5a3517f6527b8 100644 --- a/task-sdk/src/airflow/sdk/io/path.py +++ b/task-sdk/src/airflow/sdk/io/path.py @@ -135,8 +135,7 @@ def container(self) -> str: def bucket(self) -> str: if self._url: return self._url.netloc - else: - return "" + return "" @property def key(self) -> str: @@ -144,8 +143,7 @@ def key(self) -> str: # per convention, we strip the leading slashes to ensure a relative key is returned # we keep the trailing slash to allow for directory-like semantics return self._url.path.lstrip(self.sep) - else: - return "" + return "" @property def namespace(self) -> str: @@ -199,15 +197,13 @@ def replace(self, target) -> ObjectStoragePath: def cwd(cls): if cls is ObjectStoragePath: return get_upath_class("").cwd() - else: - raise NotImplementedError + raise NotImplementedError @classmethod def home(cls): if cls is ObjectStoragePath: return get_upath_class("").home() - else: - raise NotImplementedError + raise NotImplementedError # EXTENDED OPERATIONS diff --git a/task-sdk/src/airflow/sdk/io/store.py b/task-sdk/src/airflow/sdk/io/store.py index 79b3e70b42e04..68c7ad9fbf868 100644 --- a/task-sdk/src/airflow/sdk/io/store.py +++ b/task-sdk/src/airflow/sdk/io/store.py @@ -144,7 +144,7 @@ def attach( if alias: if store := _STORE_CACHE.get(alias): return store - elif not protocol: + if not protocol: raise ValueError(f"No registered store with alias: {alias}") if not protocol: diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 02203ad303a68..ed22360068ee1 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -26,7 +26,7 @@ import warnings from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, TypeVar +from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, TypeVar, cast import msgspec import structlog @@ -195,55 +195,54 @@ def logging_processors(enable_pretty_log: bool, mask_secrets: bool = True): "timestamper": timestamper, "console": console, } + dict_exc_formatter = structlog.tracebacks.ExceptionDictTransformer( + use_rich=False, show_locals=False, suppress=suppress + ) + + dict_tracebacks = structlog.processors.ExceptionRenderer(dict_exc_formatter) + if hasattr(__builtins__, "BaseExceptionGroup"): + exc_group_processor = exception_group_tracebacks(dict_exc_formatter) + processors.append(exc_group_processor) else: - dict_exc_formatter = structlog.tracebacks.ExceptionDictTransformer( - use_rich=False, show_locals=False, suppress=suppress - ) + exc_group_processor = None + + def json_dumps(msg, default): + # Note: this is likely an "expensive" step, but lets massage the dict order for nice + # viewing of the raw JSON logs. + # Maybe we don't need this once the UI renders the JSON instead of displaying the raw text + msg = { + "timestamp": msg.pop("timestamp"), + "level": msg.pop("level"), + "event": msg.pop("event"), + **msg, + } + return msgspec.json.encode(msg, enc_hook=default) - dict_tracebacks = structlog.processors.ExceptionRenderer(dict_exc_formatter) - if hasattr(__builtins__, "BaseExceptionGroup"): - exc_group_processor = exception_group_tracebacks(dict_exc_formatter) - processors.append(exc_group_processor) - else: - exc_group_processor = None - - def json_dumps(msg, default): - # Note: this is likely an "expensive" step, but lets massage the dict order for nice - # viewing of the raw JSON logs. - # Maybe we don't need this once the UI renders the JSON instead of displaying the raw text - msg = { - "timestamp": msg.pop("timestamp"), - "level": msg.pop("level"), - "event": msg.pop("event"), - **msg, - } - return msgspec.json.encode(msg, enc_hook=default) - - def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: - # Stdlib logging doesn't need the re-ordering, it's fine as it is - return msgspec.json.encode(event_dict).decode("utf-8") - - json = structlog.processors.JSONRenderer(serializer=json_dumps) - - processors.extend( - ( - dict_tracebacks, - structlog.processors.UnicodeDecoder(), - ), - ) + def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: + # Stdlib logging doesn't need the re-ordering, it's fine as it is + return msgspec.json.encode(event_dict).decode("utf-8") - # Include the remote logging provider for tasks if there are any we need (such as upload to Cloudwatch) - if (remote := load_remote_log_handler()) and (remote_processors := getattr(remote, "processors")): - processors.extend(remote_processors) + json = structlog.processors.JSONRenderer(serializer=json_dumps) - processors.append(json) + processors.extend( + ( + dict_tracebacks, + structlog.processors.UnicodeDecoder(), + ), + ) - return processors, { - "timestamper": timestamper, - "exc_group_processor": exc_group_processor, - "dict_tracebacks": dict_tracebacks, - "json": json_processor, - } + # Include the remote logging provider for tasks if there are any we need (such as upload to Cloudwatch) + if (remote := load_remote_log_handler()) and (remote_processors := getattr(remote, "processors")): + processors.extend(remote_processors) + + processors.append(json) + + return processors, { + "timestamper": timestamper, + "exc_group_processor": exc_group_processor, + "dict_tracebacks": dict_tracebacks, + "json": json_processor, + } @cache @@ -320,8 +319,7 @@ def configure_logging( raise ValueError( f"output needed to be a binary stream, but it didn't have a buffer attribute ({output=})" ) - else: - output = output.buffer + output = cast("TextIO", output).buffer if TYPE_CHECKING: # Not all binary streams are isinstance of BinaryIO, so we check via looking at `mode` at # runtime. mypy doesn't grok that though diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 333ba0ecc1758..1eacfe1e632b7 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -978,7 +978,7 @@ def handle_request(request: httpx.Request) -> httpx.Response: } }, ) - elif request.url.path == "/dag-runs/test_trigger_conflict_reset/test_run_id/clear": + if request.url.path == "/dag-runs/test_trigger_conflict_reset/test_run_id/clear": return httpx.Response(status_code=204) return httpx.Response(status_code=422) diff --git a/task-sdk/tests/task_sdk/definitions/decorators/test_setup_teardown.py b/task-sdk/tests/task_sdk/definitions/decorators/test_setup_teardown.py index f4d21c98487da..fbadaa84ac4c6 100644 --- a/task-sdk/tests/task_sdk/definitions/decorators/test_setup_teardown.py +++ b/task-sdk/tests/task_sdk/definitions/decorators/test_setup_teardown.py @@ -28,31 +28,30 @@ def make_task(name, type_, setup_=False, teardown_=False): if type_ == "classic" and setup_: return BashOperator(task_id=name, bash_command="echo 1").as_setup() - elif type_ == "classic" and teardown_: + if type_ == "classic" and teardown_: return BashOperator(task_id=name, bash_command="echo 1").as_teardown() - elif type_ == "classic": + if type_ == "classic": return BashOperator(task_id=name, bash_command="echo 1") - elif setup_: + if setup_: @setup def setuptask(): pass return setuptask.override(task_id=name)() - elif teardown_: + if teardown_: @teardown def teardowntask(): pass return teardowntask.override(task_id=name)() - else: - @task - def my_task(): - pass + @task + def my_task(): + pass - return my_task.override(task_id=name)() + return my_task.override(task_id=name)() class TestSetupTearDownTask: diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index 5b7c922cde1e2..8328b061811ce 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -242,7 +242,7 @@ def assets_equal(a1: BaseAsset, a2: BaseAsset) -> bool: if isinstance(a1, Asset) and isinstance(a2, Asset): return a1.uri == a2.uri - elif isinstance(a1, (AssetAny, AssetAll)) and isinstance(a2, (AssetAny, AssetAll)): + if isinstance(a1, (AssetAny, AssetAll)) and isinstance(a2, (AssetAny, AssetAll)): if len(a1.objects) != len(a2.objects): return False diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index bdaae8f0be0e7..5373202ab9f17 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -603,7 +603,7 @@ def xcom_get(): if key in expected_values: value = expected_values[key] return XComResult(key="return_value", value=value) - elif last_request.map_index is None: + if last_request.map_index is None: # Get all mapped XComValues for this ti value = [v for k, v in expected_values.items() if k[0] == last_request.task_id] return XComResult(key="return_value", value=value) diff --git a/task-sdk/tests/task_sdk/definitions/test_mixins.py b/task-sdk/tests/task_sdk/definitions/test_mixins.py index b5e5e4bfe26f5..80d4dadbb3d4d 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mixins.py +++ b/task-sdk/tests/task_sdk/definitions/test_mixins.py @@ -41,31 +41,30 @@ def get_task_attr(task_like, attr): def make_task(name, type_, setup_=False, teardown_=False): if type_ == "classic" and setup_: return BaseOperator(task_id=name).as_setup() - elif type_ == "classic" and teardown_: + if type_ == "classic" and teardown_: return BaseOperator(task_id=name).as_teardown() - elif type_ == "classic": + if type_ == "classic": return BaseOperator(task_id=name) - elif setup_: + if setup_: @setup def setuptask(): pass return setuptask.override(task_id=name)() - elif teardown_: + if teardown_: @teardown def teardowntask(): pass return teardowntask.override(task_id=name)() - else: - @task - def my_task(): - pass + @task + def my_task(): + pass - return my_task.override(task_id=name)() + return my_task.override(task_id=name)() @pytest.mark.parametrize( diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index cc97891c9b9e8..41b2b770996d5 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -502,19 +502,18 @@ def handle_request(request: httpx.Request) -> httpx.Response: if request_count["count"] == 1: # First request succeeds return httpx.Response(status_code=204) - else: - # Second request returns a conflict status code - return httpx.Response( - 409, - json={ - "reason": "not_running", - "message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate", - "current_state": "success", - }, - ) - elif request.url.path == f"/task-instances/{ti_id}/run": + # Second request returns a conflict status code + return httpx.Response( + 409, + json={ + "reason": "not_running", + "message": "TI is no longer in the 'running' state. Task state might be externally set and task should terminate", + "current_state": "success", + }, + ) + if request.url.path == f"/task-instances/{ti_id}/run": return httpx.Response(200, json=make_ti_context_dict()) - elif request.url.path == f"/task-instances/{ti_id}/state": + if request.url.path == f"/task-instances/{ti_id}/state": pytest.fail("Should not have sent a state update request") # Return a 204 for all other requests return httpx.Response(status_code=204)