Skip to content

Commit

Permalink
Updated
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed Feb 1, 2025
1 parent 3c9a800 commit 293a704
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 62 deletions.
38 changes: 24 additions & 14 deletions flytekit/extras/webhook/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import http
from typing import Optional

import aiohttp
import httpx
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase
Expand All @@ -14,18 +14,33 @@


class WebhookAgent(SyncAgentBase):
name = "Webhook Agent"
"""
WebhookAgent is responsible for handling webhook tasks.
def __init__(self):
This agent sends HTTP requests based on the task template and inputs provided,
and processes the responses to determine the success or failure of the task.
:param client: An optional HTTP client to use for sending requests.
"""

name: str = "Webhook Agent"

def __init__(self, client: Optional[httpx.AsyncClient] = None):
super().__init__(task_type_name=TASK_TYPE)
self._client = client or httpx.AsyncClient()

async def do(
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
) -> Resource:
"""
This method processes the webhook task and sends an HTTP request.
It uses asyncio to send the request and process the response using the httpx library.
"""
try:
final_dict = self._get_final_dict(task_template, inputs)
return await self._process_webhook(final_dict)
except aiohttp.ClientError as e:
except Exception as e:
return Resource(phase=TaskExecution.FAILED, message=str(e))

def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> dict:
Expand All @@ -38,16 +53,11 @@ def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> di
async def _make_http_request(
self, method: http.HTTPMethod, url: str, headers: dict, data: dict, timeout: int
) -> tuple:
# TODO This is a potential performance bottleneck. Consider using a connection pool. To do this, we need to
# create a session object and reuse it for multiple requests. This will reduce the overhead of creating a new
# connection for each request. The problem for not doing so is local execution, does not have a common event
# loop and agent executor creates a new event loop for each request (in the mixin).
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
if method == http.HTTPMethod.GET:
response = await session.get(url, headers=headers, params=data)
else:
response = await session.post(url, json=data, headers=headers)
return response.status, await response.text()
if method == http.HTTPMethod.GET:
response = await self._client.get(url, headers=headers, params=data, timeout=timeout)
else:
response = await self._client.post(url, json=data, headers=headers, timeout=timeout)
return response.status_code, response.text

@staticmethod
def _build_response(
Expand Down
77 changes: 57 additions & 20 deletions flytekit/extras/webhook/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,49 @@

class WebhookTask(SyncAgentExecutorMixin, PythonTask):
"""
This is the simplest form of a BigQuery Task, that can be used even for tasks that do not produce any output.
"""
The WebhookTask is used to invoke a webhook. The webhook can be invoked with a POST or GET method.
All the parameters can be formatted using python format strings.
Example:
```python
simple_get = WebhookTask(
name="simple-get",
url="http://localhost:8000/",
method=http.HTTPMethod.GET,
headers={"Content-Type": "application/json"},
)
get_with_params = WebhookTask(
name="get-with-params",
url="http://localhost:8000/items/{inputs.item_id}",
method=http.HTTPMethod.GET,
headers={"Content-Type": "application/json"},
dynamic_inputs={"s": str, "item_id": int},
show_data=True,
show_url=True,
description="Test Webhook Task",
data={"q": "{inputs.s}"},
)
def __init__(
self,
name: str,
url: str,
method: http.HTTPMethod = http.HTTPMethod.POST,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
dynamic_inputs: Optional[Dict[str, Type]] = None,
show_data: bool = False,
show_url: bool = False,
description: Optional[str] = None,
timeout: Union[int, timedelta] = timedelta(seconds=10),
# secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon
):
"""
This task is used to invoke a webhook. The webhook can be invoked with a POST or GET method.
All the parameters can be formatted using python format strings. The following parameters are available for
@fk.workflow
def wf(s: str) -> (dict, dict, dict):
v = hello(s=s)
w = WebhookTask(
name="invoke-slack",
url="https://hooks.slack.com/services/xyz/zaa/aaa",
headers={"Content-Type": "application/json"},
data={"text": "{inputs.s}"},
show_data=True,
show_url=True,
description="Test Webhook Task",
dynamic_inputs={"s": str},
)
return simple_get(), get_with_params(s=v, item_id=10), w(s=v)
```
All the parameters can be formatted using python format strings. The following parameters are available for
formatting:
- dynamic_inputs: These are the dynamic inputs to the task. The keys are the names of the inputs and the values
are the values of the inputs. All inputs are available under the prefix `inputs.`.
Expand Down Expand Up @@ -69,7 +91,22 @@ def __init__(
:param description: Description of the task
:param timeout: The timeout for the request (connection and read). Default is 10 seconds. If int value is provided,
it is considered as seconds.
"""
"""

def __init__(
self,
name: str,
url: str,
method: http.HTTPMethod = http.HTTPMethod.POST,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
dynamic_inputs: Optional[Dict[str, Type]] = None,
show_data: bool = False,
show_url: bool = False,
description: Optional[str] = None,
timeout: Union[int, timedelta] = timedelta(seconds=10),
# secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon
):
if method not in {http.HTTPMethod.GET, http.HTTPMethod.POST}:
raise ValueError(f"Method should be either GET or POST. Got {method}")

Expand Down
70 changes: 43 additions & 27 deletions tests/flytekit/unit/extras/webhook/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from unittest.mock import patch, MagicMock, AsyncMock

import pytest
import httpx

from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
from flytekit.extras.webhook.agent import WebhookAgent
from flytekit.extras.webhook.constants import SHOW_DATA_KEY, DATA_KEY, METHOD_KEY, URL_KEY, HEADERS_KEY, SHOW_URL_KEY, \
TIMEOUT_SEC
from flytekit.extras.webhook.constants import SHOW_DATA_KEY, DATA_KEY, METHOD_KEY, URL_KEY, HEADERS_KEY, SHOW_URL_KEY, TIMEOUT_SEC
from flytekit.models.core.execution import TaskExecutionPhase as TaskExecution
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
Expand All @@ -27,20 +27,15 @@ def mock_task_template():
return task_template


@pytest.fixture
def mock_aiohttp_session():
with patch('aiohttp.ClientSession') as mock_session:
yield mock_session


@pytest.mark.asyncio
async def test_do_post_success(mock_task_template, mock_aiohttp_session):
mock_response = AsyncMock()
mock_response.status = 200
mock_response.text = AsyncMock(return_value="Success")
mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response)

agent = WebhookAgent()
async def test_do_post_success(mock_task_template):
mock_response = AsyncMock(name="httpx.Response")
mock_response.status_code = 200
mock_response.text = "Success"
mock_httpx_client = AsyncMock(name="httpx.AsyncClient")
mock_httpx_client.post.return_value = mock_response

agent = WebhookAgent(client=mock_httpx_client)
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))

