Skip to content

Commit a6da8df

Browse files
authored
AIP-72: Allow retrieving Variable from Task Context (#45431)
1 parent 8639ade commit a6da8df

File tree

10 files changed

+285
-25
lines changed

10 files changed

+285
-25
lines changed

task_sdk/src/airflow/sdk/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"EdgeModifier": ".definitions.edges",
4747
"Label": ".definitions.edges",
4848
"Connection": ".definitions.connection",
49+
"Variable": ".definitions.variable",
4950
}
5051

5152

task_sdk/src/airflow/sdk/api/client.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,20 @@ class VariableOperations:
189189
def __init__(self, client: Client):
190190
self.client = client
191191

192-
def get(self, key: str) -> VariableResponse:
192+
def get(self, key: str) -> VariableResponse | ErrorResponse:
193193
"""Get a variable from the API server."""
194-
resp = self.client.get(f"variables/{key}")
194+
try:
195+
resp = self.client.get(f"variables/{key}")
196+
except ServerResponseError as e:
197+
if e.response.status_code == HTTPStatus.NOT_FOUND:
198+
log.error(
199+
"Variable not found",
200+
key=key,
201+
detail=e.detail,
202+
status_code=e.response.status_code,
203+
)
204+
return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key})
205+
raise
195206
return VariableResponse.model_validate_json(resp.read())
196207

197208
def set(self, key: str, value: str | None, description: str | None = None):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from typing import Any
21+
22+
import attrs
23+
24+
25+
@attrs.define
26+
class Variable:
27+
"""
28+
A generic way to store and retrieve arbitrary content or settings as a simple key/value store.
29+
30+
:param key: The variable key.
31+
:param value: The variable value.
32+
:param description: The variable description.
33+
34+
"""
35+
36+
key: str
37+
# keeping as any for supporting deserialize_json
38+
value: Any | None = None
39+
description: str | None = None
40+
41+
# TODO: Extend this definition for reading/writing variables without context

