Skip to content

Commit 293a704

Browse files
committed
Updated
Signed-off-by: Ketan Umare <[email protected]>
1 parent 3c9a800 commit 293a704

File tree

4 files changed

+125
-62
lines changed

4 files changed

+125
-62
lines changed

flytekit/extras/webhook/agent.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import http
22
from typing import Optional
33

4-
import aiohttp
4+
import httpx
55
from flyteidl.core.execution_pb2 import TaskExecution
66

77
from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase
@@ -14,18 +14,33 @@
1414

1515

1616
class WebhookAgent(SyncAgentBase):
17-
name = "Webhook Agent"
17+
"""
18+
WebhookAgent is responsible for handling webhook tasks.
1819
19-
def __init__(self):
20+
This agent sends HTTP requests based on the task template and inputs provided,
21+
and processes the responses to determine the success or failure of the task.
22+
23+
:param client: An optional HTTP client to use for sending requests.
24+
"""
25+
26+
name: str = "Webhook Agent"
27+
28+
def __init__(self, client: Optional[httpx.AsyncClient] = None):
2029
super().__init__(task_type_name=TASK_TYPE)
30+
self._client = client or httpx.AsyncClient()
2131

2232
async def do(
2333
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
2434
) -> Resource:
35+
"""
36+
This method processes the webhook task and sends an HTTP request.
37+
38+
It uses asyncio to send the request and process the response using the httpx library.
39+
"""
2540
try:
2641
final_dict = self._get_final_dict(task_template, inputs)
2742
return await self._process_webhook(final_dict)
28-
except aiohttp.ClientError as e:
43+
except Exception as e:
2944
return Resource(phase=TaskExecution.FAILED, message=str(e))
3045

3146
def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> dict:
@@ -38,16 +53,11 @@ def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> di
3853
async def _make_http_request(
3954
self, method: http.HTTPMethod, url: str, headers: dict, data: dict, timeout: int
4055
) -> tuple:
41-
# TODO This is a potential performance bottleneck. Consider using a connection pool. To do this, we need to
42-
# create a session object and reuse it for multiple requests. This will reduce the overhead of creating a new
43-
# connection for each request. The problem for not doing so is local execution, does not have a common event
44-
# loop and agent executor creates a new event loop for each request (in the mixin).
45-
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
46-
if method == http.HTTPMethod.GET:
47-
response = await session.get(url, headers=headers, params=data)
48-
else:
49-
response = await session.post(url, json=data, headers=headers)
50-
return response.status, await response.text()
56+
if method == http.HTTPMethod.GET:
57+
response = await self._client.get(url, headers=headers, params=data, timeout=timeout)
58+
else:
59+
response = await self._client.post(url, json=data, headers=headers, timeout=timeout)
60+
return response.status_code, response.text
5161

