|
18 | 18 | from __future__ import annotations
|
19 | 19 |
|
20 | 20 | from airflow.sdk.definitions.connection import Connection
|
| 21 | +from airflow.sdk.definitions.variable import Variable |
21 | 22 | from airflow.sdk.exceptions import ErrorType
|
22 |
| -from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse |
23 |
| -from airflow.sdk.execution_time.context import ConnectionAccessor, _convert_connection_result_conn |
| 23 | +from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult |
| 24 | +from airflow.sdk.execution_time.context import ( |
| 25 | + ConnectionAccessor, |
| 26 | + VariableAccessor, |
| 27 | + _convert_connection_result_conn, |
| 28 | + _convert_variable_result_to_variable, |
| 29 | +) |
24 | 30 |
|
25 | 31 |
|
26 | 32 | def test_convert_connection_result_conn():
|
@@ -48,6 +54,31 @@ def test_convert_connection_result_conn():
|
48 | 54 | )
|
49 | 55 |
|
50 | 56 |
|
| 57 | +def test_convert_variable_result_to_variable(): |
| 58 | + """Test that the VariableResult is converted to a Variable object.""" |
| 59 | + var = VariableResult( |
| 60 | + key="test_key", |
| 61 | + value="test_value", |
| 62 | + ) |
| 63 | + var = _convert_variable_result_to_variable(var, deserialize_json=False) |
| 64 | + assert var == Variable( |
| 65 | + key="test_key", |
| 66 | + value="test_value", |
| 67 | + ) |
| 68 | + |
| 69 | + |
| 70 | +def test_convert_variable_result_to_variable_with_deserialize_json(): |
| 71 | + """Test that the VariableResult is converted to a Variable object with deserialize_json set to True.""" |
| 72 | + var = VariableResult( |
| 73 | + key="test_key", |
| 74 | + value='{\r\n "key1": "value1",\r\n "key2": "value2",\r\n "enabled": true,\r\n "threshold": 42\r\n}', |
| 75 | + ) |
| 76 | + var = _convert_variable_result_to_variable(var, deserialize_json=True) |
| 77 | + assert var == Variable( |
| 78 | + key="test_key", value={"key1": "value1", "key2": "value2", "enabled": True, "threshold": 42} |
| 79 | + ) |
| 80 | + |
| 81 | + |
51 | 82 | class TestConnectionAccessor:
|
52 | 83 | def test_getattr_connection(self, mock_supervisor_comms):
|
53 | 84 | """
|
@@ -90,3 +121,44 @@ def test_get_method_with_default(self, mock_supervisor_comms):
|
90 | 121 |
|
91 | 122 | conn = accessor.get("nonexistent_conn", default_conn=default_conn)
|
92 | 123 | assert conn == default_conn
|
| 124 | + |
| 125 | + |
| 126 | +class TestVariableAccessor: |
| 127 | + def test_getattr_variable(self, mock_supervisor_comms): |
| 128 | + """ |
| 129 | + Test that the variable is fetched when accessed via __getattr__. |
| 130 | + """ |
| 131 | + accessor = VariableAccessor(deserialize_json=False) |
| 132 | + |
| 133 | + # Variable from the supervisor / API Server |
| 134 | + var_result = VariableResult(key="test_key", value="test_value") |
| 135 | + |
| 136 | + mock_supervisor_comms.get_message.return_value = var_result |
| 137 | + |
| 138 | + # Fetch the variable; triggers __getattr__ |
| 139 | + var = accessor.test_key |
| 140 | + |
| 141 | + expected_var = Variable(key="test_key", value="test_value") |
| 142 | + assert var == expected_var |
| 143 | + |
| 144 | + def test_get_method_valid_variable(self, mock_supervisor_comms): |
| 145 | + """Test that the get method returns the requested variable using `var.get`.""" |
| 146 | + accessor = VariableAccessor(deserialize_json=False) |
| 147 | + var_result = VariableResult(key="test_key", value="test_value") |
| 148 | + |
| 149 | + mock_supervisor_comms.get_message.return_value = var_result |
| 150 | + |
| 151 | + var = accessor.get("test_key") |
| 152 | + assert var == Variable(key="test_key", value="test_value") |
| 153 | + |
| 154 | + def test_get_method_with_default(self, mock_supervisor_comms): |
| 155 | + """Test that the get method returns the default variable when the requested variable is not found.""" |
| 156 | + |
| 157 | + accessor = VariableAccessor(deserialize_json=False) |
| 158 | + default_var = {"default_key": "default_value"} |
| 159 | + error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"}) |
| 160 | + |
| 161 | + mock_supervisor_comms.get_message.return_value = error_response |
| 162 | + |
| 163 | + var = accessor.get("nonexistent_var_key", default_var=default_var) |
| 164 | + assert var == default_var |
0 commit comments