|
33 | 33 | from flytekit.exceptions import user as user_exceptions |
34 | 34 | from flytekit.exceptions.base import FlyteException |
35 | 35 | from flytekit.exceptions.scopes import system_entry_point |
36 | | -from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException |
| 36 | +from flytekit.exceptions.user import FlyteRecoverableException, FlyteUserRuntimeException, FlyteUserException |
37 | 37 | from flytekit.models import literals as _literal_models |
38 | 38 | from flytekit.models.core import errors as error_models, execution |
39 | 39 | from flytekit.models.core import execution as execution_models |
@@ -128,19 +128,24 @@ def verify_output(*args, **kwargs): |
128 | 128 | _dispatch_execute(ctx, lambda: python_task, "inputs path", "outputs prefix") |
129 | 129 | assert mock_write_to_file.call_count == 1 |
130 | 130 |
|
| 131 | +class CustomException(FlyteUserException): |
| 132 | + _ERROR_CODE = "USER:CustomError" |
| 133 | + |
131 | 134 | @pytest.mark.parametrize( |
132 | 135 | "exception_value", |
133 | 136 | [ |
134 | | - FlyteException("exception", timestamp=1), |
135 | | - FlyteException("exception"), |
136 | | - Exception("exception"), |
| 137 | + [FlyteException("exception", timestamp=1), FlyteException.error_code], |
| 138 | + [FlyteException("exception"), FlyteException.error_code], |
| 139 | + [Exception("exception"), FlyteUserRuntimeException.error_code], |
| 140 | + [CustomException("exception"), CustomException.error_code], |
137 | 141 | ] |
138 | 142 | ) |
139 | 143 | @mock.patch("flytekit.core.utils.load_proto_from_file") |
140 | 144 | @mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") |
141 | 145 | @mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") |
142 | 146 | @mock.patch("flytekit.core.utils.write_proto_to_file") |
143 | | -def test_dispatch_execute_exception_with_multi_error_files(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto, exception_value: Exception, monkeypatch): |
| 147 | +def test_dispatch_execute_exception_with_multi_error_files(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto, exception_value: typing.Tuple[Exception, str], monkeypatch): |
| 148 | + exception_value, error_code = exception_value |
144 | 149 | monkeypatch.setenv("_F_DES", "1") |
145 | 150 | monkeypatch.setenv("_F_WN", "worker") |
146 | 151 |
|
@@ -170,7 +175,7 @@ def verify_output(*args, **kwargs): |
170 | 175 | assert error_filename_base.startswith("error-") |
171 | 176 | uuid.UUID(hex=error_filename_base[6:], version=4) |
172 | 177 | assert error_filename_ext == ".pb" |
173 | | - assert container_error.code == "USER:RuntimeError" |
| 178 | + assert container_error.code == error_code |
174 | 179 |
|
175 | 180 | mock_write_to_file.side_effect = verify_output |
176 | 181 | _dispatch_execute(ctx, lambda: python_task, "inputs path", "outputs prefix") |
|
0 commit comments