task_sdk/src/airflow/sdk/execution_time/comms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable
121121
VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult
122122
for communication between the Supervisor and the task process.
123123
"""
124-
return cls(**variable_response.model_dump())
124+
return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")
125125

126126

127127
class ErrorResponse(BaseModel):

task_sdk/src/airflow/sdk/execution_time/context.py

+61-2
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,31 @@
2121
import structlog
2222

2323
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
24+
from airflow.sdk.types import NOTSET
2425

2526
if TYPE_CHECKING:
2627
from airflow.sdk.definitions.connection import Connection
27-
from airflow.sdk.execution_time.comms import ConnectionResult
28+
from airflow.sdk.definitions.variable import Variable
29+
from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult
2830

2931

30-
def _convert_connection_result_conn(conn_result: ConnectionResult):
32+
def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection:
3133
from airflow.sdk.definitions.connection import Connection
3234

3335
# `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model
3436
return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
3537

3638

39+
def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable:
40+
from airflow.sdk.definitions.variable import Variable
41+
42+
if deserialize_json:
43+
import json
44+
45+
var_result.value = json.loads(var_result.value) # type: ignore
46+
return Variable(**var_result.model_dump(exclude={"type"}))
47+
48+
3749
def _get_connection(conn_id: str) -> Connection:
3850
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
3951
# or `airflow.sdk.execution_time.connection`
@@ -54,6 +66,26 @@ def _get_connection(conn_id: str) -> Connection:
5466
return _convert_connection_result_conn(msg)
5567

5668

69+
def _get_variable(key: str, deserialize_json: bool) -> Variable:
70+
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
71+
# or `airflow.sdk.execution_time.variable`
72+
# A reason to not move it to `airflow.sdk.execution_time.comms` is that it
73+
# will make that module depend on Task SDK, which is not ideal because we intend to
74+
# keep Task SDK as a separate package than execution time mods.
75+
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable
76+
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
77+
78+
log = structlog.get_logger(logger_name="task")
79+
SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
80+
msg = SUPERVISOR_COMMS.get_message()
81+
if isinstance(msg, ErrorResponse):
82+
raise AirflowRuntimeError(msg)
83+
84+
if TYPE_CHECKING:
85+
assert isinstance(msg, VariableResult)
86+
return _convert_variable_result_to_variable(msg, deserialize_json)
87+
88+
5789
class ConnectionAccessor:
5890
"""Wrapper to access Connection entries in template."""
5991

@@ -76,3 +108,30 @@ def get(self, conn_id: str, default_conn: Any = None) -> Any:
76108
if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
77109
return default_conn
78110
raise
111+
112+
113+
class VariableAccessor:
114+
"""Wrapper to access Variable values in template."""
115+
116+
def __init__(self, deserialize_json: bool) -> None:
117+
self._deserialize_json = deserialize_json
118+
119+
def __eq__(self, other):
120+
if not isinstance(other, VariableAccessor):
121+
return False
122+
# All instances of VariableAccessor are equal since it is a stateless dynamic accessor
123+
return True
124+
125+
def __repr__(self) -> str:
126+
return "<VariableAccessor (dynamic access)>"
127+
128+
def __getattr__(self, key: str) -> Any:
129+
return _get_variable(key, self._deserialize_json)
130+
131+
def get(self, key, default_var: Any = NOTSET) -> Any:
132+
try:
133+
return _get_variable(key, self._deserialize_json)
134+
except AirflowRuntimeError as e:
135+
if e.error.error == ErrorType.VARIABLE_NOT_FOUND:
136+
return default_var
137+
raise

task_sdk/src/airflow/sdk/execution_time/supervisor.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
IntermediateTIState,
5959
TaskInstance,
6060
TerminalTIState,
61+
VariableResponse,
6162
)
6263
from airflow.sdk.execution_time.comms import (
6364
ConnectionResult,
@@ -722,8 +723,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
722723
resp = conn.model_dump_json().encode()
723724
elif isinstance(msg, GetVariable):
724725
var = self.client.variables.get(msg.key)
725-
var_result = VariableResult.from_variable_response(var)
726-
resp = var_result.model_dump_json().encode()
726+
if isinstance(var, VariableResponse):
727+
var_result = VariableResult.from_variable_response(var)
728+
resp = var_result.model_dump_json(exclude_unset=True).encode()
729+
else:
730+
resp = var.model_dump_json().encode()
727731
elif isinstance(msg, GetXCom):
728732
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
729733
xcom_result = XComResult.from_xcom_response(xcom)

task_sdk/src/airflow/sdk/execution_time/task_runner.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
ToTask,
4545
XComResult,
4646
)
47-
from airflow.sdk.execution_time.context import ConnectionAccessor
47+
from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor
4848

4949
if TYPE_CHECKING:
5050
from structlog.typing import FilteringBoundLogger as Logger
@@ -85,10 +85,10 @@ def get_template_context(self):
8585
# "prev_end_date_success": get_prev_end_date_success(),
8686
# "test_mode": task_instance.test_mode,
8787
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
88-
# "var": {
89-
# "json": VariableAccessor(deserialize_json=True),
90-
# "value": VariableAccessor(deserialize_json=False),
91-
# },
88+
"var": {
89+
"json": VariableAccessor(deserialize_json=True),
90+
"value": VariableAccessor(deserialize_json=False),
91+
},
9292
"conn": ConnectionAccessor(),
9393
}
9494
if self._ti_context_from_server:

task_sdk/tests/api/test_client.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -362,16 +362,32 @@ def handle_request(request: httpx.Request) -> httpx.Response:
362362

363363
client = make_client(transport=httpx.MockTransport(handle_request))
364364

365-
with pytest.raises(ServerResponseError) as err:
366-
client.variables.get(key="non_existent_var")
367-
368-
assert err.value.response.status_code == 404
369-
assert err.value.detail == {
370-
"detail": {
371-
"message": "Variable with key 'non_existent_var' not found",
372-
"reason": "not_found",
373-
}
374-
}
365+
resp = client.variables.get(key="non_existent_var")
366+
367+
assert isinstance(resp, ErrorResponse)
368+
assert resp.error == ErrorType.VARIABLE_NOT_FOUND
369+
assert resp.detail == {"key": "non_existent_var"}
370+
371+
@mock.patch("time.sleep", return_value=None)
372+
def test_variable_get_500_error(self, mock_sleep):
373+
# Simulate a response from the server returning a 500 error
374+
def handle_request(request: httpx.Request) -> httpx.Response:
375+
if request.url.path == "/variables/test_key":
376+
return httpx.Response(
377+
status_code=500,
378+
headers=[("content-Type", "application/json")],
379+
json={
380+
"reason": "internal_server_error",
381+
"message": "Internal Server Error",
382+
},
383+
)
384+
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
385+
386+
client = make_client(transport=httpx.MockTransport(handle_request))
387+
with pytest.raises(ServerResponseError):
388+
client.variables.get(
389+
key="test_key",
390+
)
375391

376392
def test_variable_set_success(self):
377393
# Simulate a successful response from the server when putting a variable

task_sdk/tests/execution_time/test_context.py

+74-2
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@
1818
from __future__ import annotations
1919

2020
from airflow.sdk.definitions.connection import Connection
21+
from airflow.sdk.definitions.variable import Variable
2122
from airflow.sdk.exceptions import ErrorType
22-
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
23-
from airflow.sdk.execution_time.context import ConnectionAccessor, _convert_connection_result_conn
23+
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult
24+
from airflow.sdk.execution_time.context import (
25+
ConnectionAccessor,
26+
VariableAccessor,
27+
_convert_connection_result_conn,
28+
_convert_variable_result_to_variable,
29+
)
2430

2531

2632
def test_convert_connection_result_conn():
@@ -48,6 +54,31 @@ def test_convert_connection_result_conn():
4854
)
4955

5056

57+
def test_convert_variable_result_to_variable():
58+
"""Test that the VariableResult is converted to a Variable object."""
59+
var = VariableResult(
60+
key="test_key",
61+
value="test_value",
62+
)
63+
var = _convert_variable_result_to_variable(var, deserialize_json=False)
64+
assert var == Variable(
65+
key="test_key",
66+
value="test_value",
67+
)
68+
69+
70+
def test_convert_variable_result_to_variable_with_deserialize_json():
71+
"""Test that the VariableResult is converted to a Variable object with deserialize_json set to True."""
72+
var = VariableResult(
73+
key="test_key",
74+
value='{\r\n "key1": "value1",\r\n "key2": "value2",\r\n "enabled": true,\r\n "threshold": 42\r\n}',
75+
)
76+
var = _convert_variable_result_to_variable(var, deserialize_json=True)
77+
assert var == Variable(
78+
key="test_key", value={"key1": "value1", "key2": "value2", "enabled": True, "threshold": 42}
79+
)
80+
81+
5182
class TestConnectionAccessor:
5283
def test_getattr_connection(self, mock_supervisor_comms):
5384
"""
@@ -90,3 +121,44 @@ def test_get_method_with_default(self, mock_supervisor_comms):
90121

