11from unittest .mock import patch , MagicMock , AsyncMock
22
33import pytest
4+ import httpx
45
56from flytekit .core .context_manager import FlyteContextManager
67from flytekit .core .type_engine import TypeEngine
78from flytekit .extras .webhook .agent import WebhookAgent
8- from flytekit .extras .webhook .constants import SHOW_DATA_KEY , DATA_KEY , METHOD_KEY , URL_KEY , HEADERS_KEY , SHOW_URL_KEY , \
9- TIMEOUT_SEC
9+ from flytekit .extras .webhook .constants import SHOW_DATA_KEY , DATA_KEY , METHOD_KEY , URL_KEY , HEADERS_KEY , SHOW_URL_KEY , TIMEOUT_SEC
1010from flytekit .models .core .execution import TaskExecutionPhase as TaskExecution
1111from flytekit .models .literals import LiteralMap
1212from flytekit .models .task import TaskTemplate
@@ -27,20 +27,15 @@ def mock_task_template():
2727 return task_template
2828
2929
30- @pytest .fixture
31- def mock_aiohttp_session ():
32- with patch ('aiohttp.ClientSession' ) as mock_session :
33- yield mock_session
34-
35-
3630@pytest .mark .asyncio
37- async def test_do_post_success (mock_task_template , mock_aiohttp_session ):
38- mock_response = AsyncMock ()
39- mock_response .status = 200
40- mock_response .text = AsyncMock (return_value = "Success" )
41- mock_aiohttp_session .return_value .post = AsyncMock (return_value = mock_response )
42-
43- agent = WebhookAgent ()
31+ async def test_do_post_success (mock_task_template ):
32+ mock_response = AsyncMock (name = "httpx.Response" )
33+ mock_response .status_code = 200
34+ mock_response .text = "Success"
35+ mock_httpx_client = AsyncMock (name = "httpx.AsyncClient" )
36+ mock_httpx_client .post .return_value = mock_response
37+
38+ agent = WebhookAgent (client = mock_httpx_client )
4439 result = await agent .do (mock_task_template , output_prefix = "" , inputs = LiteralMap ({}))
4540
4641 assert result .phase == TaskExecution .SUCCEEDED
@@ -50,17 +45,18 @@ async def test_do_post_success(mock_task_template, mock_aiohttp_session):
5045
5146
5247@pytest .mark .asyncio
53- async def test_do_get_success (mock_task_template , mock_aiohttp_session ):
48+ async def test_do_get_success (mock_task_template ):
5449 mock_task_template .custom [METHOD_KEY ] = "GET"
5550 mock_task_template .custom .pop (DATA_KEY )
5651 mock_task_template .custom [SHOW_DATA_KEY ] = False
5752
58- mock_response = AsyncMock ()
59- mock_response .status = 200
60- mock_response .text = AsyncMock (return_value = "Success" )
61- mock_aiohttp_session .return_value .get = AsyncMock (return_value = mock_response )
53+ mock_response = AsyncMock (name = "httpx.Response" )
54+ mock_response .status_code = 200
55+ mock_response .text = "Success"
56+ mock_httpx_client = AsyncMock (name = "httpx.AsyncClient" )
57+ mock_httpx_client .get .return_value = mock_response
6258
63- agent = WebhookAgent ()
59+ agent = WebhookAgent (client = mock_httpx_client )
6460 result = await agent .do (mock_task_template , output_prefix = "" , inputs = LiteralMap ({}))
6561
6662 assert result .phase == TaskExecution .SUCCEEDED
@@ -70,13 +66,33 @@ async def test_do_get_success(mock_task_template, mock_aiohttp_session):
7066
7167
7268@pytest .mark .asyncio
73- async def test_do_failure (mock_task_template , mock_aiohttp_session ):
74- mock_response = AsyncMock ()
75- mock_response .status = 500
76- mock_response .text = AsyncMock (return_value = "Internal Server Error" )
77- mock_aiohttp_session .return_value .post = AsyncMock (return_value = mock_response )
69+ async def test_do_failure (mock_task_template ):
70+ mock_response = AsyncMock (name = "httpx.Response" )
71+ mock_response .status_code = 500
72+ mock_response .text = "Internal Server Error"
73+ mock_httpx_client = AsyncMock (name = "httpx.AsyncClient" )
74+ mock_httpx_client .post .return_value = mock_response
75+
76+ agent = WebhookAgent (client = mock_httpx_client )
77+ result = await agent .do (mock_task_template , output_prefix = "" , inputs = LiteralMap ({}))
7878
79- agent = WebhookAgent ()
79+ assert result .phase == TaskExecution .FAILED
80+ assert "Webhook failed with status code 500" in result .message
81+
82+
83+ @pytest .mark .asyncio
84+ async def test_do_get_failure (mock_task_template ):
85+ mock_task_template .custom [METHOD_KEY ] = "GET"
86+ mock_task_template .custom .pop (DATA_KEY )
87+ mock_task_template .custom [SHOW_DATA_KEY ] = False
88+
89+ mock_response = AsyncMock (name = "httpx.Response" )
90+ mock_response .status_code = 500
91+ mock_response .text = "Internal Server Error"
92+ mock_httpx_client = AsyncMock (name = "httpx.AsyncClient" )
93+ mock_httpx_client .get .return_value = mock_response
94+
95+ agent = WebhookAgent (client = mock_httpx_client )
8096 result = await agent .do (mock_task_template , output_prefix = "" , inputs = LiteralMap ({}))
8197
8298 assert result .phase == TaskExecution .FAILED
0 commit comments