assert result.phase == TaskExecution.SUCCEEDED
Expand All @@ -50,17 +45,18 @@ async def test_do_post_success(mock_task_template, mock_aiohttp_session):


@pytest.mark.asyncio
async def test_do_get_success(mock_task_template, mock_aiohttp_session):
async def test_do_get_success(mock_task_template):
mock_task_template.custom[METHOD_KEY] = "GET"
mock_task_template.custom.pop(DATA_KEY)
mock_task_template.custom[SHOW_DATA_KEY] = False

mock_response = AsyncMock()
mock_response.status = 200
mock_response.text = AsyncMock(return_value="Success")
mock_aiohttp_session.return_value.get = AsyncMock(return_value=mock_response)
mock_response = AsyncMock(name="httpx.Response")
mock_response.status_code = 200
mock_response.text = "Success"
mock_httpx_client = AsyncMock(name="httpx.AsyncClient")
mock_httpx_client.get.return_value = mock_response

agent = WebhookAgent()
agent = WebhookAgent(client=mock_httpx_client)
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))

assert result.phase == TaskExecution.SUCCEEDED
Expand All @@ -70,13 +66,33 @@ async def test_do_get_success(mock_task_template, mock_aiohttp_session):


@pytest.mark.asyncio
async def test_do_failure(mock_task_template, mock_aiohttp_session):
mock_response = AsyncMock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value="Internal Server Error")
mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response)
async def test_do_failure(mock_task_template):
mock_response = AsyncMock(name="httpx.Response")
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_httpx_client = AsyncMock(name="httpx.AsyncClient")
mock_httpx_client.post.return_value = mock_response

agent = WebhookAgent(client=mock_httpx_client)
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))

agent = WebhookAgent()
assert result.phase == TaskExecution.FAILED
assert "Webhook failed with status code 500" in result.message


@pytest.mark.asyncio
async def test_do_get_failure(mock_task_template):
mock_task_template.custom[METHOD_KEY] = "GET"
mock_task_template.custom.pop(DATA_KEY)
mock_task_template.custom[SHOW_DATA_KEY] = False

mock_response = AsyncMock(name="httpx.Response")
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_httpx_client = AsyncMock(name="httpx.AsyncClient")
mock_httpx_client.get.return_value = mock_response

agent = WebhookAgent(client=mock_httpx_client)
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))

assert result.phase == TaskExecution.FAILED
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/extras/webhook/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import http
from unittest.mock import patch, AsyncMock
from unittest.mock import patch

import flytekit as fk
from flytekit.extras.webhook import WebhookTask
Expand Down

0 comments on commit 293a704

Please sign in to comment.