91122
conn = accessor.get("nonexistent_conn", default_conn=default_conn)
92123
assert conn == default_conn
124+
125+
126+
class TestVariableAccessor:
127+
def test_getattr_variable(self, mock_supervisor_comms):
128+
"""
129+
Test that the variable is fetched when accessed via __getattr__.
130+
"""
131+
accessor = VariableAccessor(deserialize_json=False)
132+
133+
# Variable from the supervisor / API Server
134+
var_result = VariableResult(key="test_key", value="test_value")
135+
136+
mock_supervisor_comms.get_message.return_value = var_result
137+
138+
# Fetch the variable; triggers __getattr__
139+
var = accessor.test_key
140+
141+
expected_var = Variable(key="test_key", value="test_value")
142+
assert var == expected_var
143+
144+
def test_get_method_valid_variable(self, mock_supervisor_comms):
145+
"""Test that the get method returns the requested variable using `var.get`."""
146+
accessor = VariableAccessor(deserialize_json=False)
147+
var_result = VariableResult(key="test_key", value="test_value")
148+
149+
mock_supervisor_comms.get_message.return_value = var_result
150+
151+
var = accessor.get("test_key")
152+
assert var == Variable(key="test_key", value="test_value")
153+
154+
def test_get_method_with_default(self, mock_supervisor_comms):
155+
"""Test that the get method returns the default variable when the requested variable is not found."""
156+
157+
accessor = VariableAccessor(deserialize_json=False)
158+
default_var = {"default_key": "default_value"}
159+
error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"})
160+
161+
mock_supervisor_comms.get_message.return_value = error_response
162+
163+
var = accessor.get("nonexistent_var_key", default_var=default_var)
164+
assert var == default_var

0 commit comments

Comments
 (0)