5262
@staticmethod
5363
def _build_response(

flytekit/extras/webhook/task.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,49 @@
1313

1414
class WebhookTask(SyncAgentExecutorMixin, PythonTask):
1515
"""
16-
This is the simplest form of a BigQuery Task, that can be used even for tasks that do not produce any output.
17-
"""
16+
The WebhookTask is used to invoke a webhook. The webhook can be invoked with a POST or GET method.
17+
18+
All the parameters can be formatted using python format strings.
19+
20+
Example:
21+
```python
22+
simple_get = WebhookTask(
23+
name="simple-get",
24+
url="http://localhost:8000/",
25+
method=http.HTTPMethod.GET,
26+
headers={"Content-Type": "application/json"},
27+
)
28+
29+
get_with_params = WebhookTask(
30+
name="get-with-params",
31+
url="http://localhost:8000/items/{inputs.item_id}",
32+
method=http.HTTPMethod.GET,
33+
headers={"Content-Type": "application/json"},
34+
dynamic_inputs={"s": str, "item_id": int},
35+
show_data=True,
36+
show_url=True,
37+
description="Test Webhook Task",
38+
data={"q": "{inputs.s}"},
39+
)
1840
19-
def __init__(
20-
self,
21-
name: str,
22-
url: str,
23-
method: http.HTTPMethod = http.HTTPMethod.POST,
24-
headers: Optional[Dict[str, str]] = None,
25-
data: Optional[Dict[str, Any]] = None,
26-
dynamic_inputs: Optional[Dict[str, Type]] = None,
27-
show_data: bool = False,
28-
show_url: bool = False,
29-
description: Optional[str] = None,
30-
timeout: Union[int, timedelta] = timedelta(seconds=10),
31-
# secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon
32-
):
33-
"""
34-
This task is used to invoke a webhook. The webhook can be invoked with a POST or GET method.
3541
36-
All the parameters can be formatted using python format strings. The following parameters are available for
42+
@fk.workflow
43+
def wf(s: str) -> (dict, dict, dict):
44+
v = hello(s=s)
45+
w = WebhookTask(
46+
name="invoke-slack",
47+
url="https://hooks.slack.com/services/xyz/zaa/aaa",
48+
headers={"Content-Type": "application/json"},
49+
data={"text": "{inputs.s}"},
50+
show_data=True,
51+
show_url=True,
52+
description="Test Webhook Task",
53+
dynamic_inputs={"s": str},
54+
)
55+
return simple_get(), get_with_params(s=v, item_id=10), w(s=v)
56+
```
57+
58+
All the parameters can be formatted using python format strings. The following parameters are available for
3759
formatting:
3860
- dynamic_inputs: These are the dynamic inputs to the task. The keys are the names of the inputs and the values
3961
are the values of the inputs. All inputs are available under the prefix `inputs.`.
@@ -69,7 +91,22 @@ def __init__(
6991
:param description: Description of the task
7092
:param timeout: The timeout for the request (connection and read). Default is 10 seconds. If int value is provided,
7193
it is considered as seconds.
72-
"""
94+
"""
95+
96+
def __init__(
97+
self,
98+
name: str,
99+
url: str,
100+
method: http.HTTPMethod = http.HTTPMethod.POST,
101+
headers: Optional[Dict[str, str]] = None,
102+
data: Optional[Dict[str, Any]] = None,
103+
dynamic_inputs: Optional[Dict[str, Type]] = None,
104+
show_data: bool = False,
105+
show_url: bool = False,
106+
description: Optional[str] = None,
107+
timeout: Union[int, timedelta] = timedelta(seconds=10),
108+
# secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon
109+
):
73110
if method not in {http.HTTPMethod.GET, http.HTTPMethod.POST}:
74111
raise ValueError(f"Method should be either GET or POST. Got {method}")
75112

tests/flytekit/unit/extras/webhook/test_agent.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from unittest.mock import patch, MagicMock, AsyncMock
22

33
import pytest
4+
import httpx
45

56
from flytekit.core.context_manager import FlyteContextManager
67
from flytekit.core.type_engine import TypeEngine
78
from flytekit.extras.webhook.agent import WebhookAgent
8-
from flytekit.extras.webhook.constants import SHOW_DATA_KEY, DATA_KEY, METHOD_KEY, URL_KEY, HEADERS_KEY, SHOW_URL_KEY, \
9-
TIMEOUT_SEC
9+
from flytekit.extras.webhook.constants import SHOW_DATA_KEY, DATA_KEY, METHOD_KEY, URL_KEY, HEADERS_KEY, SHOW_URL_KEY, TIMEOUT_SEC
1010
from flytekit.models.core.execution import TaskExecutionPhase as TaskExecution
1111
from flytekit.models.literals import LiteralMap
1212
from flytekit.models.task import TaskTemplate
@@ -27,20 +27,15 @@ def mock_task_template():
2727
return task_template
2828

2929

30-
@pytest.fixture
31-
def mock_aiohttp_session():
32-
with patch('aiohttp.ClientSession') as mock_session:
33-
yield mock_session
34-
35-
3630
@pytest.mark.asyncio
37-
async def test_do_post_success(mock_task_template, mock_aiohttp_session):
38-
mock_response = AsyncMock()
39-
mock_response.status = 200
40-
mock_response.text = AsyncMock(return_value="Success")
41-
mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response)
42-
43-
agent = WebhookAgent()
31+
async def test_do_post_success(mock_task_template):
32+
mock_response = AsyncMock(name="httpx.Response")
33+
mock_response.status_code = 200
34+
mock_response.text = "Success"
35+
mock_httpx_client = AsyncMock(name="httpx.AsyncClient")
36+
mock_httpx_client.post.return_value = mock_response
37+
38+
agent = WebhookAgent(client=mock_httpx_client)
4439
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))
4540

