Skip to content

Commit c874e8d

Browse files
authored
Support using asset ref to emit alias event (#48922)
1 parent e84c177 commit c874e8d

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

task-sdk/src/airflow/sdk/execution_time/context.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,22 +371,27 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEvent
371371

372372

373373
@attrs.define
374-
class OutletEventAccessor:
374+
class OutletEventAccessor(_AssetRefResolutionMixin):
375375
"""Wrapper to access an outlet asset event in template."""
376376

377377
key: BaseAssetUniqueKey
378378
extra: dict[str, Any] = attrs.Factory(dict)
379379
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
380380

381-
def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None:
381+
def add(self, asset: Asset | AssetRef, extra: dict[str, Any] | None = None) -> None:
382382
"""Add an AssetEvent to an existing Asset."""
383383
if not isinstance(self.key, AssetAliasUniqueKey):
384384
return
385385

386+
if isinstance(asset, AssetRef):
387+
asset_key = self._resolve_asset_ref(asset)
388+
else:
389+
asset_key = AssetUniqueKey.from_asset(asset)
390+
386391
asset_alias_name = self.key.name
387392
event = AssetAliasEvent(
388393
source_alias_name=asset_alias_name,
389-
dest_asset_key=AssetUniqueKey.from_asset(asset),
394+
dest_asset_key=asset_key,
390395
extra=extra or {},
391396
)
392397
self.asset_alias_events.append(event)

task-sdk/tests/task_sdk/execution_time/test_context.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,14 @@ def test_nested_context(self):
351351

352352

353353
class TestOutletEventAccessor:
354+
@pytest.mark.parametrize(
355+
"add_arg",
356+
[
357+
Asset("name", "uri"),
358+
Asset.ref(name="name"),
359+
Asset.ref(uri="uri"),
360+
],
361+
)
354362
@pytest.mark.parametrize(
355363
"key, asset_alias_events",
356364
(
@@ -360,21 +368,28 @@ class TestOutletEventAccessor:
360368
[
361369
AssetAliasEvent(
362370
source_alias_name="test_alias",
363-
dest_asset_key=AssetUniqueKey(uri="test_uri", name="test_uri"),
371+
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
364372
extra={},
365373
)
366374
],
367375
),
368376
),
369377
)
370-
def test_add(self, key, asset_alias_events, mock_supervisor_comms):
371-
asset = Asset("test_uri")
372-
mock_supervisor_comms.get_message.return_value = asset
378+
def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
379+
mock_supervisor_comms.get_message.return_value = AssetResponse(name="name", uri="uri", group="")
373380

374381
outlet_event_accessor = OutletEventAccessor(key=key, extra={})
375-
outlet_event_accessor.add(asset)
382+
outlet_event_accessor.add(add_arg)
376383
assert outlet_event_accessor.asset_alias_events == asset_alias_events
377384

385+
@pytest.mark.parametrize(
386+
"add_arg",
387+
[
388+
Asset("name", "uri"),
389+
Asset.ref(name="name"),
390+
Asset.ref(uri="uri"),
391+
],
392+
)
378393
@pytest.mark.parametrize(
379394
"key, asset_alias_events",
380395
(
@@ -384,19 +399,18 @@ def test_add(self, key, asset_alias_events, mock_supervisor_comms):
384399
[
385400
AssetAliasEvent(
386401
source_alias_name="test_alias",
387-
dest_asset_key=AssetUniqueKey(name="test-asset", uri="test://asset-uri/"),
402+
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
388403
extra={},
389404
)
390405
],
391406
),
392407
),
393408
)
394-
def test_add_with_db(self, key, asset_alias_events, mock_supervisor_comms):
395-
asset = Asset(uri="test://asset-uri", name="test-asset")
396-
mock_supervisor_comms.get_message.return_value = asset
409+
def test_add_with_db(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
410+
mock_supervisor_comms.get_message.return_value = AssetResponse(name="name", uri="uri", group="")
397411

398412
outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""})
399-
outlet_event_accessor.add(asset, extra={})
413+
outlet_event_accessor.add(add_arg, extra={})
400414
assert outlet_event_accessor.asset_alias_events == asset_alias_events
401415

402416

0 commit comments

Comments
 (0)