From 1b1e76ce75b91d8bdef0e0bc404e7a53f3d7bbda Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 14:45:36 +0800 Subject: [PATCH 01/28] feat(task_sdk): Move asset from Airflow core to task_sdk --- airflow/__init__.py | 4 +- .../api_connexion/endpoints/asset_endpoint.py | 2 +- .../core_api/routes/public/assets.py | 2 +- .../endpoints/rpc_api_endpoint.py | 2 +- airflow/assets/__init__.py | 508 ----------------- airflow/assets/manager.py | 4 +- airflow/assets/metadata.py | 4 +- airflow/dag_processing/collection.py | 2 +- airflow/datasets/__init__.py | 4 +- airflow/decorators/base.py | 2 +- airflow/example_dags/example_asset_alias.py | 2 +- .../example_asset_alias_with_no_taskflow.py | 2 +- airflow/example_dags/example_assets.py | 2 +- .../example_dags/example_inlet_event_extra.py | 2 +- .../example_outlet_event_extra.py | 4 +- airflow/lineage/hook.py | 2 +- airflow/listeners/spec/asset.py | 2 +- airflow/models/asset.py | 2 +- airflow/models/dag.py | 2 +- airflow/models/taskinstance.py | 2 +- airflow/providers_manager.py | 4 +- airflow/serialization/serialized_objects.py | 19 +- airflow/timetables/assets.py | 4 +- airflow/timetables/base.py | 4 +- airflow/timetables/simple.py | 4 +- airflow/utils/context.py | 6 +- airflow/utils/context.pyi | 2 +- airflow/www/views.py | 2 +- .../tests/test_pytest_args_for_test_types.py | 1 + .../authoring-and-scheduling/datasets.rst | 16 +- .../common/compat/assets/__init__.py | 8 +- .../providers/common/io/assets/file.py | 2 +- .../providers/openlineage/utils/utils.py | 2 +- .../microsoft/azure/example_msfabric.py | 2 +- .../check_tests_in_right_folders.py | 1 + task_sdk/src/airflow/sdk/definitions/asset.py | 532 ++++++++++++++++++ task_sdk/src/airflow/sdk/definitions/dag.py | 4 +- .../tests/defintions}/test_asset.py | 74 +-- .../endpoints/test_dag_run_endpoint.py | 2 +- .../schemas/test_asset_schema.py | 2 +- .../api_connexion/schemas/test_dag_schema.py | 2 +- .../core_api/routes/public/test_dag_run.py | 2 +- .../core_api/routes/ui/test_assets.py | 2 +- tests/assets/test_manager.py | 2 +- tests/dags/test_assets.py | 2 +- tests/dags/test_only_empty_tasks.py | 2 +- tests/datasets/__init__.py | 16 + tests/datasets/test_dataset.py | 53 ++ tests/decorators/test_python.py | 2 +- tests/io/test_path.py | 2 +- tests/io/test_wrapper.py | 2 +- tests/jobs/test_scheduler_job.py | 2 +- tests/lineage/test_hook.py | 2 +- tests/listeners/asset_listener.py | 2 +- tests/listeners/test_asset_listener.py | 2 +- tests/models/test_asset.py | 2 +- tests/models/test_dag.py | 2 +- tests/models/test_serialized_dag.py | 2 +- tests/models/test_taskinstance.py | 32 +- tests/serialization/test_dag_serialization.py | 2 +- tests/serialization/test_serde.py | 4 +- .../serialization/test_serialized_objects.py | 2 +- tests/timetables/test_assets_timetable.py | 2 +- tests/utils/test_context.py | 2 +- tests/utils/test_json.py | 2 +- tests/www/views/test_views_asset.py | 2 +- tests/www/views/test_views_grid.py | 2 +- 67 files changed, 740 insertions(+), 658 deletions(-) create mode 100644 task_sdk/src/airflow/sdk/definitions/asset.py rename {tests/assets => task_sdk/tests/defintions}/test_asset.py (92%) create mode 100644 tests/datasets/__init__.py create mode 100644 tests/datasets/test_dataset.py diff --git a/airflow/__init__.py b/airflow/__init__.py index 411aac70fc6f..fed233f01460 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -85,15 +85,15 @@ "version": (".version", "", False), # Deprecated lazy imports "AirflowException": (".exceptions", "AirflowException", True), - "Dataset": (".assets", "Dataset", True), + "Dataset": (".sdk.definitions.asset", "Dataset", True), } if TYPE_CHECKING: # These objects are imported by PEP-562, however, static analyzers and IDE's # have no idea about typing of these objects. # Add it under TYPE_CHECKING block should help with it. - from airflow.assets import Asset, Dataset from airflow.models.dag import DAG from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.asset import Asset, Dataset def __getattr__(name: str): diff --git a/airflow/api_connexion/endpoints/asset_endpoint.py b/airflow/api_connexion/endpoints/asset_endpoint.py index 423e02d1b9c9..97a90ce28094 100644 --- a/airflow/api_connexion/endpoints/asset_endpoint.py +++ b/airflow/api_connexion/endpoints/asset_endpoint.py @@ -43,9 +43,9 @@ queued_event_collection_schema, queued_event_schema, ) -from airflow.assets import Asset from airflow.assets.manager import asset_manager from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel +from airflow.sdk.definitions.asset import Asset from airflow.utils import timezone from airflow.utils.api_migration import mark_fastapi_migration_done from airflow.utils.db import get_query_count diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index b94b825f7629..cca20536032a 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -49,9 +49,9 @@ QueuedEventResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.assets import Asset from airflow.assets.manager import asset_manager from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel +from airflow.sdk.definitions.asset import Asset from airflow.utils import timezone assets_router = AirflowRouter(tags=["Asset"]) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index abe685a009db..2db39c345f9f 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -53,7 +53,6 @@ @functools.lru_cache def initialize_method_map() -> dict[str, Callable]: from airflow.api.common.trigger_dag import trigger_dag - from airflow.assets import expand_alias_to_assets from airflow.assets.manager import AssetManager from airflow.cli.commands.task_command import _get_ti_db_access from airflow.dag_processing.manager import DagFileProcessorManager @@ -76,6 +75,7 @@ def initialize_method_map() -> dict[str, Callable]: _update_ti_heartbeat, _xcom_pull, ) + from airflow.sdk.definitions.asset import expand_alias_to_assets from airflow.secrets.metastore import MetastoreBackend from airflow.utils.cli_action_loggers import _default_action_log_internal from airflow.utils.log.file_task_handler import FileTaskHandler diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py index f1d36ac12b73..13a83393a912 100644 --- a/airflow/assets/__init__.py +++ b/airflow/assets/__init__.py @@ -14,511 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from __future__ import annotations - -import logging -import os -import urllib.parse -import warnings -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, cast, overload - -import attrs -from sqlalchemy import select - -from airflow.api_internal.internal_api_call import internal_api_call -from airflow.serialization.dag_dependency import DagDependency -from airflow.typing_compat import TypedDict -from airflow.utils.session import NEW_SESSION, provide_session - -if TYPE_CHECKING: - from urllib.parse import SplitResult - - from sqlalchemy.orm.session import Session - -__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"] - - -log = logging.getLogger(__name__) - - -def normalize_noop(parts: SplitResult) -> SplitResult: - """ - Place-hold a :class:`~urllib.parse.SplitResult`` normalizer. - - :meta private: - """ - return parts - - -def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: - if scheme == "file": - return normalize_noop - from airflow.providers_manager import ProvidersManager - - return ProvidersManager().asset_uri_handlers.get(scheme) - - -def _get_normalized_scheme(uri: str) -> str: - parsed = urllib.parse.urlsplit(uri) - return parsed.scheme.lower() - - -def _sanitize_uri(uri: str) -> str: - """ - Sanitize an asset URI. - - This checks for URI validity, and normalizes the URI if needed. A fully - normalized URI is returned. - """ - parsed = urllib.parse.urlsplit(uri) - if not parsed.scheme and not parsed.netloc: # Does not look like a URI. - return uri - if not (normalized_scheme := _get_normalized_scheme(uri)): - return uri - if normalized_scheme.startswith("x-"): - return uri - if normalized_scheme == "airflow": - raise ValueError("Asset scheme 'airflow' is reserved") - _, auth_exists, normalized_netloc = parsed.netloc.rpartition("@") - if auth_exists: - # TODO: Collect this into a DagWarning. - warnings.warn( - "An Asset URI should not contain auth info (e.g. username or " - "password). It has been automatically dropped.", - UserWarning, - stacklevel=3, - ) - if parsed.query: - normalized_query = urllib.parse.urlencode(sorted(urllib.parse.parse_qsl(parsed.query))) - else: - normalized_query = "" - parsed = parsed._replace( - scheme=normalized_scheme, - netloc=normalized_netloc, - path=parsed.path.rstrip("/") or "/", # Remove all trailing slashes. - query=normalized_query, - fragment="", # Ignore any fragments. - ) - if (normalizer := _get_uri_normalizer(normalized_scheme)) is not None: - parsed = normalizer(parsed) - return urllib.parse.urlunsplit(parsed) - - -def _validate_identifier(instance, attribute, value): - if not isinstance(value, str): - raise ValueError(f"{type(instance).__name__} {attribute.name} must be a string") - if len(value) > 1500: - raise ValueError(f"{type(instance).__name__} {attribute.name} cannot exceed 1500 characters") - if value.isspace(): - raise ValueError(f"{type(instance).__name__} {attribute.name} cannot be just whitespace") - if not value.isascii(): - raise ValueError(f"{type(instance).__name__} {attribute.name} must only consist of ASCII characters") - return value - - -def _validate_non_empty_identifier(instance, attribute, value): - if not _validate_identifier(instance, attribute, value): - raise ValueError(f"{type(instance).__name__} {attribute.name} cannot be empty") - return value - - -def _validate_asset_name(instance, attribute, value): - _validate_non_empty_identifier(instance, attribute, value) - if value == "self" or value == "context": - raise ValueError(f"prohibited name for asset: {value}") - return value - - -def extract_event_key(value: str | Asset | AssetAlias) -> str: - """ - Extract the key of an inlet or an outlet event. - - If the input value is a string, it is treated as a URI and sanitized. If the - input is a :class:`Asset`, the URI it contains is considered sanitized and - returned directly. If the input is a :class:`AssetAlias`, the name it contains - will be returned directly. - - :meta private: - """ - if isinstance(value, AssetAlias): - return value.name - - if isinstance(value, Asset): - return value.uri - return _sanitize_uri(str(value)) - - -@internal_api_call -@provide_session -def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SESSION) -> list[BaseAsset]: - """Expand asset alias to resolved assets.""" - from airflow.models.asset import AssetAliasModel - - alias_name = alias.name if isinstance(alias, AssetAlias) else alias - - asset_alias_obj = session.scalar( - select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1) - ) - if asset_alias_obj: - return [asset.to_public() for asset in asset_alias_obj.assets] - return [] - - -@attrs.define(kw_only=True) -class AssetRef: - """Reference to an asset.""" - - name: str - - -class BaseAsset: - """ - Protocol for all asset triggers to use in ``DAG(schedule=...)``. - - :meta private: - """ - - def __bool__(self) -> bool: - return True - - def __or__(self, other: BaseAsset) -> BaseAsset: - if not isinstance(other, BaseAsset): - return NotImplemented - return AssetAny(self, other) - - def __and__(self, other: BaseAsset) -> BaseAsset: - if not isinstance(other, BaseAsset): - return NotImplemented - return AssetAll(self, other) - - def as_expression(self) -> Any: - """ - Serialize the asset into its scheduling expression. - - The return value is stored in DagModel for display purposes. It must be - JSON-compatible. - - :meta private: - """ - raise NotImplementedError - - def evaluate(self, statuses: dict[str, bool]) -> bool: - raise NotImplementedError - - def iter_assets(self) -> Iterator[tuple[str, Asset]]: - raise NotImplementedError - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - raise NotImplementedError - - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: - """ - Iterate a base asset as dag dependency. - - :meta private: - """ - raise NotImplementedError - - -@attrs.define(unsafe_hash=False) -class AssetAlias(BaseAsset): - """A represeation of asset alias which is used to create asset during the runtime.""" - - name: str = attrs.field(validator=_validate_non_empty_identifier) - group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) - - def iter_assets(self) -> Iterator[tuple[str, Asset]]: - return iter(()) - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - yield self.name, self - - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: - """ - Iterate an asset alias as dag dependency. - - :meta private: - """ - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) - - -class AssetAliasEvent(TypedDict): - """A represeation of asset event to be triggered by an asset alias.""" - - source_alias_name: str - dest_asset_uri: str - extra: dict[str, Any] - - -def _set_extra_default(extra: dict | None) -> dict: - """ - Automatically convert None to an empty dict. - - This allows the caller site to continue doing ``Asset(uri, extra=None)``, - but still allow the ``extra`` attribute to always be a dict. - """ - if extra is None: - return {} - return extra - - -@attrs.define(init=False, unsafe_hash=False) -class Asset(os.PathLike, BaseAsset): - """A representation of data asset dependencies between workflows.""" - - name: str - uri: str - group: str - extra: dict[str, Any] - - asset_type: ClassVar[str] = "asset" - __version__: ClassVar[int] = 1 - - @overload - def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None: - """Canonical; both name and uri are provided.""" - - @overload - def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None: - """It's possible to only provide the name, either by keyword or as the only positional argument.""" - - @overload - def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None: - """It's possible to only provide the URI as a keyword argument.""" - - def __init__( - self, - name: str | None = None, - uri: str | None = None, - *, - group: str = "", - extra: dict | None = None, - ) -> None: - if name is None and uri is None: - raise TypeError("Asset() requires either 'name' or 'uri'") - elif name is None: - name = uri - elif uri is None: - uri = name - fields = attrs.fields_dict(Asset) - self.name = _validate_asset_name(self, fields["name"], name) - self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri)) - self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type - self.extra = _set_extra_default(extra) - - def __fspath__(self) -> str: - return self.uri - - @property - def normalized_uri(self) -> str | None: - """ - Returns the normalized and AIP-60 compliant URI whenever possible. - - If we can't retrieve the scheme from URI or no normalizer is provided or if parsing fails, - it returns None. - - If a normalizer for the scheme exists and parsing is successful we return the normalizer result. - """ - if not (normalized_scheme := _get_normalized_scheme(self.uri)): - return None - - if (normalizer := _get_uri_normalizer(normalized_scheme)) is None: - return None - parsed = urllib.parse.urlsplit(self.uri) - try: - normalized_uri = normalizer(parsed) - return urllib.parse.urlunsplit(normalized_uri) - except ValueError: - return None - - def as_expression(self) -> Any: - """ - Serialize the asset into its scheduling expression. - - :meta private: - """ - return self.uri - - def iter_assets(self) -> Iterator[tuple[str, Asset]]: - yield self.uri, self - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - return iter(()) - - def evaluate(self, statuses: dict[str, bool]) -> bool: - return statuses.get(self.uri, False) - - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: - """ - Iterate an asset as dag dependency. - - :meta private: - """ - yield DagDependency( - source=source or "asset", - target=target or "asset", - dependency_type="asset", - dependency_id=self.uri, - ) - - -class Dataset(Asset): - """A representation of dataset dependencies between workflows.""" - - asset_type: ClassVar[str] = "dataset" - - -class Model(Asset): - """A representation of model dependencies between workflows.""" - - asset_type: ClassVar[str] = "model" - - -class _AssetBooleanCondition(BaseAsset): - """Base class for asset boolean logic.""" - - agg_func: Callable[[Iterable], bool] - - def __init__(self, *objects: BaseAsset) -> None: - if not all(isinstance(o, BaseAsset) for o in objects): - raise TypeError("expect asset expressions in condition") - - self.objects = [ - _AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects - ] - - def evaluate(self, statuses: dict[str, bool]) -> bool: - return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) - - def iter_assets(self) -> Iterator[tuple[str, Asset]]: - seen = set() # We want to keep the first instance. - for o in self.objects: - for k, v in o.iter_assets(): - if k in seen: - continue - yield k, v - seen.add(k) - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - """Filter asset aliases in the condition.""" - for o in self.objects: - yield from o.iter_asset_aliases() - - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: - """ - Iterate asset, asset aliases and their resolved assets as dag dependency. - - :meta private: - """ - for obj in self.objects: - yield from obj.iter_dag_dependencies(source=source, target=target) - - -class AssetAny(_AssetBooleanCondition): - """Use to combine assets schedule references in an "and" relationship.""" - - agg_func = any - - def __or__(self, other: BaseAsset) -> BaseAsset: - if not isinstance(other, BaseAsset): - return NotImplemented - # Optimization: X | (Y | Z) is equivalent to X | Y | Z. - return AssetAny(*self.objects, other) - - def __repr__(self) -> str: - return f"AssetAny({', '.join(map(str, self.objects))})" - - def as_expression(self) -> dict[str, Any]: - """ - Serialize the asset into its scheduling expression. - - :meta private: - """ - return {"any": [o.as_expression() for o in self.objects]} - - -class _AssetAliasCondition(AssetAny): - """ - Use to expand AssetAlias as AssetAny of its resolved Assets. - - :meta private: - """ - - def __init__(self, name: str) -> None: - self.name = name - self.objects = expand_alias_to_assets(name) - - def __repr__(self) -> str: - return f"_AssetAliasCondition({', '.join(map(str, self.objects))})" - - def as_expression(self) -> Any: - """ - Serialize the asset alias into its scheduling expression. - - :meta private: - """ - return {"alias": self.name} - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - yield self.name, AssetAlias(self.name) - - def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: - """ - Iterate an asset alias and its resolved assets as dag dependency. - - :meta private: - """ - if self.objects: - for obj in self.objects: - asset = cast(Asset, obj) - uri = asset.uri - # asset - yield DagDependency( - source=f"asset-alias:{self.name}" if source else "asset", - target="asset" if source else f"asset-alias:{self.name}", - dependency_type="asset", - dependency_id=uri, - ) - # asset alias - yield DagDependency( - source=source or f"asset:{uri}", - target=target or f"asset:{uri}", - dependency_type="asset-alias", - dependency_id=self.name, - ) - else: - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) - - -class AssetAll(_AssetBooleanCondition): - """Use to combine assets schedule references in an "or" relationship.""" - - agg_func = all - - def __and__(self, other: BaseAsset) -> BaseAsset: - if not isinstance(other, BaseAsset): - return NotImplemented - # Optimization: X & (Y & Z) is equivalent to X & Y & Z. - return AssetAll(*self.objects, other) - - def __repr__(self) -> str: - return f"AssetAll({', '.join(map(str, self.objects))})" - - def as_expression(self) -> Any: - """ - Serialize the assets into its scheduling expression. - - :meta private: - """ - return {"all": [o.as_expression() for o in self.objects]} diff --git a/airflow/assets/manager.py b/airflow/assets/manager.py index a06c7c31786f..0616d6015113 100644 --- a/airflow/assets/manager.py +++ b/airflow/assets/manager.py @@ -24,7 +24,6 @@ from sqlalchemy.orm import joinedload from airflow.api_internal.internal_api_call import internal_api_call -from airflow.assets import Asset from airflow.configuration import conf from airflow.listeners.listener import get_listener_manager from airflow.models.asset import ( @@ -36,15 +35,16 @@ DagScheduleAssetReference, ) from airflow.models.dagbag import DagPriorityParsingRequest +from airflow.sdk.definitions.asset import Asset from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from sqlalchemy.orm.session import Session - from airflow.assets import Asset, AssetAlias from airflow.models.dag import DagModel from airflow.models.taskinstance import TaskInstance + from airflow.sdk.definitions.asset import Asset, AssetAlias class AssetManager(LoggingMixin): diff --git a/airflow/assets/metadata.py b/airflow/assets/metadata.py index b7522226230f..8feffe389e3e 100644 --- a/airflow/assets/metadata.py +++ b/airflow/assets/metadata.py @@ -21,10 +21,10 @@ import attrs -from airflow.assets import AssetAlias, extract_event_key +from airflow.sdk.definitions.asset import AssetAlias, extract_event_key if TYPE_CHECKING: - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset @attrs.define(init=False) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 0ca121c56185..a1bfafdab4bb 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -34,7 +34,6 @@ from sqlalchemy import func, select, tuple_ from sqlalchemy.orm import joinedload, load_only -from airflow.assets import Asset, AssetAlias from airflow.assets.manager import asset_manager from airflow.models.asset import ( AssetAliasModel, @@ -45,6 +44,7 @@ ) from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag from airflow.models.dagrun import DagRun +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.timezone import utcnow from airflow.utils.types import DagRunType diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 34729e437805..3524466e58c1 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -27,13 +27,13 @@ import warnings -from airflow.assets import AssetAlias as DatasetAlias, Dataset +from airflow.sdk.definitions.asset import AssetAlias as DatasetAlias, Dataset # TODO: Remove this module in Airflow 3.2 warnings.warn( "Import from the airflow.dataset module is deprecated and " - "will be removed in the Airflow 3.2. Please import it from 'airflow.assets'.", + "will be removed in the Airflow 3.2. Please import it from 'airflow.sdk.definitions.asset'.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 3de2f27d04c9..4f40c6ad3b47 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -40,7 +40,6 @@ import re2 import typing_extensions -from airflow.assets import Asset from airflow.models.baseoperator import ( BaseOperator, coerce_resources, @@ -56,6 +55,7 @@ ) from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value from airflow.models.xcom_arg import XComArg +from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext from airflow.typing_compat import ParamSpec, Protocol diff --git a/airflow/example_dags/example_asset_alias.py b/airflow/example_dags/example_asset_alias.py index 4970b1eda266..a7f2aac5845c 100644 --- a/airflow/example_dags/example_asset_alias.py +++ b/airflow/example_dags/example_asset_alias.py @@ -38,8 +38,8 @@ import pendulum from airflow import DAG -from airflow.assets import Asset, AssetAlias from airflow.decorators import task +from airflow.sdk.definitions.asset import Asset, AssetAlias with DAG( dag_id="asset_s3_bucket_producer", diff --git a/airflow/example_dags/example_asset_alias_with_no_taskflow.py b/airflow/example_dags/example_asset_alias_with_no_taskflow.py index c9b04d66d2f6..19f31465ea4f 100644 --- a/airflow/example_dags/example_asset_alias_with_no_taskflow.py +++ b/airflow/example_dags/example_asset_alias_with_no_taskflow.py @@ -36,8 +36,8 @@ import pendulum from airflow import DAG -from airflow.assets import Asset, AssetAlias from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.asset import Asset, AssetAlias with DAG( dag_id="asset_s3_bucket_producer_with_no_taskflow", diff --git a/airflow/example_dags/example_assets.py b/airflow/example_dags/example_assets.py index 451f17a3a3ab..b81cdad9453d 100644 --- a/airflow/example_dags/example_assets.py +++ b/airflow/example_dags/example_assets.py @@ -54,9 +54,9 @@ import pendulum -from airflow.assets import Asset from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator +from airflow.sdk.definitions.asset import Asset from airflow.timetables.assets import AssetOrTimeSchedule from airflow.timetables.trigger import CronTriggerTimetable diff --git a/airflow/example_dags/example_inlet_event_extra.py b/airflow/example_dags/example_inlet_event_extra.py index 9773df7a3f91..c503e832a833 100644 --- a/airflow/example_dags/example_inlet_event_extra.py +++ b/airflow/example_dags/example_inlet_event_extra.py @@ -25,10 +25,10 @@ import datetime -from airflow.assets import Asset from airflow.decorators import task from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator +from airflow.sdk.definitions.asset import Asset asset = Asset("s3://output/1.txt") diff --git a/airflow/example_dags/example_outlet_event_extra.py b/airflow/example_dags/example_outlet_event_extra.py index 0d097eab0ac2..dd3041e18fc0 100644 --- a/airflow/example_dags/example_outlet_event_extra.py +++ b/airflow/example_dags/example_outlet_event_extra.py @@ -25,11 +25,11 @@ import datetime -from airflow.assets import Asset -from airflow.assets.metadata import Metadata from airflow.decorators import task from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator +from airflow.sdk.definitions.asset import Asset +from airflow.sdk.definitions.asset.metadata import Metadata ds = Asset("s3://output/1.txt") diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index fd321bcab49c..9e5f8f664822 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -24,8 +24,8 @@ import attr -from airflow.assets import Asset from airflow.providers_manager import ProvidersManager +from airflow.sdk.definitions.asset import Asset from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: diff --git a/airflow/listeners/spec/asset.py b/airflow/listeners/spec/asset.py index dba9ac700e41..f99b11eb6843 100644 --- a/airflow/listeners/spec/asset.py +++ b/airflow/listeners/spec/asset.py @@ -22,7 +22,7 @@ from pluggy import HookspecMarker if TYPE_CHECKING: - from airflow.assets import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset, AssetAlias hookspec = HookspecMarker("airflow") diff --git a/airflow/models/asset.py b/airflow/models/asset.py index 50914d51650b..126bc5dc2d3f 100644 --- a/airflow/models/asset.py +++ b/airflow/models/asset.py @@ -35,8 +35,8 @@ ) from sqlalchemy.orm import relationship -from airflow.assets import Asset, AssetAlias from airflow.models.base import Base, StringID +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.settings import json from airflow.utils import timezone from airflow.utils.sqlalchemy import UtcDateTime diff --git a/airflow/models/dag.py b/airflow/models/dag.py index ff4eac87b465..0d7ac61d2b9a 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -71,7 +71,6 @@ from airflow import settings, utils from airflow.api_internal.internal_api_call import internal_api_call -from airflow.assets import Asset, AssetAlias, BaseAsset from airflow.configuration import conf as airflow_conf, secrets_backend_list from airflow.exceptions import ( AirflowException, @@ -94,6 +93,7 @@ clear_task_instances, ) from airflow.models.tasklog import LogTemplate +from airflow.sdk.definitions.asset import Asset, AssetAlias, BaseAsset from airflow.sdk.definitions.dag import DAG as TaskSDKDag, dag as task_sdk_dag_decorator from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.security import permissions diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index dd0bf3916a4f..a176e0282b60 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -73,7 +73,6 @@ from airflow import settings from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call -from airflow.assets import Asset, AssetAlias from airflow.assets.manager import asset_manager from airflow.configuration import conf from airflow.exceptions import ( @@ -102,6 +101,7 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.plugins_manager import integrate_macros_plugins +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index e5c02d0113e1..5d38454b1891 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -87,9 +87,9 @@ def ensure_prefix(field): if TYPE_CHECKING: from urllib.parse import SplitResult - from airflow.assets import Asset from airflow.decorators.base import TaskDecorator from airflow.hooks.base import BaseHook + from airflow.sdk.definitions.asset import Asset from airflow.typing_compat import Literal @@ -905,7 +905,7 @@ def _discover_filesystems(self) -> None: def _discover_asset_uri_resources(self) -> None: """Discovers and registers asset URI handlers, factories, and converters for all providers.""" - from airflow.assets import normalize_noop + from airflow.sdk.definitions.asset import normalize_noop def _safe_register_resource( provider_package_name: str, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index cb4a35c2aecd..d9771a2c401f 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -35,15 +35,6 @@ from pendulum.tz.timezone import FixedTimezone, Timezone from airflow import macros -from airflow.assets import ( - Asset, - AssetAlias, - AssetAll, - AssetAny, - AssetRef, - BaseAsset, - _AssetAliasCondition, -) from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.exceptions import AirflowException, SerializationError, TaskDeferred from airflow.jobs.job import Job @@ -60,6 +51,16 @@ from airflow.models.tasklog import LogTemplate from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg from airflow.providers_manager import ProvidersManager +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasCondition, + AssetAll, + AssetAny, + AssetRef, + BaseAsset, + _AssetAliasCondition, +) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding diff --git a/airflow/timetables/assets.py b/airflow/timetables/assets.py index d69a8e4d80cc..6d2331324382 100644 --- a/airflow/timetables/assets.py +++ b/airflow/timetables/assets.py @@ -19,8 +19,8 @@ import typing -from airflow.assets import AssetAll, BaseAsset from airflow.exceptions import AirflowTimetableInvalid +from airflow.sdk.definitions.asset import AssetAll, BaseAsset from airflow.timetables.simple import AssetTriggeredTimetable from airflow.utils.types import DagRunType @@ -29,7 +29,7 @@ import pendulum - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py index f8aa4279ebba..1a076747ec59 100644 --- a/airflow/timetables/base.py +++ b/airflow/timetables/base.py @@ -18,13 +18,13 @@ from typing import TYPE_CHECKING, Any, Iterator, NamedTuple, Sequence -from airflow.assets import BaseAsset +from airflow.sdk.definitions.asset import BaseAsset from airflow.typing_compat import Protocol, runtime_checkable if TYPE_CHECKING: from pendulum import DateTime - from airflow.assets import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.serialization.dag_dependency import DagDependency from airflow.utils.types import DagRunType diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index 3457c52a08aa..adba135c5785 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Collection, Sequence -from airflow.assets import AssetAlias, _AssetAliasCondition +from airflow.sdk.definitions.asset import AssetAlias, _AssetAliasCondition from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.utils import timezone @@ -26,8 +26,8 @@ from pendulum import DateTime from sqlalchemy import Session - from airflow.assets import BaseAsset from airflow.models.asset import AssetEvent + from airflow.sdk.definitions.asset import BaseAsset from airflow.timetables.base import TimeRestriction from airflow.utils.types import DagRunType diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 3e217e748d02..5e423d4746af 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -40,15 +40,15 @@ import lazy_object_proxy from sqlalchemy import select -from airflow.assets import ( +from airflow.exceptions import RemovedInAirflow3Warning +from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, _fetch_active_assets_by_name +from airflow.sdk.definitions.asset import ( Asset, AssetAlias, AssetAliasEvent, AssetRef, extract_event_key, ) -from airflow.exceptions import RemovedInAirflow3Warning -from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, _fetch_active_assets_by_name from airflow.utils.db import LazySelectSequence from airflow.utils.types import NOTSET diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index f4ed77537ff1..069dba2f8f19 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -31,7 +31,6 @@ from typing import Any, Collection, Container, Iterable, Iterator, Mapping, Sequ from pendulum import DateTime from sqlalchemy.orm import Session -from airflow.assets import Asset, AssetAlias, AssetAliasEvent from airflow.configuration import AirflowConfigParser from airflow.models.asset import AssetEvent from airflow.models.baseoperator import BaseOperator @@ -39,6 +38,7 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent from airflow.serialization.pydantic.asset import AssetEventPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic diff --git a/airflow/www/views.py b/airflow/www/views.py index e97e585f753c..a239fe20384d 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -87,7 +87,6 @@ set_dag_run_state_to_success, set_state, ) -from airflow.assets import Asset, AssetAlias from airflow.auth.managers.models.resource_details import AccessView, DagAccessEntity, DagDetails from airflow.configuration import AIRFLOW_CONFIG, conf from airflow.exceptions import ( @@ -112,6 +111,7 @@ from airflow.models.taskinstance import TaskInstance, TaskInstanceNote from airflow.plugins_manager import PLUGINS_ATTRIBUTES_TO_DUMP from airflow.providers_manager import ProvidersManager +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.security import permissions from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py b/dev/breeze/tests/test_pytest_args_for_test_types.py index 20dc91947c1b..740f8c5b3e53 100644 --- a/dev/breeze/tests/test_pytest_args_for_test_types.py +++ b/dev/breeze/tests/test_pytest_args_for_test_types.py @@ -114,6 +114,7 @@ "tests/cluster_policies", "tests/config_templates", "tests/dag_processing", + "tests/datasets", "tests/decorators", "tests/hooks", "tests/io", diff --git a/docs/apache-airflow/authoring-and-scheduling/datasets.rst b/docs/apache-airflow/authoring-and-scheduling/datasets.rst index 7940a9051679..9e777d929958 100644 --- a/docs/apache-airflow/authoring-and-scheduling/datasets.rst +++ b/docs/apache-airflow/authoring-and-scheduling/datasets.rst @@ -27,7 +27,7 @@ In addition to scheduling DAGs based on time, you can also schedule DAGs to run .. code-block:: python - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset with DAG(...): MyOperator( @@ -57,7 +57,7 @@ An Airflow asset is a logical grouping of data. Upstream producer tasks can upda .. code-block:: python - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset example_asset = Asset("s3://asset-bucket/example.csv") @@ -67,7 +67,7 @@ You must create assets with a valid URI. Airflow core and providers define vario .. code-block:: python - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset example_asset = Asset(uri="s3://asset-bucket/example.csv", name="bucket-1") @@ -248,8 +248,8 @@ The easiest way to attach extra information to the asset event is by ``yield``-i .. code-block:: python - from airflow.assets import Asset - from airflow.assets.metadata import Metadata + from airflow.sdk.definitions.asset import Asset + from airflow.sdk.definitions.asset.metadata import Metadata example_s3_asset = Asset("s3://asset/example.csv") @@ -440,7 +440,7 @@ The following example creates an asset event against the S3 URI ``f"s3://bucket/ .. code-block:: python - from airflow.assets import AssetAlias + from airflow.sdk.definitions.asset import AssetAlias @task(outlets=[AssetAlias("my-task-outputs")]) @@ -452,7 +452,7 @@ The following example creates an asset event against the S3 URI ``f"s3://bucket/ .. code-block:: python - from airflow.assets.metadata import Metadata + from airflow.sdk.definitions.asset.metadata import Metadata @task(outlets=[AssetAlias("my-task-outputs")]) @@ -464,7 +464,7 @@ Only one asset event is emitted for an added asset, even if it is added to the a .. code-block:: python - from airflow.assets import AssetAlias + from airflow.sdk.definitions.asset import AssetAlias @task( diff --git a/providers/src/airflow/providers/common/compat/assets/__init__.py b/providers/src/airflow/providers/common/compat/assets/__init__.py index e302395f701e..47614e2a6c1e 100644 --- a/providers/src/airflow/providers/common/compat/assets/__init__.py +++ b/providers/src/airflow/providers/common/compat/assets/__init__.py @@ -22,7 +22,8 @@ from airflow import __version__ as AIRFLOW_VERSION if TYPE_CHECKING: - from airflow.assets import ( + from airflow.auth.managers.models.resource_details import AssetDetails + from airflow.sdk.definitions.asset import ( Asset, AssetAlias, AssetAliasEvent, @@ -30,10 +31,10 @@ AssetAny, expand_alias_to_assets, ) - from airflow.auth.managers.models.resource_details import AssetDetails else: try: - from airflow.assets import ( + from airflow.auth.managers.models.resource_details import AssetDetails + from airflow.sdk.definitions.asset import ( Asset, AssetAlias, AssetAliasEvent, @@ -41,7 +42,6 @@ AssetAny, expand_alias_to_assets, ) - from airflow.auth.managers.models.resource_details import AssetDetails except ModuleNotFoundError: from packaging.version import Version diff --git a/providers/src/airflow/providers/common/io/assets/file.py b/providers/src/airflow/providers/common/io/assets/file.py index fadc4cbe1bdc..aeff818bd6ee 100644 --- a/providers/src/airflow/providers/common/io/assets/file.py +++ b/providers/src/airflow/providers/common/io/assets/file.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING try: - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset except ModuleNotFoundError: from airflow.datasets import Dataset as Asset # type: ignore[no-redef] diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 99faa3c4d5ce..6c411171edb0 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -709,7 +709,7 @@ def translate_airflow_asset(asset: Asset, lineage_context) -> OpenLineageDataset some core Airflow changes are missing and ImportError is raised. """ try: - from airflow.assets import _get_normalized_scheme + from airflow.sdk.definitions.asset import _get_normalized_scheme except ModuleNotFoundError: try: from airflow.datasets import _get_normalized_scheme # type: ignore[no-redef, attr-defined] diff --git a/providers/tests/system/microsoft/azure/example_msfabric.py b/providers/tests/system/microsoft/azure/example_msfabric.py index 0f65df2f72f9..9da67b3a0fa6 100644 --- a/providers/tests/system/microsoft/azure/example_msfabric.py +++ b/providers/tests/system/microsoft/azure/example_msfabric.py @@ -19,8 +19,8 @@ from datetime import datetime from airflow import models -from airflow.assets import Asset from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator +from airflow.sdk.definitions.asset import Asset DAG_ID = "example_msfabric" diff --git a/scripts/ci/pre_commit/check_tests_in_right_folders.py b/scripts/ci/pre_commit/check_tests_in_right_folders.py index 11d44efd407a..a04400e1c0cb 100755 --- a/scripts/ci/pre_commit/check_tests_in_right_folders.py +++ b/scripts/ci/pre_commit/check_tests_in_right_folders.py @@ -46,6 +46,7 @@ "dags", "dags_corrupted", "dags_with_system_exit", + "datasets", "decorators", "executors", "hooks", diff --git a/task_sdk/src/airflow/sdk/definitions/asset.py b/task_sdk/src/airflow/sdk/definitions/asset.py new file mode 100644 index 000000000000..761ba50ec46d --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/asset.py @@ -0,0 +1,532 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import os +import urllib.parse +import warnings +from collections.abc import Iterable, Iterator +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + cast, + overload, +) + +import attrs +from sqlalchemy import select + +from airflow.api_internal.internal_api_call import internal_api_call +from airflow.serialization.dag_dependency import DagDependency +from airflow.typing_compat import TypedDict +from airflow.utils.session import NEW_SESSION, provide_session + +if TYPE_CHECKING: + from urllib.parse import SplitResult + + from sqlalchemy.orm.session import Session + +__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"] + + +log = logging.getLogger(__name__) + + +def normalize_noop(parts: SplitResult) -> SplitResult: + """ + Place-hold a :class:`~urllib.parse.SplitResult`` normalizer. + + :meta private: + """ + return parts + + +def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: + if scheme == "file": + return normalize_noop + from airflow.providers_manager import ProvidersManager + + return ProvidersManager().asset_uri_handlers.get(scheme) + + +def _get_normalized_scheme(uri: str) -> str: + parsed = urllib.parse.urlsplit(uri) + return parsed.scheme.lower() + + +def _sanitize_uri(uri: str) -> str: + """ + Sanitize an asset URI. + + This checks for URI validity, and normalizes the URI if needed. A fully + normalized URI is returned. + """ + parsed = urllib.parse.urlsplit(uri) + if not parsed.scheme and not parsed.netloc: # Does not look like a URI. + return uri + if not (normalized_scheme := _get_normalized_scheme(uri)): + return uri + if normalized_scheme.startswith("x-"): + return uri + if normalized_scheme == "airflow": + raise ValueError("Asset scheme 'airflow' is reserved") + _, auth_exists, normalized_netloc = parsed.netloc.rpartition("@") + if auth_exists: + # TODO: Collect this into a DagWarning. + warnings.warn( + "An Asset URI should not contain auth info (e.g. username or " + "password). It has been automatically dropped.", + UserWarning, + stacklevel=3, + ) + if parsed.query: + normalized_query = urllib.parse.urlencode(sorted(urllib.parse.parse_qsl(parsed.query))) + else: + normalized_query = "" + parsed = parsed._replace( + scheme=normalized_scheme, + netloc=normalized_netloc, + path=parsed.path.rstrip("/") or "/", # Remove all trailing slashes. + query=normalized_query, + fragment="", # Ignore any fragments. + ) + if (normalizer := _get_uri_normalizer(normalized_scheme)) is not None: + parsed = normalizer(parsed) + return urllib.parse.urlunsplit(parsed) + + +def _validate_identifier(instance, attribute, value): + if not isinstance(value, str): + raise ValueError(f"{type(instance).__name__} {attribute.name} must be a string") + if len(value) > 1500: + raise ValueError(f"{type(instance).__name__} {attribute.name} cannot exceed 1500 characters") + if value.isspace(): + raise ValueError(f"{type(instance).__name__} {attribute.name} cannot be just whitespace") + if not value.isascii(): + raise ValueError(f"{type(instance).__name__} {attribute.name} must only consist of ASCII characters") + return value + + +def _validate_non_empty_identifier(instance, attribute, value): + if not _validate_identifier(instance, attribute, value): + raise ValueError(f"{type(instance).__name__} {attribute.name} cannot be empty") + return value + + +def _validate_asset_name(instance, attribute, value): + _validate_non_empty_identifier(instance, attribute, value) + if value == "self" or value == "context": + raise ValueError(f"prohibited name for asset: {value}") + return value + + +def extract_event_key(value: str | Asset | AssetAlias) -> str: + """ + Extract the key of an inlet or an outlet event. + + If the input value is a string, it is treated as a URI and sanitized. If the + input is a :class:`Asset`, the URI it contains is considered sanitized and + returned directly. If the input is a :class:`AssetAlias`, the name it contains + will be returned directly. + + :meta private: + """ + if isinstance(value, AssetAlias): + return value.name + + if isinstance(value, Asset): + return value.uri + return _sanitize_uri(str(value)) + + +@internal_api_call +@provide_session +def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SESSION) -> list[BaseAsset]: + """Expand asset alias to resolved assets.""" + from airflow.models.asset import AssetAliasModel + + alias_name = alias.name if isinstance(alias, AssetAlias) else alias + + asset_alias_obj = session.scalar( + select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1) + ) + if asset_alias_obj: + return [asset.to_public() for asset in asset_alias_obj.assets] + return [] + + +@attrs.define(kw_only=True) +class AssetRef: + """Reference to an asset.""" + + name: str + + +class BaseAsset: + """ + Protocol for all asset triggers to use in ``DAG(schedule=...)``. + + :meta private: + """ + + def __bool__(self) -> bool: + return True + + def __or__(self, other: BaseAsset) -> BaseAsset: + if not isinstance(other, BaseAsset): + return NotImplemented + return AssetAny(self, other) + + def __and__(self, other: BaseAsset) -> BaseAsset: + if not isinstance(other, BaseAsset): + return NotImplemented + return AssetAll(self, other) + + def as_expression(self) -> Any: + """ + Serialize the asset into its scheduling expression. + + The return value is stored in DagModel for display purposes. It must be + JSON-compatible. + + :meta private: + """ + raise NotImplementedError + + def evaluate(self, statuses: dict[str, bool]) -> bool: + raise NotImplementedError + + def iter_assets(self) -> Iterator[tuple[str, Asset]]: + raise NotImplementedError + + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + raise NotImplementedError + + def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: + """ + Iterate a base asset as dag dependency. + + :meta private: + """ + raise NotImplementedError + + +@attrs.define(unsafe_hash=False) +class AssetAlias(BaseAsset): + """A represeation of asset alias which is used to create asset during the runtime.""" + + name: str = attrs.field(validator=_validate_non_empty_identifier) + group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) + + def iter_assets(self) -> Iterator[tuple[str, Asset]]: + return iter(()) + + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + yield self.name, self + + def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: + """ + Iterate an asset alias as dag dependency. + + :meta private: + """ + yield DagDependency( + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", + dependency_id=self.name, + ) + + +class AssetAliasEvent(TypedDict): + """A represeation of asset event to be triggered by an asset alias.""" + + source_alias_name: str + dest_asset_uri: str + extra: dict[str, Any] + + +def _set_extra_default(extra: dict | None) -> dict: + """ + Automatically convert None to an empty dict. + + This allows the caller site to continue doing ``Asset(uri, extra=None)``, + but still allow the ``extra`` attribute to always be a dict. + """ + if extra is None: + return {} + return extra + + +@attrs.define(init=False, unsafe_hash=False) +class Asset(os.PathLike, BaseAsset): + """A representation of data asset dependencies between workflows.""" + + name: str + uri: str + group: str + extra: dict[str, Any] + + asset_type: ClassVar[str] = "asset" + __version__: ClassVar[int] = 1 + + @overload + def __init__(self, name: str, uri: str, *, group: str = "", extra: dict | None = None) -> None: + """Canonical; both name and uri are provided.""" + + @overload + def __init__(self, name: str, *, group: str = "", extra: dict | None = None) -> None: + """It's possible to only provide the name, either by keyword or as the only positional argument.""" + + @overload + def __init__(self, *, uri: str, group: str = "", extra: dict | None = None) -> None: + """It's possible to only provide the URI as a keyword argument.""" + + def __init__( + self, + name: str | None = None, + uri: str | None = None, + *, + group: str = "", + extra: dict | None = None, + ) -> None: + if name is None and uri is None: + raise TypeError("Asset() requires either 'name' or 'uri'") + elif name is None: + name = uri + elif uri is None: + uri = name + fields = attrs.fields_dict(Asset) + self.name = _validate_asset_name(self, fields["name"], name) + self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri)) + self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type + self.extra = _set_extra_default(extra) + + def __fspath__(self) -> str: + return self.uri + + @property + def normalized_uri(self) -> str | None: + """ + Returns the normalized and AIP-60 compliant URI whenever possible. + + If we can't retrieve the scheme from URI or no normalizer is provided or if parsing fails, + it returns None. + + If a normalizer for the scheme exists and parsing is successful we return the normalizer result. + """ + if not (normalized_scheme := _get_normalized_scheme(self.uri)): + return None + + if (normalizer := _get_uri_normalizer(normalized_scheme)) is None: + return None + parsed = urllib.parse.urlsplit(self.uri) + try: + normalized_uri = normalizer(parsed) + return urllib.parse.urlunsplit(normalized_uri) + except ValueError: + return None + + def as_expression(self) -> Any: + """ + Serialize the asset into its scheduling expression. + + :meta private: + """ + return self.uri + + def iter_assets(self) -> Iterator[tuple[str, Asset]]: + yield self.uri, self + + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + return iter(()) + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return statuses.get(self.uri, False) + + def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: + """ + Iterate an asset as dag dependency. + + :meta private: + """ + yield DagDependency( + source=source or "asset", + target=target or "asset", + dependency_type="asset", + dependency_id=self.uri, + ) + + +class Dataset(Asset): + """A representation of dataset dependencies between workflows.""" + + asset_type: ClassVar[str] = "dataset" + + +class Model(Asset): + """A representation of model dependencies between workflows.""" + + asset_type: ClassVar[str] = "model" + + +class _AssetBooleanCondition(BaseAsset): + """Base class for asset boolean logic.""" + + agg_func: Callable[[Iterable], bool] + + def __init__(self, *objects: BaseAsset) -> None: + if not all(isinstance(o, BaseAsset) for o in objects): + raise TypeError("expect asset expressions in condition") + + self.objects = [ + _AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects + ] + + def evaluate(self, statuses: dict[str, bool]) -> bool: + return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects) + + def iter_assets(self) -> Iterator[tuple[str, Asset]]: + seen = set() # We want to keep the first instance. + for o in self.objects: + for k, v in o.iter_assets(): + if k in seen: + continue + yield k, v + seen.add(k) + + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + """Filter asset aliases in the condition.""" + for o in self.objects: + yield from o.iter_asset_aliases() + + def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: + """ + Iterate asset, asset aliases and their resolved assets as dag dependency. + + :meta private: + """ + for obj in self.objects: + yield from obj.iter_dag_dependencies(source=source, target=target) + + +class AssetAny(_AssetBooleanCondition): + """Use to combine assets schedule references in an "and" relationship.""" + + agg_func = any + + def __or__(self, other: BaseAsset) -> BaseAsset: + if not isinstance(other, BaseAsset): + return NotImplemented + # Optimization: X | (Y | Z) is equivalent to X | Y | Z. + return AssetAny(*self.objects, other) + + def __repr__(self) -> str: + return f"AssetAny({', '.join(map(str, self.objects))})" + + def as_expression(self) -> dict[str, Any]: + """ + Serialize the asset into its scheduling expression. + + :meta private: + """ + return {"any": [o.as_expression() for o in self.objects]} + + +class _AssetAliasCondition(AssetAny): + """ + Use to expand AssetAlias as AssetAny of its resolved Assets. + + :meta private: + """ + + def __init__(self, name: str) -> None: + self.name = name + self.objects = expand_alias_to_assets(name) + + def __repr__(self) -> str: + return f"_AssetAliasCondition({', '.join(map(str, self.objects))})" + + def as_expression(self) -> Any: + """ + Serialize the asset alias into its scheduling expression. + + :meta private: + """ + return {"alias": self.name} + + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + yield self.name, AssetAlias(self.name) + + def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> Iterator[DagDependency]: + """ + Iterate an asset alias and its resolved assets as dag dependency. + + :meta private: + """ + if self.objects: + for obj in self.objects: + asset = cast(Asset, obj) + uri = asset.uri + # asset + yield DagDependency( + source=f"asset-alias:{self.name}" if source else "asset", + target="asset" if source else f"asset-alias:{self.name}", + dependency_type="asset", + dependency_id=uri, + ) + # asset alias + yield DagDependency( + source=source or f"asset:{uri}", + target=target or f"asset:{uri}", + dependency_type="asset-alias", + dependency_id=self.name, + ) + else: + yield DagDependency( + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", + dependency_id=self.name, + ) + + +class AssetAll(_AssetBooleanCondition): + """Use to combine assets schedule references in an "or" relationship.""" + + agg_func = all + + def __and__(self, other: BaseAsset) -> BaseAsset: + if not isinstance(other, BaseAsset): + return NotImplemented + # Optimization: X & (Y & Z) is equivalent to X & Y & Z. + return AssetAll(*self.objects, other) + + def __repr__(self) -> str: + return f"AssetAll({', '.join(map(str, self.objects))})" + + def as_expression(self) -> Any: + """ + Serialize the assets into its scheduling expression. + + :meta private: + """ + return {"all": [o.as_expression() for o in self.objects]} diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index d427ddde7984..35d3c0ac9c92 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -45,7 +45,6 @@ from dateutil.relativedelta import relativedelta from airflow import settings -from airflow.assets import Asset, AssetAlias, BaseAsset from airflow.exceptions import ( DuplicateTaskIdFound, FailStopDagInvalidTriggerRule, @@ -54,6 +53,7 @@ ) from airflow.models.param import DagParam, ParamsDict from airflow.sdk.definitions.abstractoperator import AbstractOperator +from airflow.sdk.definitions.asset import Asset, AssetAlias, BaseAsset from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.types import NOTSET from airflow.timetables.base import Timetable @@ -497,7 +497,7 @@ def _validate_max_active_runs(self, _, max_active_runs): @timetable.default def _default_timetable(instance: DAG): - from airflow.assets import AssetAll + from airflow.sdk.definitions.asset import AssetAll schedule = instance.schedule # TODO: Once diff --git a/tests/assets/test_asset.py b/task_sdk/tests/defintions/test_asset.py similarity index 92% rename from tests/assets/test_asset.py rename to task_sdk/tests/defintions/test_asset.py index a454fd2826bd..a2d89fc38a7b 100644 --- a/tests/assets/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -25,9 +25,13 @@ import pytest from sqlalchemy.sql import select -from airflow.assets import ( +from airflow.models.asset import AssetAliasModel, AssetDagRunQueue, AssetModel +from airflow.models.serialized_dag import SerializedDagModel +from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import ( Asset, AssetAlias, + AssetAliasCondition, AssetAll, AssetAny, BaseAsset, @@ -37,11 +41,10 @@ _get_normalized_scheme, _sanitize_uri, ) -from airflow.models.asset import AssetAliasModel, AssetDagRunQueue, AssetModel -from airflow.models.serialized_dag import SerializedDagModel -from airflow.operators.empty import EmptyOperator from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG +ASSET_MODULE_PATH = "airflow.sdk.definitions.asset" + @pytest.fixture def clear_assets(): @@ -269,10 +272,19 @@ def test_asset_logical_conditions_evaluation_and_serialization(inputs, scenario, @pytest.mark.parametrize( "status_values, expected_evaluation", [ - ((False, True, True), False), # AssetAll requires all conditions to be True, but d1 is False + ( + (False, True, True), + False, + ), # AssetAll requires all conditions to be True, but d1 is False ((True, True, True), True), # All conditions are True - ((True, False, True), True), # d1 is True, and AssetAny condition (d2 or d3 being True) is met - ((True, False, False), False), # d1 is True, but neither d2 nor d3 meet the AssetAny condition + ( + (True, False, True), + True, + ), # d1 is True, and AssetAny condition (d2 or d3 being True) is met + ( + (True, False, False), + False, + ), # d1 is True, but neither d2 nor d3 meet the AssetAny condition ], ) def test_nested_asset_conditions_with_serialization(status_values, expected_evaluation): @@ -531,7 +543,10 @@ def normalizer(uri): return normalizer -@patch("airflow.assets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) +@patch( + "airflow.sdk.definitions.asset._get_uri_normalizer", + _mock_get_uri_normalizer_raising_error, +) def test_sanitize_uri_raises_exception(): with pytest.raises(ValueError) as e_info: _sanitize_uri("postgres://localhost:5432/database.schema.table") @@ -539,20 +554,23 @@ def test_sanitize_uri_raises_exception(): assert str(e_info.value) == "Incorrect URI format" -@patch("airflow.assets._get_uri_normalizer", lambda x: None) +@patch("airflow.sdk.definitions.asset._get_uri_normalizer", lambda x: None) def test_normalize_uri_no_normalizer_found(): asset = Asset(uri="any_uri_without_normalizer_defined") assert asset.normalized_uri is None -@patch("airflow.assets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error) +@patch( + "airflow.sdk.definitions.asset._get_uri_normalizer", + _mock_get_uri_normalizer_raising_error, +) def test_normalize_uri_invalid_uri(): asset = Asset(uri="any_uri_not_aip60_compliant") assert asset.normalized_uri is None -@patch("airflow.assets._get_uri_normalizer", _mock_get_uri_normalizer_noop) -@patch("airflow.assets._get_normalized_scheme", lambda x: "valid_scheme") +@patch("airflow.sdk.definitions.asset._get_uri_normalizer", _mock_get_uri_normalizer_noop) +@patch("airflow.sdk.definitions.asset._get_normalized_scheme", lambda x: "valid_scheme") def test_normalize_uri_valid_uri(): asset = Asset(uri="valid_aip60_uri") assert asset.normalized_uri == "valid_aip60_uri" @@ -645,35 +663,3 @@ def test_only_posarg(self, subcls, group, arg): assert obj.name == arg assert obj.uri == arg assert obj.group == group - - -@pytest.mark.parametrize( - "module_path, attr_name, warning_message", - ( - ( - "airflow", - "Dataset", - ( - "Import 'Dataset' directly from the airflow module is deprecated and will be removed in the future. " - "Please import it from 'airflow.assets.Dataset'." - ), - ), - ( - "airflow.datasets", - "Dataset", - ( - "Import from the airflow.dataset module is deprecated and " - "will be removed in the Airflow 3.2. Please import it from 'airflow.assets'." - ), - ), - ), -) -def test_backward_compat_import_before_airflow_3_2(module_path, attr_name, warning_message): - with pytest.warns() as record: - import importlib - - mod = importlib.import_module(module_path, __name__) - getattr(mod, attr_name) - - assert record[0].category is DeprecationWarning - assert str(record[0].message) == warning_message diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index b295c063adc2..1eb92438f493 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -24,12 +24,12 @@ import time_machine from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.assets import Asset from airflow.models.asset import AssetEvent, AssetModel from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun from airflow.models.param import Param from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import Asset from airflow.utils import timezone from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State diff --git a/tests/api_connexion/schemas/test_asset_schema.py b/tests/api_connexion/schemas/test_asset_schema.py index e403e1c6a286..103af2836328 100644 --- a/tests/api_connexion/schemas/test_asset_schema.py +++ b/tests/api_connexion/schemas/test_asset_schema.py @@ -27,9 +27,9 @@ asset_event_schema, asset_schema, ) -from airflow.assets import Asset from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import Asset from tests_common.test_utils.db import clear_db_assets, clear_db_dags diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py index a14365f07c1e..e43c1f2ae769 100644 --- a/tests/api_connexion/schemas/test_dag_schema.py +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -27,9 +27,9 @@ DAGDetailSchema, DAGSchema, ) -from airflow.assets import Asset from airflow.models import DagModel, DagTag from airflow.models.dag import DAG +from airflow.sdk.definitions.asset import Asset UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 7d28a9237e34..89705ba85ab6 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -22,10 +22,10 @@ import pytest from sqlalchemy import select -from airflow import Asset from airflow.models import DagRun from airflow.models.asset import AssetEvent, AssetModel from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import Asset from airflow.utils.session import provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunTriggeredByType, DagRunType diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py b/tests/api_fastapi/core_api/routes/ui/test_assets.py index b5c85b98ba6b..8eafb0f8bdd4 100644 --- a/tests/api_fastapi/core_api/routes/ui/test_assets.py +++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py @@ -18,8 +18,8 @@ import pytest -from airflow.assets import Asset from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import Asset from tests_common.test_utils.db import initial_db_init diff --git a/tests/assets/test_manager.py b/tests/assets/test_manager.py index 1b3e8216a9a2..b37ac6c912f3 100644 --- a/tests/assets/test_manager.py +++ b/tests/assets/test_manager.py @@ -24,7 +24,6 @@ import pytest from sqlalchemy import delete -from airflow.assets import Asset, AssetAlias from airflow.assets.manager import AssetManager from airflow.listeners.listener import get_listener_manager from airflow.models.asset import ( @@ -37,6 +36,7 @@ ) from airflow.models.dag import DagModel from airflow.models.dagbag import DagPriorityParsingRequest +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from tests.listeners import asset_listener diff --git a/tests/dags/test_assets.py b/tests/dags/test_assets.py index 30a6e3f147a5..1fbc67a18d32 100644 --- a/tests/dags/test_assets.py +++ b/tests/dags/test_assets.py @@ -19,11 +19,11 @@ from datetime import datetime -from airflow.assets import Asset from airflow.exceptions import AirflowFailException, AirflowSkipException from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.asset import Asset skip_task_dag_asset = Asset("s3://dag_with_skip_task/output_1.txt", extra={"hi": "bye"}) fail_task_dag_asset = Asset("s3://dag_with_fail_task/output_1.txt", extra={"hi": "bye"}) diff --git a/tests/dags/test_only_empty_tasks.py b/tests/dags/test_only_empty_tasks.py index 2cea9c3c6b17..99c0224e56e8 100644 --- a/tests/dags/test_only_empty_tasks.py +++ b/tests/dags/test_only_empty_tasks.py @@ -20,9 +20,9 @@ from datetime import datetime from typing import Sequence -from airflow.assets import Asset from airflow.models.dag import DAG from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import Asset DEFAULT_DATE = datetime(2016, 1, 1) diff --git a/tests/datasets/__init__.py b/tests/datasets/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/datasets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py new file mode 100644 index 000000000000..de1a9a5cc3a9 --- /dev/null +++ b/tests/datasets/test_dataset.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from __future__ import annotations + +import pytest + + +@pytest.mark.parametrize( + "module_path, attr_name, warning_message", + ( + ( + "airflow", + "Dataset", + ( + "Import 'Dataset' directly from the airflow module is deprecated and will be removed in the future. " + "Please import it from 'airflow.sdk.definitions.asset.Dataset'." + ), + ), + ( + "airflow.datasets", + "Dataset", + ( + "Import from the airflow.dataset module is deprecated and " + "will be removed in the Airflow 3.2. Please import it from 'airflow.sdk.definitions.asset'." + ), + ), + ), +) +def test_backward_compat_import_before_airflow_3_2(module_path, attr_name, warning_message): + with pytest.warns() as record: + import importlib + + mod = importlib.import_module(module_path, __name__) + getattr(mod, attr_name) + + assert record[0].category is DeprecationWarning + assert str(record[0].message) == warning_message diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 668d0be99b65..83c6f8ab4dce 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -985,7 +985,7 @@ def other(x): ... @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode def test_task_decorator_asset(dag_maker, session): - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset result = None uri = "s3://bucket/name" diff --git a/tests/io/test_path.py b/tests/io/test_path.py index 264e3a6d8c15..fd9844bc4bc5 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -28,10 +28,10 @@ from fsspec.implementations.memory import MemoryFileSystem from fsspec.registry import _registry as _fsspec_registry, register_implementation -from airflow.assets import Asset from airflow.io import _register_filesystems, get_fs from airflow.io.path import ObjectStoragePath from airflow.io.store import _STORE_CACHE, ObjectStore, attach +from airflow.sdk.definitions.asset import Asset from airflow.utils.module_loading import qualname FAKE = "file:///fake" diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py index 641eda84d1a4..35469326794e 100644 --- a/tests/io/test_wrapper.py +++ b/tests/io/test_wrapper.py @@ -19,8 +19,8 @@ import uuid from unittest.mock import patch -from airflow.assets import Asset from airflow.io.path import ObjectStoragePath +from airflow.sdk.definitions.asset import Asset @patch("airflow.providers_manager.ProvidersManager") diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 12423157ccb5..c3cac7f08d28 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -39,7 +39,6 @@ import airflow.example_dags from airflow import settings -from airflow.assets import Asset from airflow.assets.manager import AssetManager from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.callbacks.database_callback_sink import DatabaseCallbackSink @@ -66,6 +65,7 @@ from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.bash import BashOperator +from airflow.sdk.definitions.asset import Asset from airflow.serialization.serialized_objects import SerializedDAG from airflow.timetables.base import DataInterval from airflow.utils import timezone diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py index cfb446b8adee..ec6390c77a55 100644 --- a/tests/lineage/test_hook.py +++ b/tests/lineage/test_hook.py @@ -22,7 +22,6 @@ import pytest from airflow import plugins_manager -from airflow.assets import Asset from airflow.hooks.base import BaseHook from airflow.lineage import hook from airflow.lineage.hook import ( @@ -33,6 +32,7 @@ NoOpCollector, get_hook_lineage_collector, ) +from airflow.sdk.definitions.asset import Asset from tests_common.test_utils.mock_plugins import mock_plugin_manager diff --git a/tests/listeners/asset_listener.py b/tests/listeners/asset_listener.py index e7adf580363b..3ceba2d676dd 100644 --- a/tests/listeners/asset_listener.py +++ b/tests/listeners/asset_listener.py @@ -23,7 +23,7 @@ from airflow.listeners import hookimpl if typing.TYPE_CHECKING: - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset changed: list[Asset] = [] diff --git a/tests/listeners/test_asset_listener.py b/tests/listeners/test_asset_listener.py index a075b87a7f3d..52cdc39604d5 100644 --- a/tests/listeners/test_asset_listener.py +++ b/tests/listeners/test_asset_listener.py @@ -18,10 +18,10 @@ import pytest -from airflow.assets import Asset from airflow.listeners.listener import get_listener_manager from airflow.models.asset import AssetModel from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import Asset from airflow.utils.session import provide_session from tests.listeners import asset_listener diff --git a/tests/models/test_asset.py b/tests/models/test_asset.py index 5b35a0c89529..9763f220adef 100644 --- a/tests/models/test_asset.py +++ b/tests/models/test_asset.py @@ -17,8 +17,8 @@ from __future__ import annotations -from airflow.assets import AssetAlias from airflow.models.asset import AssetAliasModel +from airflow.sdk.definitions.asset import AssetAlias class TestAssetAliasModel: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 28f745b0614e..f2128b205b4d 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -36,7 +36,6 @@ from sqlalchemy import inspect, select from airflow import settings -from airflow.assets import Asset, AssetAlias, AssetAll, AssetAny from airflow.configuration import conf from airflow.decorators import setup, task as task_decorator, teardown from airflow.exceptions import ( @@ -71,6 +70,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import TaskGroup +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny from airflow.sdk.definitions.contextmanager import TaskGroupContext from airflow.security import permissions from airflow.templates import NativeEnvironment, SandboxedEnvironment diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 599cb1396d71..fda83190fd6e 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -26,13 +26,13 @@ from sqlalchemy import func, select import airflow.example_dags as example_dags_module -from airflow.assets import Asset from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel as SDM from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.bash import BashOperator +from airflow.sdk.definitions.asset import Asset from airflow.serialization.serialized_objects import SerializedDAG from airflow.settings import json from airflow.utils.hashlib_wrapper import md5 diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index fbe92f006742..81f8ed7e60af 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -39,7 +39,6 @@ from sqlalchemy import select from airflow import settings -from airflow.assets import AssetAlias from airflow.decorators import task, task_group from airflow.exceptions import ( AirflowException, @@ -78,6 +77,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.sensors.python import PythonSensor +from airflow.sdk.definitions.asset import AssetAlias from airflow.sensors.base import BaseSensorOperator from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.settings import TracebackSessionForTests @@ -2441,7 +2441,7 @@ def test_outlet_assets_skipped(self): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_extra(self, dag_maker, session): - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset with dag_maker(schedule=None, session=session) as dag: @@ -2483,7 +2483,7 @@ def _write2_post_execute(context, _): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_extra_ignore_different(self, dag_maker, session): - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset with dag_maker(schedule=None, session=session): @@ -2505,8 +2505,8 @@ def write(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_extra_yield(self, dag_maker, session): - from airflow.assets import Asset - from airflow.assets.metadata import Metadata + from airflow.sdk.definitions.asset import Asset + from airflow.sdk.definitions.asset.metadata import Metadata with dag_maker(schedule=None, session=session) as dag: @@ -2555,7 +2555,7 @@ def _write2_post_execute(context, result): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_alias(self, dag_maker, session): - from airflow.assets import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset, AssetAlias asset_uri = "test_outlet_asset_alias_test_case_ds" alias_name_1 = "test_outlet_asset_alias_test_case_asset_alias_1" @@ -2604,7 +2604,7 @@ def producer(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_multiple_asset_alias(self, dag_maker, session): - from airflow.assets import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset, AssetAlias asset_uri = "test_outlet_maa_ds" asset_alias_name_1 = "test_outlet_maa_asset_alias_1" @@ -2678,8 +2678,8 @@ def producer(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_alias_through_metadata(self, dag_maker, session): - from airflow.assets import AssetAlias - from airflow.assets.metadata import Metadata + from airflow.sdk.definitions.asset import AssetAlias + from airflow.sdk.definitions.asset.metadata import Metadata asset_uri = "test_outlet_asset_alias_through_metadata_ds" asset_alias_name = "test_outlet_asset_alias_through_metadata_asset_alias" @@ -2723,7 +2723,7 @@ def producer(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_alias_asset_not_exists(self, dag_maker, session): - from airflow.assets import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset, AssetAlias asset_alias_name = "test_outlet_asset_alias_asset_not_exists_asset_alias" asset_uri = "did_not_exists" @@ -2763,7 +2763,7 @@ def producer(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_inlet_asset_extra(self, dag_maker, session): - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset read_task_evaluated = False @@ -2826,7 +2826,7 @@ def test_inlet_asset_alias_extra(self, dag_maker, session): session.add_all([asset_model, asset_alias_model]) session.commit() - from airflow.assets import Asset, AssetAlias + from airflow.sdk.definitions.asset import Asset, AssetAlias read_task_evaluated = False @@ -2885,7 +2885,7 @@ def test_inlet_unresolved_asset_alias(self, dag_maker, session): session.add(asset_alias_model) session.commit() - from airflow.assets import AssetAlias + from airflow.sdk.definitions.asset import AssetAlias with dag_maker(schedule=None, session=session): @@ -2916,7 +2916,7 @@ def read(*, inlet_events): ], ) def test_inlet_asset_extra_slice(self, dag_maker, session, slicer, expected): - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset asset_uri = "test_inlet_asset_extra_slice" @@ -2979,7 +2979,7 @@ def test_inlet_asset_alias_extra_slice(self, dag_maker, session, slicer, expecte session.add_all([asset_model, asset_alias_model]) session.commit() - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset with dag_maker(dag_id="write", schedule="@daily", params={"i": -1}, session=session): @@ -3024,7 +3024,7 @@ def test_changing_of_asset_when_adrq_is_already_populated(self, dag_maker): Test that when a task that produces asset has ran, that changing the consumer dag asset will not cause primary key blank-out """ - from airflow.assets import Asset + from airflow.sdk.definitions.asset import Asset with dag_maker(schedule=None, serialized=True) as dag1: diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index f04ff3e2568d..84740af63c15 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -43,7 +43,6 @@ from kubernetes.client import models as k8s import airflow -from airflow.assets import Asset from airflow.decorators import teardown from airflow.decorators.base import DecoratedOperator from airflow.exceptions import ( @@ -65,6 +64,7 @@ from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.sensors.bash import BashSensor +from airflow.sdk.definitions.asset import Asset from airflow.security import permissions from airflow.serialization.enums import Encoding from airflow.serialization.json_schema import load_dag_schema_dict diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py index 11010af86ab9..a3a946124ff9 100644 --- a/tests/serialization/test_serde.py +++ b/tests/serialization/test_serde.py @@ -28,7 +28,7 @@ import pytest from pydantic import BaseModel -from airflow.assets import Asset +from airflow.sdk.definitions.asset import Asset from airflow.serialization.serde import ( CLASSNAME, DATA, @@ -337,7 +337,7 @@ def test_backwards_compat(self): """ uri = "s3://does/not/exist" data = { - "__type": "airflow.assets.Asset", + "__type": "airflow.sdk.definitions.asset.Asset", "__source": None, "__var": { "__var": { diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 96f7414b7765..a7d775f82c3d 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -31,7 +31,6 @@ from pendulum.tz.timezone import Timezone from pydantic import BaseModel -from airflow.assets import Asset, AssetAlias, AssetAliasEvent from airflow.exceptions import ( AirflowException, AirflowFailException, @@ -50,6 +49,7 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.pydantic.asset import AssetEventPydantic, AssetPydantic from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic diff --git a/tests/timetables/test_assets_timetable.py b/tests/timetables/test_assets_timetable.py index bb942a4a01d4..9d572295773a 100644 --- a/tests/timetables/test_assets_timetable.py +++ b/tests/timetables/test_assets_timetable.py @@ -23,8 +23,8 @@ import pytest from pendulum import DateTime -from airflow.assets import Asset, AssetAlias from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel +from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.timetables.assets import AssetOrTimeSchedule from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import AssetTriggeredTimetable diff --git a/tests/utils/test_context.py b/tests/utils/test_context.py index 5d2f7543b629..0e7309075b38 100644 --- a/tests/utils/test_context.py +++ b/tests/utils/test_context.py @@ -20,8 +20,8 @@ import pytest -from airflow.assets import Asset, AssetAlias, AssetAliasEvent from airflow.models.asset import AssetAliasModel, AssetModel +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent from airflow.utils.context import OutletEventAccessor, OutletEventAccessors diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index 5a58b5d79032..b99681c22318 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -26,7 +26,7 @@ import pendulum import pytest -from airflow.assets import Asset +from airflow.sdk.definitions.asset import Asset from airflow.utils import json as utils_json diff --git a/tests/www/views/test_views_asset.py b/tests/www/views/test_views_asset.py index f2e860958ca4..e4fda0aeac66 100644 --- a/tests/www/views/test_views_asset.py +++ b/tests/www/views/test_views_asset.py @@ -22,9 +22,9 @@ import pytest from dateutil.tz import UTC -from airflow.assets import Asset from airflow.models.asset import AssetActive, AssetEvent, AssetModel from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import Asset from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.db import clear_db_assets diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index c94539f05587..c7a453ffaeeb 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -24,12 +24,12 @@ import pytest from dateutil.tz import UTC -from airflow.assets import Asset from airflow.decorators import task_group from airflow.lineage.entities import File from airflow.models import DagBag from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.asset import Asset from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.task_group import TaskGroup From 14b916476ab9081ea3a015ed39cdb68e2b4e92fd Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 15:06:54 +0800 Subject: [PATCH 02/28] feat(task_sdk): Move assets.metadata to task_sdk.definitions.asset --- airflow/assets/metadata.py | 46 ------------------- .../example_outlet_event_extra.py | 3 +- airflow/utils/operator_helpers.py | 2 +- .../authoring-and-scheduling/datasets.rst | 4 +- task_sdk/src/airflow/sdk/definitions/asset.py | 22 +++++++++ tests/models/test_taskinstance.py | 6 +-- 6 files changed, 28 insertions(+), 55 deletions(-) delete mode 100644 airflow/assets/metadata.py diff --git a/airflow/assets/metadata.py b/airflow/assets/metadata.py deleted file mode 100644 index 8feffe389e3e..000000000000 --- a/airflow/assets/metadata.py +++ /dev/null @@ -1,46 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import attrs - -from airflow.sdk.definitions.asset import AssetAlias, extract_event_key - -if TYPE_CHECKING: - from airflow.sdk.definitions.asset import Asset - - -@attrs.define(init=False) -class Metadata: - """Metadata to attach to an AssetEvent.""" - - uri: str - extra: dict[str, Any] - alias_name: str | None = None - - def __init__( - self, target: str | Asset, extra: dict[str, Any], alias: AssetAlias | str | None = None - ) -> None: - self.uri = extract_event_key(target) - self.extra = extra - if isinstance(alias, AssetAlias): - self.alias_name = alias.name - else: - self.alias_name = alias diff --git a/airflow/example_dags/example_outlet_event_extra.py b/airflow/example_dags/example_outlet_event_extra.py index dd3041e18fc0..d07365a9f1f3 100644 --- a/airflow/example_dags/example_outlet_event_extra.py +++ b/airflow/example_dags/example_outlet_event_extra.py @@ -28,8 +28,7 @@ from airflow.decorators import task from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator -from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.asset.metadata import Metadata +from airflow.sdk.definitions.asset import Asset, Metadata ds = Asset("s3://output/1.txt") diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index a06cdf42b50c..82bf9d43cd16 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, Protocol, TypeVar from airflow import settings -from airflow.assets.metadata import Metadata +from airflow.sdk.definitions.asset import Metadata from airflow.typing_compat import ParamSpec from airflow.utils.context import Context, lazy_mapping_from_context from airflow.utils.types import NOTSET diff --git a/docs/apache-airflow/authoring-and-scheduling/datasets.rst b/docs/apache-airflow/authoring-and-scheduling/datasets.rst index 9e777d929958..1794fb972799 100644 --- a/docs/apache-airflow/authoring-and-scheduling/datasets.rst +++ b/docs/apache-airflow/authoring-and-scheduling/datasets.rst @@ -249,7 +249,7 @@ The easiest way to attach extra information to the asset event is by ``yield``-i .. code-block:: python from airflow.sdk.definitions.asset import Asset - from airflow.sdk.definitions.asset.metadata import Metadata + from airflow.sdk.definitions.asset import Metadata example_s3_asset = Asset("s3://asset/example.csv") @@ -452,7 +452,7 @@ The following example creates an asset event against the S3 URI ``f"s3://bucket/ .. code-block:: python - from airflow.sdk.definitions.asset.metadata import Metadata + from airflow.sdk.definitions.asset import Metadata @task(outlets=[AssetAlias("my-task-outputs")]) diff --git a/task_sdk/src/airflow/sdk/definitions/asset.py b/task_sdk/src/airflow/sdk/definitions/asset.py index 761ba50ec46d..faee618cf25d 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset.py +++ b/task_sdk/src/airflow/sdk/definitions/asset.py @@ -530,3 +530,25 @@ def as_expression(self) -> Any: :meta private: """ return {"all": [o.as_expression() for o in self.objects]} + + +@attrs.define(init=False) +class Metadata: + """Metadata to attach to an AssetEvent.""" + + uri: str + extra: dict[str, Any] + alias_name: str | None = None + + def __init__( + self, + target: str | Asset, + extra: dict[str, Any], + alias: AssetAlias | str | None = None, + ) -> None: + self.uri = extract_event_key(target) + self.extra = extra + if isinstance(alias, AssetAlias): + self.alias_name = alias.name + else: + self.alias_name = alias diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 81f8ed7e60af..36297f7b7bce 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2505,8 +2505,7 @@ def write(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_extra_yield(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset - from airflow.sdk.definitions.asset.metadata import Metadata + from airflow.sdk.definitions.asset import Asset, Metadata with dag_maker(schedule=None, session=session) as dag: @@ -2678,8 +2677,7 @@ def producer(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_alias_through_metadata(self, dag_maker, session): - from airflow.sdk.definitions.asset import AssetAlias - from airflow.sdk.definitions.asset.metadata import Metadata + from airflow.sdk.definitions.asset import AssetAlias, Metadata asset_uri = "test_outlet_asset_alias_through_metadata_ds" asset_alias_name = "test_outlet_asset_alias_through_metadata_asset_alias" From c6243de81476204333e613d11ebbf05a7938d05a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 19:32:11 +0800 Subject: [PATCH 03/28] fix(providers/amazon): fix common.compat provider ImportError handling --- .../airflow/providers/amazon/aws/assets/s3.py | 18 +++++++++++++++++- .../aws/auth_manager/test_aws_auth_manager.py | 18 +++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/assets/s3.py b/providers/src/airflow/providers/amazon/aws/assets/s3.py index 4d02b156afb0..c291078155ac 100644 --- a/providers/src/airflow/providers/amazon/aws/assets/s3.py +++ b/providers/src/airflow/providers/amazon/aws/assets/s3.py @@ -19,14 +19,30 @@ from typing import TYPE_CHECKING from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.common.compat.assets import Asset if TYPE_CHECKING: from urllib.parse import SplitResult + from airflow.providers.common.compat.assets import Asset from airflow.providers.common.compat.openlineage.facet import ( Dataset as OpenLineageDataset, ) +else: + # TODO: Remove this try-exception block after bumping common provider to 1.3.0 + # This is due to common provider AssetDetails import error handling + try: + from airflow.providers.common.compat.assets import Asset + except ImportError: + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION + + AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions.asset import Asset + else: + # dataset is renamed to asset since Airflow 3.0 + from airflow.datasets import Dataset as Asset def create_asset(*, bucket: str, key: str, extra=None) -> Asset: diff --git a/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py index acca91221480..0700a71a6191 100644 --- a/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -57,9 +57,25 @@ from airflow.auth.managers.models.resource_details import AssetDetails from airflow.security.permissions import RESOURCE_ASSET else: - from airflow.providers.common.compat.assets import AssetDetails from airflow.providers.common.compat.security.permissions import RESOURCE_ASSET + # TODO: Remove this try-exception block after bumping common provider to 1.3.0 + # This is due to common provider AssetDetails import error handling + try: + from airflow.auth.managers.models.resource_details import AssetDetails + except ModuleNotFoundError: + # 2.7.x + pass + except ImportError: + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION + + _IS_AIRFLOW_2_8_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") + if _IS_AIRFLOW_2_8_OR_HIGHER: + from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails + + pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Test requires Airflow 2.9+"), pytest.mark.skip_if_database_isolation_mode, From 9ff2ca769e2f0a5c391c6bb16b2a3c7003536d8f Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 19:32:31 +0800 Subject: [PATCH 04/28] fix(providers/google): fix common.compat provider ImportError handling --- .../airflow/providers/google/assets/gcs.py | 18 ++++++++++- providers/tests/google/assets/test_gcs.py | 21 ++++++++++++- .../tests/google/cloud/hooks/test_gcs.py | 31 ++++++++++++++----- 3 files changed, 61 insertions(+), 9 deletions(-) diff --git a/providers/src/airflow/providers/google/assets/gcs.py b/providers/src/airflow/providers/google/assets/gcs.py index 4df6995787ec..22206e3f7532 100644 --- a/providers/src/airflow/providers/google/assets/gcs.py +++ b/providers/src/airflow/providers/google/assets/gcs.py @@ -18,13 +18,29 @@ from typing import TYPE_CHECKING -from airflow.providers.common.compat.assets import Asset from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url if TYPE_CHECKING: from urllib.parse import SplitResult + from airflow.providers.common.compat.assets import Asset from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset +else: + # TODO: Remove this try-exception block after bumping common provider to 1.3.0 + # This is due to common provider AssetDetails import error handling + try: + from airflow.providers.common.compat.assets import Asset + except ImportError: + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION + + AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions.asset import Asset + else: + # dataset is renamed to asset since Airflow 3.0 + from airflow.datasets import Dataset as Asset def create_asset(*, bucket: str, key: str, extra: dict | None = None) -> Asset: diff --git a/providers/tests/google/assets/test_gcs.py b/providers/tests/google/assets/test_gcs.py index e9920302b0e0..b72357960d53 100644 --- a/providers/tests/google/assets/test_gcs.py +++ b/providers/tests/google/assets/test_gcs.py @@ -17,12 +17,31 @@ from __future__ import annotations import urllib.parse +from typing import TYPE_CHECKING import pytest -from airflow.providers.common.compat.assets import Asset from airflow.providers.google.assets.gcs import convert_asset_to_openlineage, create_asset, sanitize_uri +if TYPE_CHECKING: + from airflow.providers.common.compat.assets import Asset +else: + # TODO: Remove this try-exception block after bumping common provider to 1.3.0 + # This is due to common provider AssetDetails import error handling + try: + from airflow.providers.common.compat.assets import Asset + except ImportError: + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION + + AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions.asset import Asset + else: + # dataset is renamed to asset since Airflow 3.0 + from airflow.datasets import Dataset as Asset + def test_sanitize_uri(): uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/dir/file.txt")) diff --git a/providers/tests/google/cloud/hooks/test_gcs.py b/providers/tests/google/cloud/hooks/test_gcs.py index 48f8c4858117..b33f53fbc029 100644 --- a/providers/tests/google/cloud/hooks/test_gcs.py +++ b/providers/tests/google/cloud/hooks/test_gcs.py @@ -24,6 +24,7 @@ from collections import namedtuple from datetime import datetime, timedelta from io import BytesIO +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock @@ -36,7 +37,6 @@ from google.cloud.storage.retry import DEFAULT_RETRY from airflow.exceptions import AirflowException -from airflow.providers.common.compat.assets import Asset from airflow.providers.google.cloud.hooks import gcs from airflow.providers.google.cloud.hooks.gcs import _fallback_object_url_to_object_name_and_bucket_name from airflow.providers.google.common.consts import CLIENT_INFO @@ -46,6 +46,25 @@ from providers.tests.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS +if TYPE_CHECKING: + from airflow.providers.common.compat.assets import Asset +else: + # TODO: Remove this try-exception block after bumping common provider to 1.3.0 + # This is due to common provider AssetDetails import error handling + try: + from airflow.providers.common.compat.assets import Asset + except ImportError: + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION + + AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions.asset import Asset + else: + # dataset is renamed to asset since Airflow 3.0 + from airflow.datasets import Dataset as Asset + BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" GCS_STRING = "airflow.providers.google.cloud.hooks.gcs.{}" @@ -424,8 +443,8 @@ def test_copy_exposes_lineage(self, mock_service, mock_copy, hook_lineage_collec mock_copy.return_value = storage.Blob( name=destination_object_name, bucket=storage.Bucket(mock_service, destination_bucket_name) ) - mock_service.return_value.bucket.side_effect = ( - lambda name: source_bucket + mock_service.return_value.bucket.side_effect = lambda name: ( + source_bucket if name == source_bucket_name else storage.Bucket(mock_service, destination_bucket_name) ) @@ -519,10 +538,8 @@ def test_rewrite_exposes_lineage(self, mock_service, hook_lineage_collector): blob = MagicMock(spec=storage.Blob) blob.rewrite = MagicMock(return_value=(None, None, None)) dest_bucket.blob = MagicMock(return_value=blob) - mock_service.return_value.bucket.side_effect = ( - lambda name: storage.Bucket(mock_service, source_bucket_name) - if name == source_bucket_name - else dest_bucket + mock_service.return_value.bucket.side_effect = lambda name: ( + storage.Bucket(mock_service, source_bucket_name) if name == source_bucket_name else dest_bucket ) self.gcs_hook.rewrite( From c52e090642e9dbbc78fa3c8994817b2d723f6e54 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 19:32:46 +0800 Subject: [PATCH 05/28] fix(providers/openlineage): fix common.compat provider ImportError handling --- .../providers/openlineage/utils/utils.py | 17 +++++++++++++++ .../tests/openlineage/plugins/test_utils.py | 21 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 6c411171edb0..720fd6929011 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -68,7 +68,24 @@ from openlineage.client.facet_v2 import RunFacet, processing_engine_run from airflow.models import TaskInstance + from airflow.providers.common.compat.assets import Asset from airflow.utils.state import DagRunState, TaskInstanceState +else: + # TODO: Remove this try-exception block after bumping common provider to 1.3.0 + # This is due to common provider AssetDetails import error handling + try: + from airflow.providers.common.compat.assets import Asset + except ImportError: + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION + + AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions.asset import Asset + else: + # dataset is renamed to asset since Airflow 3.0 + from airflow.datasets import Dataset as Asset log = logging.getLogger(__name__) _NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index e84fac118657..2dd55e655be4 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -20,7 +20,7 @@ import json import uuid from json import JSONEncoder -from typing import Any +from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock, patch import pytest @@ -55,6 +55,25 @@ BashOperator, ) +if TYPE_CHECKING: + from airflow.providers.common.compat.assets import Asset +else: + # TODO: Remove this try-exception block after bumping common provider to 1.3.0 + # This is due to common provider AssetDetails import error handling + try: + from airflow.providers.common.compat.assets import Asset + except ImportError: + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION + + AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") + if AIRFLOW_V_3_0_PLUS: + from airflow.sdk.definitions.asset import Asset + else: + # dataset is renamed to asset since Airflow 3.0 + from airflow.datasets import Dataset as Asset + if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType From 8fc409529347e46cba7ac78def398e3154d84330 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 19:34:42 +0800 Subject: [PATCH 06/28] fix(provider/common/compat): fix common.compat provider ImportError handling --- .../providers/common/compat/assets/__init__.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/providers/src/airflow/providers/common/compat/assets/__init__.py b/providers/src/airflow/providers/common/compat/assets/__init__.py index 47614e2a6c1e..4530f73f595e 100644 --- a/providers/src/airflow/providers/common/compat/assets/__init__.py +++ b/providers/src/airflow/providers/common/compat/assets/__init__.py @@ -34,6 +34,17 @@ else: try: from airflow.auth.managers.models.resource_details import AssetDetails + except ModuleNotFoundError: + # 2.7.x + pass + except ImportError: + from packaging.version import Version + + _IS_AIRFLOW_2_8_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") + if _IS_AIRFLOW_2_8_OR_HIGHER: + from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails + + try: from airflow.sdk.definitions.asset import ( Asset, AssetAlias, @@ -47,14 +58,10 @@ _IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") _IS_AIRFLOW_2_9_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") - _IS_AIRFLOW_2_8_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") # dataset is renamed to asset since Airflow 3.0 from airflow.datasets import Dataset as Asset - if _IS_AIRFLOW_2_8_OR_HIGHER: - from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails - if _IS_AIRFLOW_2_9_OR_HIGHER: from airflow.datasets import ( DatasetAll as AssetAll, From 9d97eaa2d547d115bb02b77deb237c74d2033c9d Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 19:37:19 +0800 Subject: [PATCH 07/28] feat(task_sdk): expose Model --- task_sdk/src/airflow/sdk/definitions/asset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/task_sdk/src/airflow/sdk/definitions/asset.py b/task_sdk/src/airflow/sdk/definitions/asset.py index faee618cf25d..a257685ee216 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset.py +++ b/task_sdk/src/airflow/sdk/definitions/asset.py @@ -44,7 +44,8 @@ from sqlalchemy.orm.session import Session -__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"] + +__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset", "Model"] log = logging.getLogger(__name__) From 8e1e1cdb56f1b3151b8ea5c0ede0a7ae364b0912 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 Nov 2024 19:37:34 +0800 Subject: [PATCH 08/28] docs(nesfragements): update how asset module should be imported --- newsfragments/41348.significant.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/newsfragments/41348.significant.rst b/newsfragments/41348.significant.rst index eeda04d3985c..f966baccc177 100644 --- a/newsfragments/41348.significant.rst +++ b/newsfragments/41348.significant.rst @@ -17,7 +17,7 @@ * Rename class ``DatasetEventCollectionSchema`` as ``AssetEventCollectionSchema`` * Rename class ``CreateDatasetEventSchema`` as ``CreateAssetEventSchema`` -* Rename module ``airflow.datasets`` as ``airflow.assets`` +* Move module ``airflow.datasets`` to ``airflow.sdk.definitions.asset`` * Rename class ``DatasetAlias`` as ``AssetAlias`` * Rename class ``DatasetAll`` as ``AssetAll`` From 4b5452f6354f162201cd12e6f75f6c0724201758 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 8 Nov 2024 11:27:38 +0800 Subject: [PATCH 09/28] fix(task_sdk): fix 2_10 compatibility --- task_sdk/src/airflow/sdk/definitions/asset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/task_sdk/src/airflow/sdk/definitions/asset.py b/task_sdk/src/airflow/sdk/definitions/asset.py index a257685ee216..f093ea68c8f4 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset.py +++ b/task_sdk/src/airflow/sdk/definitions/asset.py @@ -63,8 +63,14 @@ def normalize_noop(parts: SplitResult) -> SplitResult: def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: if scheme == "file": return normalize_noop + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION from airflow.providers_manager import ProvidersManager + AIRFLOW_V_2 = Version(AIRFLOW_VERSION).base_version < Version("3.0.0").base_version + if AIRFLOW_V_2: + return ProvidersManager().dataset_uri_handlers.get(scheme) # type: ignore[attr-defined] return ProvidersManager().asset_uri_handlers.get(scheme) From f5a2f9ae8ed2d6d37dbf16d7211bdc165c4e1f0b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 8 Nov 2024 15:41:55 +0800 Subject: [PATCH 10/28] feat(common.compat): use version to decide how to import assets instead of exception --- .../common/compat/assets/__init__.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/providers/src/airflow/providers/common/compat/assets/__init__.py b/providers/src/airflow/providers/common/compat/assets/__init__.py index 4530f73f595e..ea073840fe00 100644 --- a/providers/src/airflow/providers/common/compat/assets/__init__.py +++ b/providers/src/airflow/providers/common/compat/assets/__init__.py @@ -32,19 +32,15 @@ expand_alias_to_assets, ) else: - try: - from airflow.auth.managers.models.resource_details import AssetDetails - except ModuleNotFoundError: - # 2.7.x - pass - except ImportError: - from packaging.version import Version + from packaging.version import Version - _IS_AIRFLOW_2_8_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") - if _IS_AIRFLOW_2_8_OR_HIGHER: - from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails + AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") + AIRFLOW_V_2_10_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") + AIRFLOW_V_2_9_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") + AIRFLOW_V_2_8_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") - try: + if AIRFLOW_V_3_0_PLUS: + from airflow.auth.managers.models.resource_details import AssetDetails from airflow.sdk.definitions.asset import ( Asset, AssetAlias, @@ -53,22 +49,20 @@ AssetAny, expand_alias_to_assets, ) - except ModuleNotFoundError: - from packaging.version import Version - - _IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") - _IS_AIRFLOW_2_9_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") - + else: # dataset is renamed to asset since Airflow 3.0 from airflow.datasets import Dataset as Asset - if _IS_AIRFLOW_2_9_OR_HIGHER: + if AIRFLOW_V_2_8_PLUS: + from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails + + if AIRFLOW_V_2_9_PLUS: from airflow.datasets import ( DatasetAll as AssetAll, DatasetAny as AssetAny, ) - if _IS_AIRFLOW_2_10_OR_HIGHER: + if AIRFLOW_V_2_10_PLUS: from airflow.datasets import ( DatasetAlias as AssetAlias, DatasetAliasEvent as AssetAliasEvent, From 5e0fe68ff73364a8c5884a502fa9c5a59e194cc2 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 8 Nov 2024 18:10:33 +0800 Subject: [PATCH 11/28] feat(providers/common.compat): use airflow version instead of exception to return compat method --- .../providers/common/compat/lineage/hook.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/providers/src/airflow/providers/common/compat/lineage/hook.py b/providers/src/airflow/providers/common/compat/lineage/hook.py index 50fbc3d0996a..bf080de37ffb 100644 --- a/providers/src/airflow/providers/common/compat/lineage/hook.py +++ b/providers/src/airflow/providers/common/compat/lineage/hook.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from importlib.util import find_spec +from airflow.providers.common.compat.assets import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS def _get_asset_compat_hook_lineage_collector(): @@ -79,28 +79,27 @@ def collected_assets_compat(collector) -> HookLineage: def get_hook_lineage_collector(): - # HookLineageCollector added in 2.10 - try: - if find_spec("airflow.assets"): - # Dataset has been renamed as Asset in 3.0 - from airflow.lineage.hook import get_hook_lineage_collector + # Dataset has been renamed as Asset in 3.0 + if AIRFLOW_V_3_0_PLUS: + from airflow.lineage.hook import get_hook_lineage_collector - return get_hook_lineage_collector() + return get_hook_lineage_collector() + # HookLineageCollector added in 2.10 + if AIRFLOW_V_2_10_PLUS: return _get_asset_compat_hook_lineage_collector() - except ImportError: - class NoOpCollector: - """ - NoOpCollector is a hook lineage collector that does nothing. + class NoOpCollector: + """ + NoOpCollector is a hook lineage collector that does nothing. - It is used when you want to disable lineage collection. - """ + It is used when you want to disable lineage collection. + """ - def add_input_asset(self, *_, **__): - pass + def add_input_asset(self, *_, **__): + pass - def add_output_asset(self, *_, **__): - pass + def add_output_asset(self, *_, **__): + pass - return NoOpCollector() + return NoOpCollector() From 2d64067edfe096ad769c1a7d6dd625e6b6ce8832 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 8 Nov 2024 18:16:12 +0800 Subject: [PATCH 12/28] refactor(providers/common/compat): extract airflow version to __init__ --- .../airflow/providers/common/compat/__init__.py | 14 +++++++++----- .../providers/common/compat/assets/__init__.py | 14 ++++++-------- .../providers/common/compat/lineage/hook.py | 2 +- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/providers/src/airflow/providers/common/compat/__init__.py b/providers/src/airflow/providers/common/compat/__init__.py index 38c5f8c6cdea..1f9eab88c17d 100644 --- a/providers/src/airflow/providers/common/compat/__init__.py +++ b/providers/src/airflow/providers/common/compat/__init__.py @@ -23,17 +23,21 @@ # from __future__ import annotations -import packaging.version +from packaging.version import Version -from airflow import __version__ as airflow_version +from airflow import __version__ as AIRFLOW_VERSION __all__ = ["__version__"] __version__ = "1.2.2" -if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( - "2.8.0" -): + +AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") +AIRFLOW_V_2_10_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") +AIRFLOW_V_2_9_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") +AIRFLOW_V_2_8_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") + +if not AIRFLOW_V_2_8_PLUS: raise RuntimeError( f"The package `apache-airflow-providers-common-compat:{__version__}` needs Apache Airflow 2.8.0+" ) diff --git a/providers/src/airflow/providers/common/compat/assets/__init__.py b/providers/src/airflow/providers/common/compat/assets/__init__.py index ea073840fe00..66178cf0c68d 100644 --- a/providers/src/airflow/providers/common/compat/assets/__init__.py +++ b/providers/src/airflow/providers/common/compat/assets/__init__.py @@ -19,7 +19,12 @@ from typing import TYPE_CHECKING -from airflow import __version__ as AIRFLOW_VERSION +from airflow.providers.common.compat import ( + AIRFLOW_V_2_8_PLUS, + AIRFLOW_V_2_9_PLUS, + AIRFLOW_V_2_10_PLUS, + AIRFLOW_V_3_0_PLUS, +) if TYPE_CHECKING: from airflow.auth.managers.models.resource_details import AssetDetails @@ -32,13 +37,6 @@ expand_alias_to_assets, ) else: - from packaging.version import Version - - AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") - AIRFLOW_V_2_10_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0") - AIRFLOW_V_2_9_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.9.0") - AIRFLOW_V_2_8_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") - if AIRFLOW_V_3_0_PLUS: from airflow.auth.managers.models.resource_details import AssetDetails from airflow.sdk.definitions.asset import ( diff --git a/providers/src/airflow/providers/common/compat/lineage/hook.py b/providers/src/airflow/providers/common/compat/lineage/hook.py index bf080de37ffb..63214a9051c1 100644 --- a/providers/src/airflow/providers/common/compat/lineage/hook.py +++ b/providers/src/airflow/providers/common/compat/lineage/hook.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from airflow.providers.common.compat.assets import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS +from airflow.providers.common.compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS def _get_asset_compat_hook_lineage_collector(): From dd2e888098d0f1cb86143fe2cf39e1695f8f6d67 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 9 Nov 2024 11:05:06 +0800 Subject: [PATCH 13/28] fix(providers): use version compare to decide whether to import asset --- .../src/airflow/providers/common/io/assets/file.py | 10 ++++++++-- .../src/airflow/providers/openlineage/utils/utils.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/providers/src/airflow/providers/common/io/assets/file.py b/providers/src/airflow/providers/common/io/assets/file.py index aeff818bd6ee..6277e48c0a8a 100644 --- a/providers/src/airflow/providers/common/io/assets/file.py +++ b/providers/src/airflow/providers/common/io/assets/file.py @@ -19,9 +19,15 @@ import urllib.parse from typing import TYPE_CHECKING -try: +from packaging.version import Version + +from airflow import __version__ as AIRFLOW_VERSION + +# TODO: Remove version check block after bumping common provider to 1.3.0 +AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") +if AIRFLOW_V_3_0_PLUS: from airflow.sdk.definitions.asset import Asset -except ModuleNotFoundError: +else: from airflow.datasets import Dataset as Asset # type: ignore[no-redef] if TYPE_CHECKING: diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 720fd6929011..17911685b269 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -725,9 +725,15 @@ def translate_airflow_asset(asset: Asset, lineage_context) -> OpenLineageDataset This function returns None if no URI normalizer is defined, no asset converter is found or some core Airflow changes are missing and ImportError is raised. """ - try: + # TODO: Remove version check block after bumping common provider to 1.3.0 + from packaging.version import Version + + from airflow import __version__ as AIRFLOW_VERSION + + AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") + if AIRFLOW_V_3_0_PLUS: from airflow.sdk.definitions.asset import _get_normalized_scheme - except ModuleNotFoundError: + else: try: from airflow.datasets import _get_normalized_scheme # type: ignore[no-redef, attr-defined] except ImportError: From 78e01c65b9192a7a340524d29696a56576a4652f Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 15 Nov 2024 16:34:19 +0800 Subject: [PATCH 14/28] feat(decorators/asset): move @asset to task_sdk --- airflow/decorators/assets.py | 131 ------------------ .../example_dags/example_asset_decorator.py | 4 +- .../src/airflow/sdk/definitions/decorators.py | 115 +++++++++++++++ .../tests/defintions/test_decorators.py | 13 +- 4 files changed, 125 insertions(+), 138 deletions(-) delete mode 100644 airflow/decorators/assets.py rename tests/decorators/test_assets.py => task_sdk/tests/defintions/test_decorators.py (94%) diff --git a/airflow/decorators/assets.py b/airflow/decorators/assets.py deleted file mode 100644 index 2f5052c2d5c9..000000000000 --- a/airflow/decorators/assets.py +++ /dev/null @@ -1,131 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import inspect -from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping - -import attrs - -from airflow.assets import Asset, AssetRef -from airflow.models.asset import _fetch_active_assets_by_name -from airflow.models.dag import DAG, ScheduleArg -from airflow.providers.standard.operators.python import PythonOperator -from airflow.utils.session import create_session - -if TYPE_CHECKING: - from airflow.io.path import ObjectStoragePath - - -class _AssetMainOperator(PythonOperator): - def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: - super().__init__(**kwargs) - self._definition_name = definition_name - self._uri = uri - - def _iter_kwargs( - self, context: Mapping[str, Any], active_assets: dict[str, Asset] - ) -> Iterator[tuple[str, Any]]: - value: Any - for key in inspect.signature(self.python_callable).parameters: - if key == "self": - value = active_assets.get(self._definition_name) - elif key == "context": - value = context - else: - value = active_assets.get(key, Asset(name=key)) - yield key, value - - def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: - active_assets: dict[str, Asset] = {} - asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)] - if "self" in inspect.signature(self.python_callable).parameters: - asset_names.append(self._definition_name) - - if asset_names: - with create_session() as session: - active_assets = _fetch_active_assets_by_name(asset_names, session) - return dict(self._iter_kwargs(context, active_assets)) - - -@attrs.define(kw_only=True) -class AssetDefinition(Asset): - """ - Asset representation from decorating a function with ``@asset``. - - :meta private: - """ - - function: Callable - schedule: ScheduleArg - - def __attrs_post_init__(self) -> None: - parameters = inspect.signature(self.function).parameters - - with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): - _AssetMainOperator( - task_id="__main__", - inlets=[ - AssetRef(name=inlet_asset_name) - for inlet_asset_name in parameters - if inlet_asset_name not in ("self", "context") - ], - outlets=[self.to_asset()], - python_callable=self.function, - definition_name=self.name, - uri=self.uri, - ) - - def to_asset(self) -> Asset: - return Asset( - name=self.name, - uri=self.uri, - group=self.group, - extra=self.extra, - ) - - def serialize(self): - return { - "uri": self.uri, - "name": self.name, - "group": self.group, - "extra": self.extra, - } - - -@attrs.define(kw_only=True) -class asset: - """Create an asset by decorating a materialization function.""" - - schedule: ScheduleArg - uri: str | ObjectStoragePath | None = None - group: str = "" - extra: dict[str, Any] = attrs.field(factory=dict) - - def __call__(self, f: Callable) -> AssetDefinition: - if (name := f.__name__) != f.__qualname__: - raise ValueError("nested function not supported") - - return AssetDefinition( - name=name, - uri=name if self.uri is None else str(self.uri), - group=self.group, - extra=self.extra, - function=f, - schedule=self.schedule, - ) diff --git a/airflow/example_dags/example_asset_decorator.py b/airflow/example_dags/example_asset_decorator.py index b4de09c23146..5be9540faee3 100644 --- a/airflow/example_dags/example_asset_decorator.py +++ b/airflow/example_dags/example_asset_decorator.py @@ -18,9 +18,9 @@ import pendulum -from airflow.assets import Asset from airflow.decorators import dag, task -from airflow.decorators.assets import asset +from airflow.sdk.definitions.asset import Asset +from airflow.sdk.definitions.decorators import asset @asset(uri="s3://bucket/asset1_producer", schedule=None) diff --git a/task_sdk/src/airflow/sdk/definitions/decorators.py b/task_sdk/src/airflow/sdk/definitions/decorators.py index ab73ba0c9242..be40e597ca9b 100644 --- a/task_sdk/src/airflow/sdk/definitions/decorators.py +++ b/task_sdk/src/airflow/sdk/definitions/decorators.py @@ -17,6 +17,22 @@ from __future__ import annotations +import inspect +from collections.abc import Iterator, Mapping +from typing import TYPE_CHECKING, Any, Callable + +import attrs + +from airflow.models.asset import _fetch_active_assets_by_name +from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.asset import Asset, AssetRef +from airflow.sdk.definitions.dag import DAG, ScheduleArg +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from airflow.io.path import ObjectStoragePath + + import sys from types import FunctionType @@ -40,3 +56,102 @@ def fixup_decorator_warning_stack(func: FunctionType): # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to # `warnings.warn` to ignore the decorator. func.__globals__["warnings"] = _autostacklevel_warn() + + +class _AssetMainOperator(PythonOperator): + def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: + super().__init__(**kwargs) + self._definition_name = definition_name + self._uri = uri + + def _iter_kwargs( + self, context: Mapping[str, Any], active_assets: dict[str, Asset] + ) -> Iterator[tuple[str, Any]]: + value: Any + for key in inspect.signature(self.python_callable).parameters: + if key == "self": + value = active_assets.get(self._definition_name) + elif key == "context": + value = context + else: + value = active_assets.get(key, Asset(name=key)) + yield key, value + + def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: + active_assets: dict[str, Asset] = {} + asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)] + if "self" in inspect.signature(self.python_callable).parameters: + asset_names.append(self._definition_name) + + if asset_names: + with create_session() as session: + active_assets = _fetch_active_assets_by_name(asset_names, session) + return dict(self._iter_kwargs(context, active_assets)) + + +@attrs.define(kw_only=True) +class AssetDefinition(Asset): + """ + Asset representation from decorating a function with ``@asset``. + + :meta private: + """ + + function: Callable + schedule: ScheduleArg + + def __attrs_post_init__(self) -> None: + parameters = inspect.signature(self.function).parameters + + with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): + _AssetMainOperator( + task_id="__main__", + inlets=[ + AssetRef(name=inlet_asset_name) + for inlet_asset_name in parameters + if inlet_asset_name not in ("self", "context") + ], + outlets=[self.to_asset()], + python_callable=self.function, + definition_name=self.name, + uri=self.uri, + ) + + def to_asset(self) -> Asset: + return Asset( + name=self.name, + uri=self.uri, + group=self.group, + extra=self.extra, + ) + + def serialize(self): + return { + "uri": self.uri, + "name": self.name, + "group": self.group, + "extra": self.extra, + } + + +@attrs.define(kw_only=True) +class asset: + """Create an asset by decorating a materialization function.""" + + schedule: ScheduleArg + uri: str | ObjectStoragePath | None = None + group: str = "" + extra: dict[str, Any] = attrs.field(factory=dict) + + def __call__(self, f: Callable) -> AssetDefinition: + if (name := f.__name__) != f.__qualname__: + raise ValueError("nested function not supported") + + return AssetDefinition( + name=name, + uri=name if self.uri is None else str(self.uri), + group=self.group, + extra=self.extra, + function=f, + schedule=self.schedule, + ) diff --git a/tests/decorators/test_assets.py b/task_sdk/tests/defintions/test_decorators.py similarity index 94% rename from tests/decorators/test_assets.py rename to task_sdk/tests/defintions/test_decorators.py index a3821140e548..f7b5a0a74658 100644 --- a/tests/decorators/test_assets.py +++ b/task_sdk/tests/defintions/test_decorators.py @@ -21,9 +21,9 @@ import pytest -from airflow.assets import Asset -from airflow.decorators.assets import AssetRef, _AssetMainOperator, asset from airflow.models.asset import AssetActive, AssetModel +from airflow.sdk.definitions.asset import Asset +from airflow.sdk.definitions.decorators import AssetRef, _AssetMainOperator, asset pytestmark = pytest.mark.db_test @@ -119,8 +119,8 @@ def test_serialzie(self, example_asset_definition): "uri": "s3://bucket/object", } - @mock.patch("airflow.decorators.assets._AssetMainOperator") - @mock.patch("airflow.decorators.assets.DAG") + @mock.patch("airflow.sdk.definitions.decorators._AssetMainOperator") + @mock.patch("airflow.sdk.definitions.decorators.DAG") def test__attrs_post_init__( self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset ): @@ -169,7 +169,10 @@ def test_determine_kwargs(self, example_asset_func_with_valid_arg_as_inlet_asset ) assert op.determine_kwargs(context={"k": "v"}) == { "self": Asset( - name="example_asset_func", uri="s3://bucket/object", group="MLModel", extra={"k": "v"} + name="example_asset_func", + uri="s3://bucket/object", + group="MLModel", + extra={"k": "v"}, ), "context": {"k": "v"}, "inlet_asset_1": Asset(name="inlet_asset_1", uri="s3://bucket/object1"), From 763a6d096c4ae86003fbfecf2eee47fc694990ec Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 15 Nov 2024 16:37:35 +0800 Subject: [PATCH 15/28] refactor(asset): rename _AssetAliasCondition as AssetAliasCondition --- airflow/serialization/serialized_objects.py | 3 +-- airflow/timetables/simple.py | 4 ++-- task_sdk/src/airflow/sdk/definitions/asset.py | 6 +++--- task_sdk/tests/defintions/test_asset.py | 13 ++++++------- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d9771a2c401f..dd4e52692f16 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -59,7 +59,6 @@ AssetAny, AssetRef, BaseAsset, - _AssetAliasCondition, ) from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.serialization.dag_dependency import DagDependency @@ -1054,7 +1053,7 @@ def detect_task_dependencies(task: Operator) -> list[DagDependency]: ) ) elif isinstance(obj, AssetAlias): - cond = _AssetAliasCondition(obj.name) + cond = AssetAliasCondition(obj.name) deps.extend(cond.iter_dag_dependencies(source=task.dag_id, target="")) return deps diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py index adba135c5785..8ce498c9e049 100644 --- a/airflow/timetables/simple.py +++ b/airflow/timetables/simple.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Collection, Sequence -from airflow.sdk.definitions.asset import AssetAlias, _AssetAliasCondition +from airflow.sdk.definitions.asset import AssetAlias, AssetAliasCondition from airflow.timetables.base import DagRunInfo, DataInterval, Timetable from airflow.utils import timezone @@ -169,7 +169,7 @@ def __init__(self, assets: BaseAsset) -> None: super().__init__() self.asset_condition = assets if isinstance(self.asset_condition, AssetAlias): - self.asset_condition = _AssetAliasCondition(self.asset_condition.name) + self.asset_condition = AssetAliasCondition(self.asset_condition.name) if not next(self.asset_condition.iter_assets(), False): self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY diff --git a/task_sdk/src/airflow/sdk/definitions/asset.py b/task_sdk/src/airflow/sdk/definitions/asset.py index f093ea68c8f4..8e0224a41770 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset.py +++ b/task_sdk/src/airflow/sdk/definitions/asset.py @@ -405,7 +405,7 @@ def __init__(self, *objects: BaseAsset) -> None: raise TypeError("expect asset expressions in condition") self.objects = [ - _AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects + AssetAliasCondition(obj.name) if isinstance(obj, AssetAlias) else obj for obj in objects ] def evaluate(self, statuses: dict[str, bool]) -> bool: @@ -458,7 +458,7 @@ def as_expression(self) -> dict[str, Any]: return {"any": [o.as_expression() for o in self.objects]} -class _AssetAliasCondition(AssetAny): +class AssetAliasCondition(AssetAny): """ Use to expand AssetAlias as AssetAny of its resolved Assets. @@ -470,7 +470,7 @@ def __init__(self, name: str) -> None: self.objects = expand_alias_to_assets(name) def __repr__(self) -> str: - return f"_AssetAliasCondition({', '.join(map(str, self.objects))})" + return f"AssetAliasCondition({', '.join(map(str, self.objects))})" def as_expression(self) -> Any: """ diff --git a/task_sdk/tests/defintions/test_asset.py b/task_sdk/tests/defintions/test_asset.py index a2d89fc38a7b..9c6b147ff3c6 100644 --- a/task_sdk/tests/defintions/test_asset.py +++ b/task_sdk/tests/defintions/test_asset.py @@ -37,7 +37,6 @@ BaseAsset, Dataset, Model, - _AssetAliasCondition, _get_normalized_scheme, _sanitize_uri, ) @@ -579,7 +578,7 @@ def test_normalize_uri_valid_uri(): @pytest.mark.skip_if_database_isolation_mode @pytest.mark.db_test @pytest.mark.usefixtures("clear_assets") -class Test_AssetAliasCondition: +class TestAssetAliasCondition: @pytest.fixture def asset_1(self, session): """Example asset links to asset alias resolved_asset_alias_2.""" @@ -615,22 +614,22 @@ def resolved_asset_alias_2(self, session, asset_1): return asset_alias_2 def test_init(self, asset_alias_1, asset_1, resolved_asset_alias_2): - cond = _AssetAliasCondition(name=asset_alias_1.name) + cond = AssetAliasCondition(name=asset_alias_1.name) assert cond.objects == [] - cond = _AssetAliasCondition(name=resolved_asset_alias_2.name) + cond = AssetAliasCondition(name=resolved_asset_alias_2.name) assert cond.objects == [Asset(uri=asset_1.uri)] def test_as_expression(self, asset_alias_1, resolved_asset_alias_2): for assset_alias in (asset_alias_1, resolved_asset_alias_2): - cond = _AssetAliasCondition(assset_alias.name) + cond = AssetAliasCondition(assset_alias.name) assert cond.as_expression() == {"alias": assset_alias.name} def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1): - cond = _AssetAliasCondition(asset_alias_1.name) + cond = AssetAliasCondition(asset_alias_1.name) assert cond.evaluate({asset_1.uri: True}) is False - cond = _AssetAliasCondition(resolved_asset_alias_2.name) + cond = AssetAliasCondition(resolved_asset_alias_2.name) assert cond.evaluate({asset_1.uri: True}) is True From 9c6af0dc7d28bda91d3d73697241c2ef20d633ed Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 15 Nov 2024 17:40:25 +0800 Subject: [PATCH 16/28] feat(task_sdk): make airflow.sdk.definitions.decoratos a package --- .../example_dags/example_asset_decorator.py | 2 +- .../sdk/definitions/decorators/__init__.py | 42 +++++++++++++++++++ .../{decorators.py => decorators/asset.py} | 27 +----------- task_sdk/tests/defintions/test_decorators.py | 6 +-- 4 files changed, 47 insertions(+), 30 deletions(-) create mode 100644 task_sdk/src/airflow/sdk/definitions/decorators/__init__.py rename task_sdk/src/airflow/sdk/definitions/{decorators.py => decorators/asset.py} (84%) diff --git a/airflow/example_dags/example_asset_decorator.py b/airflow/example_dags/example_asset_decorator.py index 5be9540faee3..ab1d94552d75 100644 --- a/airflow/example_dags/example_asset_decorator.py +++ b/airflow/example_dags/example_asset_decorator.py @@ -20,7 +20,7 @@ from airflow.decorators import dag, task from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.decorators import asset +from airflow.sdk.definitions.decorators.asset import asset @asset(uri="s3://bucket/asset1_producer", schedule=None) diff --git a/task_sdk/src/airflow/sdk/definitions/decorators/__init__.py b/task_sdk/src/airflow/sdk/definitions/decorators/__init__.py new file mode 100644 index 000000000000..ab73ba0c9242 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/decorators/__init__.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import sys +from types import FunctionType + + +class _autostacklevel_warn: + def __init__(self): + self.warnings = __import__("warnings") + + def __getattr__(self, name: str): + return getattr(self.warnings, name) + + def __dir__(self): + return dir(self.warnings) + + def warn(self, message, category=None, stacklevel=1, source=None): + self.warnings.warn(message, category, stacklevel + 2, source) + + +def fixup_decorator_warning_stack(func: FunctionType): + if func.__globals__.get("warnings") is sys.modules["warnings"]: + # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to + # `warnings.warn` to ignore the decorator. + func.__globals__["warnings"] = _autostacklevel_warn() diff --git a/task_sdk/src/airflow/sdk/definitions/decorators.py b/task_sdk/src/airflow/sdk/definitions/decorators/asset.py similarity index 84% rename from task_sdk/src/airflow/sdk/definitions/decorators.py rename to task_sdk/src/airflow/sdk/definitions/decorators/asset.py index be40e597ca9b..337dc83f8cdc 100644 --- a/task_sdk/src/airflow/sdk/definitions/decorators.py +++ b/task_sdk/src/airflow/sdk/definitions/decorators/asset.py @@ -24,40 +24,15 @@ import attrs from airflow.models.asset import _fetch_active_assets_by_name +from airflow.models.dag import DAG, ScheduleArg from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk.definitions.asset import Asset, AssetRef -from airflow.sdk.definitions.dag import DAG, ScheduleArg from airflow.utils.session import create_session if TYPE_CHECKING: from airflow.io.path import ObjectStoragePath -import sys -from types import FunctionType - - -class _autostacklevel_warn: - def __init__(self): - self.warnings = __import__("warnings") - - def __getattr__(self, name: str): - return getattr(self.warnings, name) - - def __dir__(self): - return dir(self.warnings) - - def warn(self, message, category=None, stacklevel=1, source=None): - self.warnings.warn(message, category, stacklevel + 2, source) - - -def fixup_decorator_warning_stack(func: FunctionType): - if func.__globals__.get("warnings") is sys.modules["warnings"]: - # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to - # `warnings.warn` to ignore the decorator. - func.__globals__["warnings"] = _autostacklevel_warn() - - class _AssetMainOperator(PythonOperator): def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: super().__init__(**kwargs) diff --git a/task_sdk/tests/defintions/test_decorators.py b/task_sdk/tests/defintions/test_decorators.py index f7b5a0a74658..3e0ce1ef68a7 100644 --- a/task_sdk/tests/defintions/test_decorators.py +++ b/task_sdk/tests/defintions/test_decorators.py @@ -23,7 +23,7 @@ from airflow.models.asset import AssetActive, AssetModel from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.decorators import AssetRef, _AssetMainOperator, asset +from airflow.sdk.definitions.decorators.asset import AssetRef, _AssetMainOperator, asset pytestmark = pytest.mark.db_test @@ -119,8 +119,8 @@ def test_serialzie(self, example_asset_definition): "uri": "s3://bucket/object", } - @mock.patch("airflow.sdk.definitions.decorators._AssetMainOperator") - @mock.patch("airflow.sdk.definitions.decorators.DAG") + @mock.patch("airflow.sdk.definitions.decorators.asset._AssetMainOperator") + @mock.patch("airflow.sdk.definitions.decorators.asset.DAG") def test__attrs_post_init__( self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset ): From dc5cb6a5fe36f5536f3ba29d258d58a8b2969e8c Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 15 Nov 2024 22:56:56 +0800 Subject: [PATCH 17/28] Revert "feat(task_sdk): make airflow.sdk.definitions.decoratos a package" This reverts commit 324efc079ec1c5c2618bb19d48d188a1363f3931. --- .../example_dags/example_asset_decorator.py | 2 +- .../{decorators/asset.py => decorators.py} | 27 +++++++++++- .../sdk/definitions/decorators/__init__.py | 42 ------------------- task_sdk/tests/defintions/test_decorators.py | 6 +-- 4 files changed, 30 insertions(+), 47 deletions(-) rename task_sdk/src/airflow/sdk/definitions/{decorators/asset.py => decorators.py} (84%) delete mode 100644 task_sdk/src/airflow/sdk/definitions/decorators/__init__.py diff --git a/airflow/example_dags/example_asset_decorator.py b/airflow/example_dags/example_asset_decorator.py index ab1d94552d75..b7560f213426 100644 --- a/airflow/example_dags/example_asset_decorator.py +++ b/airflow/example_dags/example_asset_decorator.py @@ -20,7 +20,7 @@ from airflow.decorators import dag, task from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.decorators.asset import asset +from airflow.sdk.definitions.asset.decorators import asset @asset(uri="s3://bucket/asset1_producer", schedule=None) diff --git a/task_sdk/src/airflow/sdk/definitions/decorators/asset.py b/task_sdk/src/airflow/sdk/definitions/decorators.py similarity index 84% rename from task_sdk/src/airflow/sdk/definitions/decorators/asset.py rename to task_sdk/src/airflow/sdk/definitions/decorators.py index 337dc83f8cdc..be40e597ca9b 100644 --- a/task_sdk/src/airflow/sdk/definitions/decorators/asset.py +++ b/task_sdk/src/airflow/sdk/definitions/decorators.py @@ -24,15 +24,40 @@ import attrs from airflow.models.asset import _fetch_active_assets_by_name -from airflow.models.dag import DAG, ScheduleArg from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk.definitions.asset import Asset, AssetRef +from airflow.sdk.definitions.dag import DAG, ScheduleArg from airflow.utils.session import create_session if TYPE_CHECKING: from airflow.io.path import ObjectStoragePath +import sys +from types import FunctionType + + +class _autostacklevel_warn: + def __init__(self): + self.warnings = __import__("warnings") + + def __getattr__(self, name: str): + return getattr(self.warnings, name) + + def __dir__(self): + return dir(self.warnings) + + def warn(self, message, category=None, stacklevel=1, source=None): + self.warnings.warn(message, category, stacklevel + 2, source) + + +def fixup_decorator_warning_stack(func: FunctionType): + if func.__globals__.get("warnings") is sys.modules["warnings"]: + # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to + # `warnings.warn` to ignore the decorator. + func.__globals__["warnings"] = _autostacklevel_warn() + + class _AssetMainOperator(PythonOperator): def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: super().__init__(**kwargs) diff --git a/task_sdk/src/airflow/sdk/definitions/decorators/__init__.py b/task_sdk/src/airflow/sdk/definitions/decorators/__init__.py deleted file mode 100644 index ab73ba0c9242..000000000000 --- a/task_sdk/src/airflow/sdk/definitions/decorators/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import sys -from types import FunctionType - - -class _autostacklevel_warn: - def __init__(self): - self.warnings = __import__("warnings") - - def __getattr__(self, name: str): - return getattr(self.warnings, name) - - def __dir__(self): - return dir(self.warnings) - - def warn(self, message, category=None, stacklevel=1, source=None): - self.warnings.warn(message, category, stacklevel + 2, source) - - -def fixup_decorator_warning_stack(func: FunctionType): - if func.__globals__.get("warnings") is sys.modules["warnings"]: - # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to - # `warnings.warn` to ignore the decorator. - func.__globals__["warnings"] = _autostacklevel_warn() diff --git a/task_sdk/tests/defintions/test_decorators.py b/task_sdk/tests/defintions/test_decorators.py index 3e0ce1ef68a7..f7b5a0a74658 100644 --- a/task_sdk/tests/defintions/test_decorators.py +++ b/task_sdk/tests/defintions/test_decorators.py @@ -23,7 +23,7 @@ from airflow.models.asset import AssetActive, AssetModel from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.decorators.asset import AssetRef, _AssetMainOperator, asset +from airflow.sdk.definitions.decorators import AssetRef, _AssetMainOperator, asset pytestmark = pytest.mark.db_test @@ -119,8 +119,8 @@ def test_serialzie(self, example_asset_definition): "uri": "s3://bucket/object", } - @mock.patch("airflow.sdk.definitions.decorators.asset._AssetMainOperator") - @mock.patch("airflow.sdk.definitions.decorators.asset.DAG") + @mock.patch("airflow.sdk.definitions.decorators._AssetMainOperator") + @mock.patch("airflow.sdk.definitions.decorators.DAG") def test__attrs_post_init__( self, DAG, _AssetMainOperator, example_asset_func_with_valid_arg_as_inlet_asset ): From 4d17aa0784802776ca8405b6bc34799b7b144756 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 15 Nov 2024 23:06:35 +0800 Subject: [PATCH 18/28] feat(task_sdk): move asset related logic in airflow.sdk.definitions.decorators to airflow.sdk.definitions.asset.* --- task_sdk/src/airflow/sdk/definitions/asset.py | 142 +++++++++++++++--- .../src/airflow/sdk/definitions/decorators.py | 115 -------------- ...decorators.py => test_asset_decorators.py} | 3 +- 3 files changed, 124 insertions(+), 136 deletions(-) rename task_sdk/tests/defintions/{test_decorators.py => test_asset_decorators.py} (98%) diff --git a/task_sdk/src/airflow/sdk/definitions/asset.py b/task_sdk/src/airflow/sdk/definitions/asset.py index 8e0224a41770..54440c144440 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset.py +++ b/task_sdk/src/airflow/sdk/definitions/asset.py @@ -17,11 +17,12 @@ from __future__ import annotations +import inspect import logging import os import urllib.parse import warnings -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Mapping from typing import ( TYPE_CHECKING, Any, @@ -35,17 +36,21 @@ from sqlalchemy import select from airflow.api_internal.internal_api_call import internal_api_call +from airflow.models.asset import _fetch_active_assets_by_name +from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.dag import DAG, ScheduleArg from airflow.serialization.dag_dependency import DagDependency from airflow.typing_compat import TypedDict -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session if TYPE_CHECKING: from urllib.parse import SplitResult from sqlalchemy.orm.session import Session + from airflow.io.path import ObjectStoragePath -__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset", "Model"] +__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset", "Model", "AssetRef", "asset"] log = logging.getLogger(__name__) @@ -180,11 +185,16 @@ def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SE return [] -@attrs.define(kw_only=True) -class AssetRef: - """Reference to an asset.""" +def _set_extra_default(extra: dict | None) -> dict: + """ + Automatically convert None to an empty dict. - name: str + This allows the caller site to continue doing ``Asset(uri, extra=None)``, + but still allow the ``extra`` attribute to always be a dict. + """ + if extra is None: + return {} + return extra class BaseAsset: @@ -271,18 +281,6 @@ class AssetAliasEvent(TypedDict): extra: dict[str, Any] -def _set_extra_default(extra: dict | None) -> dict: - """ - Automatically convert None to an empty dict. - - This allows the caller site to continue doing ``Asset(uri, extra=None)``, - but still allow the ``extra`` attribute to always be a dict. - """ - if extra is None: - return {} - return extra - - @attrs.define(init=False, unsafe_hash=False) class Asset(os.PathLike, BaseAsset): """A representation of data asset dependencies between workflows.""" @@ -559,3 +557,109 @@ def __init__( self.alias_name = alias.name else: self.alias_name = alias + + +@attrs.define(kw_only=True) +class AssetRef: + """Reference to an asset.""" + + name: str + + +class _AssetMainOperator(PythonOperator): + def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: + super().__init__(**kwargs) + self._definition_name = definition_name + self._uri = uri + + def _iter_kwargs( + self, context: Mapping[str, Any], active_assets: dict[str, Asset] + ) -> Iterator[tuple[str, Any]]: + value: Any + for key in inspect.signature(self.python_callable).parameters: + if key == "self": + value = active_assets.get(self._definition_name) + elif key == "context": + value = context + else: + value = active_assets.get(key, Asset(name=key)) + yield key, value + + def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: + active_assets: dict[str, Asset] = {} + asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)] + if "self" in inspect.signature(self.python_callable).parameters: + asset_names.append(self._definition_name) + + if asset_names: + with create_session() as session: + active_assets = _fetch_active_assets_by_name(asset_names, session) + return dict(self._iter_kwargs(context, active_assets)) + + +@attrs.define(kw_only=True) +class AssetDefinition(Asset): + """ + Asset representation from decorating a function with ``@asset``. + + :meta private: + """ + + function: Callable + schedule: ScheduleArg + + def __attrs_post_init__(self) -> None: + parameters = inspect.signature(self.function).parameters + + with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): + _AssetMainOperator( + task_id="__main__", + inlets=[ + AssetRef(name=inlet_asset_name) + for inlet_asset_name in parameters + if inlet_asset_name not in ("self", "context") + ], + outlets=[self.to_asset()], + python_callable=self.function, + definition_name=self.name, + uri=self.uri, + ) + + def to_asset(self) -> Asset: + return Asset( + name=self.name, + uri=self.uri, + group=self.group, + extra=self.extra, + ) + + def serialize(self): + return { + "uri": self.uri, + "name": self.name, + "group": self.group, + "extra": self.extra, + } + + +@attrs.define(kw_only=True) +class asset: + """Create an asset by decorating a materialization function.""" + + schedule: ScheduleArg + uri: str | ObjectStoragePath | None = None + group: str = "" + extra: dict[str, Any] = attrs.field(factory=dict) + + def __call__(self, f: Callable) -> AssetDefinition: + if (name := f.__name__) != f.__qualname__: + raise ValueError("nested function not supported") + + return AssetDefinition( + name=name, + uri=name if self.uri is None else str(self.uri), + group=self.group, + extra=self.extra, + function=f, + schedule=self.schedule, + ) diff --git a/task_sdk/src/airflow/sdk/definitions/decorators.py b/task_sdk/src/airflow/sdk/definitions/decorators.py index be40e597ca9b..ab73ba0c9242 100644 --- a/task_sdk/src/airflow/sdk/definitions/decorators.py +++ b/task_sdk/src/airflow/sdk/definitions/decorators.py @@ -17,22 +17,6 @@ from __future__ import annotations -import inspect -from collections.abc import Iterator, Mapping -from typing import TYPE_CHECKING, Any, Callable - -import attrs - -from airflow.models.asset import _fetch_active_assets_by_name -from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.asset import Asset, AssetRef -from airflow.sdk.definitions.dag import DAG, ScheduleArg -from airflow.utils.session import create_session - -if TYPE_CHECKING: - from airflow.io.path import ObjectStoragePath - - import sys from types import FunctionType @@ -56,102 +40,3 @@ def fixup_decorator_warning_stack(func: FunctionType): # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to # `warnings.warn` to ignore the decorator. func.__globals__["warnings"] = _autostacklevel_warn() - - -class _AssetMainOperator(PythonOperator): - def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: - super().__init__(**kwargs) - self._definition_name = definition_name - self._uri = uri - - def _iter_kwargs( - self, context: Mapping[str, Any], active_assets: dict[str, Asset] - ) -> Iterator[tuple[str, Any]]: - value: Any - for key in inspect.signature(self.python_callable).parameters: - if key == "self": - value = active_assets.get(self._definition_name) - elif key == "context": - value = context - else: - value = active_assets.get(key, Asset(name=key)) - yield key, value - - def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: - active_assets: dict[str, Asset] = {} - asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)] - if "self" in inspect.signature(self.python_callable).parameters: - asset_names.append(self._definition_name) - - if asset_names: - with create_session() as session: - active_assets = _fetch_active_assets_by_name(asset_names, session) - return dict(self._iter_kwargs(context, active_assets)) - - -@attrs.define(kw_only=True) -class AssetDefinition(Asset): - """ - Asset representation from decorating a function with ``@asset``. - - :meta private: - """ - - function: Callable - schedule: ScheduleArg - - def __attrs_post_init__(self) -> None: - parameters = inspect.signature(self.function).parameters - - with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): - _AssetMainOperator( - task_id="__main__", - inlets=[ - AssetRef(name=inlet_asset_name) - for inlet_asset_name in parameters - if inlet_asset_name not in ("self", "context") - ], - outlets=[self.to_asset()], - python_callable=self.function, - definition_name=self.name, - uri=self.uri, - ) - - def to_asset(self) -> Asset: - return Asset( - name=self.name, - uri=self.uri, - group=self.group, - extra=self.extra, - ) - - def serialize(self): - return { - "uri": self.uri, - "name": self.name, - "group": self.group, - "extra": self.extra, - } - - -@attrs.define(kw_only=True) -class asset: - """Create an asset by decorating a materialization function.""" - - schedule: ScheduleArg - uri: str | ObjectStoragePath | None = None - group: str = "" - extra: dict[str, Any] = attrs.field(factory=dict) - - def __call__(self, f: Callable) -> AssetDefinition: - if (name := f.__name__) != f.__qualname__: - raise ValueError("nested function not supported") - - return AssetDefinition( - name=name, - uri=name if self.uri is None else str(self.uri), - group=self.group, - extra=self.extra, - function=f, - schedule=self.schedule, - ) diff --git a/task_sdk/tests/defintions/test_decorators.py b/task_sdk/tests/defintions/test_asset_decorators.py similarity index 98% rename from task_sdk/tests/defintions/test_decorators.py rename to task_sdk/tests/defintions/test_asset_decorators.py index f7b5a0a74658..f1945d446a27 100644 --- a/task_sdk/tests/defintions/test_decorators.py +++ b/task_sdk/tests/defintions/test_asset_decorators.py @@ -22,8 +22,7 @@ import pytest from airflow.models.asset import AssetActive, AssetModel -from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.decorators import AssetRef, _AssetMainOperator, asset +from airflow.sdk.definitions.asset import Asset, AssetRef, _AssetMainOperator, asset pytestmark = pytest.mark.db_test From 952cdec158937ab085d3381acbb383f525eb4a36 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 16 Nov 2024 11:58:47 +0800 Subject: [PATCH 19/28] refactor(task_sdk): move @asset to airflow.sdk.definitions.asset.decorators --- .../example_outlet_event_extra.py | 3 +- airflow/utils/context.py | 2 +- airflow/utils/operator_helpers.py | 2 +- .../authoring-and-scheduling/datasets.rst | 4 +- .../{asset.py => asset/__init__.py} | 276 +++++------------- .../sdk/definitions/asset/decorators.py | 136 +++++++++ .../airflow/sdk/definitions/asset/metadata.py | 69 +++++ .../tests/defintions/test_asset_decorators.py | 3 +- tests/models/test_taskinstance.py | 6 +- 9 files changed, 287 insertions(+), 214 deletions(-) rename task_sdk/src/airflow/sdk/definitions/{asset.py => asset/__init__.py} (77%) create mode 100644 task_sdk/src/airflow/sdk/definitions/asset/decorators.py create mode 100644 task_sdk/src/airflow/sdk/definitions/asset/metadata.py diff --git a/airflow/example_dags/example_outlet_event_extra.py b/airflow/example_dags/example_outlet_event_extra.py index d07365a9f1f3..dd3041e18fc0 100644 --- a/airflow/example_dags/example_outlet_event_extra.py +++ b/airflow/example_dags/example_outlet_event_extra.py @@ -28,7 +28,8 @@ from airflow.decorators import task from airflow.models.dag import DAG from airflow.providers.standard.operators.bash import BashOperator -from airflow.sdk.definitions.asset import Asset, Metadata +from airflow.sdk.definitions.asset import Asset +from airflow.sdk.definitions.asset.metadata import Metadata ds = Asset("s3://output/1.txt") diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 5e423d4746af..b954a5e1f2f9 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -47,8 +47,8 @@ AssetAlias, AssetAliasEvent, AssetRef, - extract_event_key, ) +from airflow.sdk.definitions.asset.metadata import extract_event_key from airflow.utils.db import LazySelectSequence from airflow.utils.types import NOTSET diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index 82bf9d43cd16..f841d968ad6e 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Collection, Mapping, Protocol, TypeVar from airflow import settings -from airflow.sdk.definitions.asset import Metadata +from airflow.sdk.definitions.asset.metadata import Metadata from airflow.typing_compat import ParamSpec from airflow.utils.context import Context, lazy_mapping_from_context from airflow.utils.types import NOTSET diff --git a/docs/apache-airflow/authoring-and-scheduling/datasets.rst b/docs/apache-airflow/authoring-and-scheduling/datasets.rst index 1794fb972799..9e777d929958 100644 --- a/docs/apache-airflow/authoring-and-scheduling/datasets.rst +++ b/docs/apache-airflow/authoring-and-scheduling/datasets.rst @@ -249,7 +249,7 @@ The easiest way to attach extra information to the asset event is by ``yield``-i .. code-block:: python from airflow.sdk.definitions.asset import Asset - from airflow.sdk.definitions.asset import Metadata + from airflow.sdk.definitions.asset.metadata import Metadata example_s3_asset = Asset("s3://asset/example.csv") @@ -452,7 +452,7 @@ The following example creates an asset event against the S3 URI ``f"s3://bucket/ .. code-block:: python - from airflow.sdk.definitions.asset import Metadata + from airflow.sdk.definitions.asset.metadata import Metadata @task(outlets=[AssetAlias("my-task-outputs")]) diff --git a/task_sdk/src/airflow/sdk/definitions/asset.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py similarity index 77% rename from task_sdk/src/airflow/sdk/definitions/asset.py rename to task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 54440c144440..cb574c3df96d 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -17,12 +17,11 @@ from __future__ import annotations -import inspect import logging import os import urllib.parse import warnings -from collections.abc import Iterable, Iterator, Mapping +from collections.abc import Iterable, Iterator from typing import ( TYPE_CHECKING, Any, @@ -36,21 +35,26 @@ from sqlalchemy import select from airflow.api_internal.internal_api_call import internal_api_call -from airflow.models.asset import _fetch_active_assets_by_name -from airflow.providers.standard.operators.python import PythonOperator -from airflow.sdk.definitions.dag import DAG, ScheduleArg from airflow.serialization.dag_dependency import DagDependency from airflow.typing_compat import TypedDict -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from urllib.parse import SplitResult from sqlalchemy.orm.session import Session - from airflow.io.path import ObjectStoragePath -__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset", "Model", "AssetRef", "asset"] +__all__ = [ + "Asset", + "Dataset", + "Model", + "AssetRef", + "AssetAlias", + "AssetAliasCondition", + "AssetAll", + "AssetAny", +] log = logging.getLogger(__name__) @@ -150,41 +154,6 @@ def _validate_asset_name(instance, attribute, value): return value -def extract_event_key(value: str | Asset | AssetAlias) -> str: - """ - Extract the key of an inlet or an outlet event. - - If the input value is a string, it is treated as a URI and sanitized. If the - input is a :class:`Asset`, the URI it contains is considered sanitized and - returned directly. If the input is a :class:`AssetAlias`, the name it contains - will be returned directly. - - :meta private: - """ - if isinstance(value, AssetAlias): - return value.name - - if isinstance(value, Asset): - return value.uri - return _sanitize_uri(str(value)) - - -@internal_api_call -@provide_session -def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SESSION) -> list[BaseAsset]: - """Expand asset alias to resolved assets.""" - from airflow.models.asset import AssetAliasModel - - alias_name = alias.name if isinstance(alias, AssetAlias) else alias - - asset_alias_obj = session.scalar( - select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1) - ) - if asset_alias_obj: - return [asset.to_public() for asset in asset_alias_obj.assets] - return [] - - def _set_extra_default(extra: dict | None) -> dict: """ Automatically convert None to an empty dict. @@ -246,41 +215,6 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe raise NotImplementedError -@attrs.define(unsafe_hash=False) -class AssetAlias(BaseAsset): - """A represeation of asset alias which is used to create asset during the runtime.""" - - name: str = attrs.field(validator=_validate_non_empty_identifier) - group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) - - def iter_assets(self) -> Iterator[tuple[str, Asset]]: - return iter(()) - - def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: - yield self.name, self - - def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: - """ - Iterate an asset alias as dag dependency. - - :meta private: - """ - yield DagDependency( - source=source or "asset-alias", - target=target or "asset-alias", - dependency_type="asset-alias", - dependency_id=self.name, - ) - - -class AssetAliasEvent(TypedDict): - """A represeation of asset event to be triggered by an asset alias.""" - - source_alias_name: str - dest_asset_uri: str - extra: dict[str, Any] - - @attrs.define(init=False, unsafe_hash=False) class Asset(os.PathLike, BaseAsset): """A representation of data asset dependencies between workflows.""" @@ -381,6 +315,13 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe ) +@attrs.define(kw_only=True) +class AssetRef: + """Reference to an asset.""" + + name: str + + class Dataset(Asset): """A representation of dataset dependencies between workflows.""" @@ -393,6 +334,41 @@ class Model(Asset): asset_type: ClassVar[str] = "model" +@attrs.define(unsafe_hash=False) +class AssetAlias(BaseAsset): + """A represeation of asset alias which is used to create asset during the runtime.""" + + name: str = attrs.field(validator=_validate_non_empty_identifier) + group: str = attrs.field(kw_only=True, default="", validator=_validate_identifier) + + def iter_assets(self) -> Iterator[tuple[str, Asset]]: + return iter(()) + + def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]: + yield self.name, self + + def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]: + """ + Iterate an asset alias as dag dependency. + + :meta private: + """ + yield DagDependency( + source=source or "asset-alias", + target=target or "asset-alias", + dependency_type="asset-alias", + dependency_id=self.name, + ) + + +class AssetAliasEvent(TypedDict): + """A represeation of asset event to be triggered by an asset alias.""" + + source_alias_name: str + dest_asset_uri: str + extra: dict[str, Any] + + class _AssetBooleanCondition(BaseAsset): """Base class for asset boolean logic.""" @@ -456,6 +432,22 @@ def as_expression(self) -> dict[str, Any]: return {"any": [o.as_expression() for o in self.objects]} +@internal_api_call +@provide_session +def expand_alias_to_assets(alias: str | AssetAlias, *, session: Session = NEW_SESSION) -> list[BaseAsset]: + """Expand asset alias to resolved assets.""" + from airflow.models.asset import AssetAliasModel + + alias_name = alias.name if isinstance(alias, AssetAlias) else alias + + asset_alias_obj = session.scalar( + select(AssetAliasModel).where(AssetAliasModel.name == alias_name).limit(1) + ) + if asset_alias_obj: + return [asset.to_public() for asset in asset_alias_obj.assets] + return [] + + class AssetAliasCondition(AssetAny): """ Use to expand AssetAlias as AssetAny of its resolved Assets. @@ -535,131 +527,3 @@ def as_expression(self) -> Any: :meta private: """ return {"all": [o.as_expression() for o in self.objects]} - - -@attrs.define(init=False) -class Metadata: - """Metadata to attach to an AssetEvent.""" - - uri: str - extra: dict[str, Any] - alias_name: str | None = None - - def __init__( - self, - target: str | Asset, - extra: dict[str, Any], - alias: AssetAlias | str | None = None, - ) -> None: - self.uri = extract_event_key(target) - self.extra = extra - if isinstance(alias, AssetAlias): - self.alias_name = alias.name - else: - self.alias_name = alias - - -@attrs.define(kw_only=True) -class AssetRef: - """Reference to an asset.""" - - name: str - - -class _AssetMainOperator(PythonOperator): - def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: - super().__init__(**kwargs) - self._definition_name = definition_name - self._uri = uri - - def _iter_kwargs( - self, context: Mapping[str, Any], active_assets: dict[str, Asset] - ) -> Iterator[tuple[str, Any]]: - value: Any - for key in inspect.signature(self.python_callable).parameters: - if key == "self": - value = active_assets.get(self._definition_name) - elif key == "context": - value = context - else: - value = active_assets.get(key, Asset(name=key)) - yield key, value - - def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: - active_assets: dict[str, Asset] = {} - asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)] - if "self" in inspect.signature(self.python_callable).parameters: - asset_names.append(self._definition_name) - - if asset_names: - with create_session() as session: - active_assets = _fetch_active_assets_by_name(asset_names, session) - return dict(self._iter_kwargs(context, active_assets)) - - -@attrs.define(kw_only=True) -class AssetDefinition(Asset): - """ - Asset representation from decorating a function with ``@asset``. - - :meta private: - """ - - function: Callable - schedule: ScheduleArg - - def __attrs_post_init__(self) -> None: - parameters = inspect.signature(self.function).parameters - - with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): - _AssetMainOperator( - task_id="__main__", - inlets=[ - AssetRef(name=inlet_asset_name) - for inlet_asset_name in parameters - if inlet_asset_name not in ("self", "context") - ], - outlets=[self.to_asset()], - python_callable=self.function, - definition_name=self.name, - uri=self.uri, - ) - - def to_asset(self) -> Asset: - return Asset( - name=self.name, - uri=self.uri, - group=self.group, - extra=self.extra, - ) - - def serialize(self): - return { - "uri": self.uri, - "name": self.name, - "group": self.group, - "extra": self.extra, - } - - -@attrs.define(kw_only=True) -class asset: - """Create an asset by decorating a materialization function.""" - - schedule: ScheduleArg - uri: str | ObjectStoragePath | None = None - group: str = "" - extra: dict[str, Any] = attrs.field(factory=dict) - - def __call__(self, f: Callable) -> AssetDefinition: - if (name := f.__name__) != f.__qualname__: - raise ValueError("nested function not supported") - - return AssetDefinition( - name=name, - uri=name if self.uri is None else str(self.uri), - group=self.group, - extra=self.extra, - function=f, - schedule=self.schedule, - ) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/decorators.py b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py new file mode 100644 index 000000000000..55467c8d63a3 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import inspect +from collections.abc import Iterator, Mapping +from typing import ( + TYPE_CHECKING, + Any, + Callable, +) + +import attrs + +from airflow.models.asset import _fetch_active_assets_by_name +from airflow.models.dag import DAG, ScheduleArg +from airflow.providers.standard.operators.python import PythonOperator +from airflow.sdk.definitions.asset import Asset, AssetRef +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from airflow.io.path import ObjectStoragePath + + +class _AssetMainOperator(PythonOperator): + def __init__(self, *, definition_name: str, uri: str | None = None, **kwargs) -> None: + super().__init__(**kwargs) + self._definition_name = definition_name + self._uri = uri + + def _iter_kwargs( + self, context: Mapping[str, Any], active_assets: dict[str, Asset] + ) -> Iterator[tuple[str, Any]]: + value: Any + for key in inspect.signature(self.python_callable).parameters: + if key == "self": + value = active_assets.get(self._definition_name) + elif key == "context": + value = context + else: + value = active_assets.get(key, Asset(name=key)) + yield key, value + + def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]: + active_assets: dict[str, Asset] = {} + asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)] + if "self" in inspect.signature(self.python_callable).parameters: + asset_names.append(self._definition_name) + + if asset_names: + with create_session() as session: + active_assets = _fetch_active_assets_by_name(asset_names, session) + return dict(self._iter_kwargs(context, active_assets)) + + +@attrs.define(kw_only=True) +class AssetDefinition(Asset): + """ + Asset representation from decorating a function with ``@asset``. + + :meta private: + """ + + function: Callable + schedule: ScheduleArg + + def __attrs_post_init__(self) -> None: + parameters = inspect.signature(self.function).parameters + + with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True): + _AssetMainOperator( + task_id="__main__", + inlets=[ + AssetRef(name=inlet_asset_name) + for inlet_asset_name in parameters + if inlet_asset_name not in ("self", "context") + ], + outlets=[self.to_asset()], + python_callable=self.function, + definition_name=self.name, + uri=self.uri, + ) + + def to_asset(self) -> Asset: + return Asset( + name=self.name, + uri=self.uri, + group=self.group, + extra=self.extra, + ) + + def serialize(self): + return { + "uri": self.uri, + "name": self.name, + "group": self.group, + "extra": self.extra, + } + + +@attrs.define(kw_only=True) +class asset: + """Create an asset by decorating a materialization function.""" + + schedule: ScheduleArg + uri: str | ObjectStoragePath | None = None + group: str = "" + extra: dict[str, Any] = attrs.field(factory=dict) + + def __call__(self, f: Callable) -> AssetDefinition: + if (name := f.__name__) != f.__qualname__: + raise ValueError("nested function not supported") + + return AssetDefinition( + name=name, + uri=name if self.uri is None else str(self.uri), + group=self.group, + extra=self.extra, + function=f, + schedule=self.schedule, + ) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/metadata.py b/task_sdk/src/airflow/sdk/definitions/asset/metadata.py new file mode 100644 index 000000000000..23d5b96dc128 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/asset/metadata.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import ( + Any, +) + +import attrs + +from airflow.sdk.definitions.asset import Asset, AssetAlias, _sanitize_uri + +__all__ = ["Metadata"] + + +def extract_event_key(value: str | Asset | AssetAlias) -> str: + """ + Extract the key of an inlet or an outlet event. + + If the input value is a string, it is treated as a URI and sanitized. If the + input is a :class:`Asset`, the URI it contains is considered sanitized and + returned directly. If the input is a :class:`AssetAlias`, the name it contains + will be returned directly. + + :meta private: + """ + if isinstance(value, AssetAlias): + return value.name + + if isinstance(value, Asset): + return value.uri + return _sanitize_uri(str(value)) + + +@attrs.define(init=False) +class Metadata: + """Metadata to attach to an AssetEvent.""" + + uri: str + extra: dict[str, Any] + alias_name: str | None = None + + def __init__( + self, + target: str | Asset, + extra: dict[str, Any], + alias: AssetAlias | str | None = None, + ) -> None: + self.uri = extract_event_key(target) + self.extra = extra + if isinstance(alias, AssetAlias): + self.alias_name = alias.name + else: + self.alias_name = alias diff --git a/task_sdk/tests/defintions/test_asset_decorators.py b/task_sdk/tests/defintions/test_asset_decorators.py index f1945d446a27..04650bc66444 100644 --- a/task_sdk/tests/defintions/test_asset_decorators.py +++ b/task_sdk/tests/defintions/test_asset_decorators.py @@ -22,7 +22,8 @@ import pytest from airflow.models.asset import AssetActive, AssetModel -from airflow.sdk.definitions.asset import Asset, AssetRef, _AssetMainOperator, asset +from airflow.sdk.definitions.asset import Asset, AssetRef +from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset pytestmark = pytest.mark.db_test diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 36297f7b7bce..81f8ed7e60af 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2505,7 +2505,8 @@ def write(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_extra_yield(self, dag_maker, session): - from airflow.sdk.definitions.asset import Asset, Metadata + from airflow.sdk.definitions.asset import Asset + from airflow.sdk.definitions.asset.metadata import Metadata with dag_maker(schedule=None, session=session) as dag: @@ -2677,7 +2678,8 @@ def producer(*, outlet_events): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_outlet_asset_alias_through_metadata(self, dag_maker, session): - from airflow.sdk.definitions.asset import AssetAlias, Metadata + from airflow.sdk.definitions.asset import AssetAlias + from airflow.sdk.definitions.asset.metadata import Metadata asset_uri = "test_outlet_asset_alias_through_metadata_ds" asset_alias_name = "test_outlet_asset_alias_through_metadata_asset_alias" From 6d86152bdc6770c4fe6d24def47d25029e7a3e8e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 16 Nov 2024 12:08:58 +0800 Subject: [PATCH 20/28] test(providers/amazon): remove unnecessary compat handling --- .../aws/auth_manager/test_aws_auth_manager.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py index 0700a71a6191..e973c8433b27 100644 --- a/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/tests/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -57,24 +57,9 @@ from airflow.auth.managers.models.resource_details import AssetDetails from airflow.security.permissions import RESOURCE_ASSET else: + from airflow.providers.common.compat.assets import AssetDetails from airflow.providers.common.compat.security.permissions import RESOURCE_ASSET - # TODO: Remove this try-exception block after bumping common provider to 1.3.0 - # This is due to common provider AssetDetails import error handling - try: - from airflow.auth.managers.models.resource_details import AssetDetails - except ModuleNotFoundError: - # 2.7.x - pass - except ImportError: - from packaging.version import Version - - from airflow import __version__ as AIRFLOW_VERSION - - _IS_AIRFLOW_2_8_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.8.0") - if _IS_AIRFLOW_2_8_OR_HIGHER: - from airflow.auth.managers.models.resource_details import DatasetDetails as AssetDetails - pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Test requires Airflow 2.9+"), From dd8bfd0102ca08d038f43388bd0c30d98eafb49f Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 16 Nov 2024 12:16:31 +0800 Subject: [PATCH 21/28] test(providers/google): remove unnecessary compat handling --- providers/tests/google/assets/test_gcs.py | 21 +------------------ .../tests/google/cloud/hooks/test_gcs.py | 21 +------------------ 2 files changed, 2 insertions(+), 40 deletions(-) diff --git a/providers/tests/google/assets/test_gcs.py b/providers/tests/google/assets/test_gcs.py index b72357960d53..e9920302b0e0 100644 --- a/providers/tests/google/assets/test_gcs.py +++ b/providers/tests/google/assets/test_gcs.py @@ -17,31 +17,12 @@ from __future__ import annotations import urllib.parse -from typing import TYPE_CHECKING import pytest +from airflow.providers.common.compat.assets import Asset from airflow.providers.google.assets.gcs import convert_asset_to_openlineage, create_asset, sanitize_uri -if TYPE_CHECKING: - from airflow.providers.common.compat.assets import Asset -else: - # TODO: Remove this try-exception block after bumping common provider to 1.3.0 - # This is due to common provider AssetDetails import error handling - try: - from airflow.providers.common.compat.assets import Asset - except ImportError: - from packaging.version import Version - - from airflow import __version__ as AIRFLOW_VERSION - - AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") - if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.definitions.asset import Asset - else: - # dataset is renamed to asset since Airflow 3.0 - from airflow.datasets import Dataset as Asset - def test_sanitize_uri(): uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/dir/file.txt")) diff --git a/providers/tests/google/cloud/hooks/test_gcs.py b/providers/tests/google/cloud/hooks/test_gcs.py index b33f53fbc029..766d5a4120f8 100644 --- a/providers/tests/google/cloud/hooks/test_gcs.py +++ b/providers/tests/google/cloud/hooks/test_gcs.py @@ -24,7 +24,6 @@ from collections import namedtuple from datetime import datetime, timedelta from io import BytesIO -from typing import TYPE_CHECKING from unittest import mock from unittest.mock import MagicMock @@ -37,6 +36,7 @@ from google.cloud.storage.retry import DEFAULT_RETRY from airflow.exceptions import AirflowException +from airflow.providers.google.assets.gcs import Asset from airflow.providers.google.cloud.hooks import gcs from airflow.providers.google.cloud.hooks.gcs import _fallback_object_url_to_object_name_and_bucket_name from airflow.providers.google.common.consts import CLIENT_INFO @@ -46,25 +46,6 @@ from providers.tests.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id from tests_common.test_utils.compat import AIRFLOW_V_2_10_PLUS -if TYPE_CHECKING: - from airflow.providers.common.compat.assets import Asset -else: - # TODO: Remove this try-exception block after bumping common provider to 1.3.0 - # This is due to common provider AssetDetails import error handling - try: - from airflow.providers.common.compat.assets import Asset - except ImportError: - from packaging.version import Version - - from airflow import __version__ as AIRFLOW_VERSION - - AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") - if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.definitions.asset import Asset - else: - # dataset is renamed to asset since Airflow 3.0 - from airflow.datasets import Dataset as Asset - BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" GCS_STRING = "airflow.providers.google.cloud.hooks.gcs.{}" From 14af1b09c147bfe54d9542884e4fa8b555037aa6 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 16 Nov 2024 12:16:45 +0800 Subject: [PATCH 22/28] test(openlineage): remove unnecessary compat handling --- .../tests/openlineage/plugins/test_utils.py | 21 +------------------ 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/providers/tests/openlineage/plugins/test_utils.py b/providers/tests/openlineage/plugins/test_utils.py index 2dd55e655be4..e84fac118657 100644 --- a/providers/tests/openlineage/plugins/test_utils.py +++ b/providers/tests/openlineage/plugins/test_utils.py @@ -20,7 +20,7 @@ import json import uuid from json import JSONEncoder -from typing import TYPE_CHECKING, Any +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -55,25 +55,6 @@ BashOperator, ) -if TYPE_CHECKING: - from airflow.providers.common.compat.assets import Asset -else: - # TODO: Remove this try-exception block after bumping common provider to 1.3.0 - # This is due to common provider AssetDetails import error handling - try: - from airflow.providers.common.compat.assets import Asset - except ImportError: - from packaging.version import Version - - from airflow import __version__ as AIRFLOW_VERSION - - AIRFLOW_V_3_0_PLUS = Version(Version(AIRFLOW_VERSION).base_version) >= Version("3.0.0") - if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.definitions.asset import Asset - else: - # dataset is renamed to asset since Airflow 3.0 - from airflow.datasets import Dataset as Asset - if AIRFLOW_V_3_0_PLUS: from airflow.utils.types import DagRunTriggeredByType From 2dee5f5bec4f031dc7adceaf74299ffe5e93ea54 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sat, 16 Nov 2024 12:17:10 +0800 Subject: [PATCH 23/28] fix(provider/openlineage): fix how asset compat is handled --- providers/src/airflow/providers/openlineage/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/providers/src/airflow/providers/openlineage/utils/utils.py b/providers/src/airflow/providers/openlineage/utils/utils.py index 17911685b269..a37b94b85c19 100644 --- a/providers/src/airflow/providers/openlineage/utils/utils.py +++ b/providers/src/airflow/providers/openlineage/utils/utils.py @@ -37,7 +37,6 @@ # TODO: move this maybe to Airflow's logic? from airflow.models import DAG, BaseOperator, DagRun, MappedOperator -from airflow.providers.common.compat.assets import Asset from airflow.providers.openlineage import __version__ as OPENLINEAGE_PROVIDER_VERSION, conf from airflow.providers.openlineage.plugins.facets import ( AirflowDagRunFacet, From 61b140250cc64c61e4c4f2c87c0ea3ed9727e6de Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 18 Nov 2024 14:05:29 +0800 Subject: [PATCH 24/28] feat(task_sdk/asset): expose extract_event_key --- task_sdk/src/airflow/sdk/definitions/asset/metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task_sdk/src/airflow/sdk/definitions/asset/metadata.py b/task_sdk/src/airflow/sdk/definitions/asset/metadata.py index 23d5b96dc128..088191970396 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/metadata.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/metadata.py @@ -25,7 +25,7 @@ from airflow.sdk.definitions.asset import Asset, AssetAlias, _sanitize_uri -__all__ = ["Metadata"] +__all__ = ["Metadata", "extract_event_key"] def extract_event_key(value: str | Asset | AssetAlias) -> str: From 0ea225826e6921590c97020e253f566356e4e32f Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 18 Nov 2024 14:08:12 +0800 Subject: [PATCH 25/28] test(providers/google): change Asset import back to common.compat --- providers/tests/google/cloud/hooks/test_gcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/tests/google/cloud/hooks/test_gcs.py b/providers/tests/google/cloud/hooks/test_gcs.py index 766d5a4120f8..8dc5966e3d72 100644 --- a/providers/tests/google/cloud/hooks/test_gcs.py +++ b/providers/tests/google/cloud/hooks/test_gcs.py @@ -36,7 +36,7 @@ from google.cloud.storage.retry import DEFAULT_RETRY from airflow.exceptions import AirflowException -from airflow.providers.google.assets.gcs import Asset +from airflow.providers.common.compat.assets import Asset from airflow.providers.google.cloud.hooks import gcs from airflow.providers.google.cloud.hooks.gcs import _fallback_object_url_to_object_name_and_bucket_name from airflow.providers.google.common.consts import CLIENT_INFO From 1d06fe335d2dc44b3f0b3bf3e2916d8522b33d6c Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 19 Nov 2024 14:43:52 +0800 Subject: [PATCH 26/28] docs(newsfragments): fix error naming --- newsfragments/41348.significant.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/newsfragments/41348.significant.rst b/newsfragments/41348.significant.rst index f966baccc177..6e2044f8d196 100644 --- a/newsfragments/41348.significant.rst +++ b/newsfragments/41348.significant.rst @@ -25,7 +25,7 @@ * Rename function ``expand_alias_to_datasets`` as ``expand_alias_to_assets`` * Rename class ``DatasetAliasEvent`` as ``AssetAliasEvent`` - * Rename method ``dest_dataset_uri`` as ``dest_asset_uri`` + * Rename attribute ``dest_dataset_uri`` as ``dest_asset_uri`` * Rename class ``BaseDataset`` as ``BaseAsset`` From fdab604c53da01add052f9b19221db6890a8857a Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 19 Nov 2024 14:51:05 +0800 Subject: [PATCH 27/28] docs(newsfragments): fix typo --- newsfragments/41348.significant.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/newsfragments/41348.significant.rst b/newsfragments/41348.significant.rst index 6e2044f8d196..b2ce9a3aa979 100644 --- a/newsfragments/41348.significant.rst +++ b/newsfragments/41348.significant.rst @@ -52,7 +52,7 @@ * Rename method ``create_datasets`` as ``create_assets`` * Rename method ``register_dataset_change`` as ``notify_asset_created`` * Rename method ``notify_dataset_changed`` as ``notify_asset_changed`` - * Renme method ``notify_dataset_alias_created`` as ``notify_asset_alias_created`` + * Rename method ``notify_dataset_alias_created`` as ``notify_asset_alias_created`` * Rename module ``airflow.models.dataset`` as ``airflow.models.asset`` From 9331e85c92f264646b941497a62655020f8de4a6 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 19 Nov 2024 14:56:42 +0800 Subject: [PATCH 28/28] docs(newsfragment): add missing metadata --- newsfragments/41348.significant.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/newsfragments/41348.significant.rst b/newsfragments/41348.significant.rst index b2ce9a3aa979..eca66b78708f 100644 --- a/newsfragments/41348.significant.rst +++ b/newsfragments/41348.significant.rst @@ -84,7 +84,7 @@ * Rename class ``DatasetPydantic`` as ``AssetPydantic`` * Rename class ``DatasetEventPydantic`` as ``AssetEventPydantic`` -* Rename module ``airflow.datasets.metadata`` as ``airflow.assets.metadata`` +* Rename module ``airflow.datasets.metadata`` as ``airflow.sdk.definitions.asset.metadata`` * In module ``airflow.jobs.scheduler_job_runner``