4641
assert result.phase == TaskExecution.SUCCEEDED
@@ -50,17 +45,18 @@ async def test_do_post_success(mock_task_template, mock_aiohttp_session):
5045

5146

5247
@pytest.mark.asyncio
53-
async def test_do_get_success(mock_task_template, mock_aiohttp_session):
48+
async def test_do_get_success(mock_task_template):
5449
mock_task_template.custom[METHOD_KEY] = "GET"
5550
mock_task_template.custom.pop(DATA_KEY)
5651
mock_task_template.custom[SHOW_DATA_KEY] = False
5752

58-
mock_response = AsyncMock()
59-
mock_response.status = 200
60-
mock_response.text = AsyncMock(return_value="Success")
61-
mock_aiohttp_session.return_value.get = AsyncMock(return_value=mock_response)
53+
mock_response = AsyncMock(name="httpx.Response")
54+
mock_response.status_code = 200
55+
mock_response.text = "Success"
56+
mock_httpx_client = AsyncMock(name="httpx.AsyncClient")
57+
mock_httpx_client.get.return_value = mock_response
6258

63-
agent = WebhookAgent()
59+
agent = WebhookAgent(client=mock_httpx_client)
6460
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))
6561

6662
assert result.phase == TaskExecution.SUCCEEDED
@@ -70,13 +66,33 @@ async def test_do_get_success(mock_task_template, mock_aiohttp_session):
7066

7167

7268
@pytest.mark.asyncio
73-
async def test_do_failure(mock_task_template, mock_aiohttp_session):
74-
mock_response = AsyncMock()
75-
mock_response.status = 500
76-
mock_response.text = AsyncMock(return_value="Internal Server Error")
77-
mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response)
69+
async def test_do_failure(mock_task_template):
70+
mock_response = AsyncMock(name="httpx.Response")
71+
mock_response.status_code = 500
72+
mock_response.text = "Internal Server Error"
73+
mock_httpx_client = AsyncMock(name="httpx.AsyncClient")
74+
mock_httpx_client.post.return_value = mock_response
75+
76+
agent = WebhookAgent(client=mock_httpx_client)
77+
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))
7878

79-
agent = WebhookAgent()
79+
assert result.phase == TaskExecution.FAILED
80+
assert "Webhook failed with status code 500" in result.message
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_do_get_failure(mock_task_template):
85+
mock_task_template.custom[METHOD_KEY] = "GET"
86+
mock_task_template.custom.pop(DATA_KEY)
87+
mock_task_template.custom[SHOW_DATA_KEY] = False
88+
89+
mock_response = AsyncMock(name="httpx.Response")
90+
mock_response.status_code = 500
91+
mock_response.text = "Internal Server Error"
92+
mock_httpx_client = AsyncMock(name="httpx.AsyncClient")
93+
mock_httpx_client.get.return_value = mock_response
94+
95+
agent = WebhookAgent(client=mock_httpx_client)
8096
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))
8197

8298
assert result.phase == TaskExecution.FAILED

tests/flytekit/unit/extras/webhook/test_end_to_end.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import http
2-
from unittest.mock import patch, AsyncMock
2+
from unittest.mock import patch
33

44
import flytekit as fk
55
from flytekit.extras.webhook import WebhookTask

0 commit comments

Comments
 (0)