diff --git a/flytekit/extras/webhook/agent.py b/flytekit/extras/webhook/agent.py index c94cb46226..3b1620f338 100644 --- a/flytekit/extras/webhook/agent.py +++ b/flytekit/extras/webhook/agent.py @@ -1,7 +1,7 @@ import http from typing import Optional -import aiohttp +import httpx from flyteidl.core.execution_pb2 import TaskExecution from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase @@ -14,18 +14,33 @@ class WebhookAgent(SyncAgentBase): - name = "Webhook Agent" + """ + WebhookAgent is responsible for handling webhook tasks. - def __init__(self): + This agent sends HTTP requests based on the task template and inputs provided, + and processes the responses to determine the success or failure of the task. + + :param client: An optional HTTP client to use for sending requests. + """ + + name: str = "Webhook Agent" + + def __init__(self, client: Optional[httpx.AsyncClient] = None): super().__init__(task_type_name=TASK_TYPE) + self._client = client or httpx.AsyncClient() async def do( self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs ) -> Resource: + """ + This method processes the webhook task and sends an HTTP request. + + It uses asyncio to send the request and process the response using the httpx library. + """ try: final_dict = self._get_final_dict(task_template, inputs) return await self._process_webhook(final_dict) - except aiohttp.ClientError as e: + except Exception as e: return Resource(phase=TaskExecution.FAILED, message=str(e)) def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> dict: @@ -38,16 +53,11 @@ def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> di async def _make_http_request( self, method: http.HTTPMethod, url: str, headers: dict, data: dict, timeout: int ) -> tuple: - # TODO This is a potential performance bottleneck. Consider using a connection pool. To do this, we need to - # create a session object and reuse it for multiple requests. This will reduce the overhead of creating a new - # connection for each request. The problem for not doing so is local execution, does not have a common event - # loop and agent executor creates a new event loop for each request (in the mixin). - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: - if method == http.HTTPMethod.GET: - response = await session.get(url, headers=headers, params=data) - else: - response = await session.post(url, json=data, headers=headers) - return response.status, await response.text() + if method == http.HTTPMethod.GET: + response = await self._client.get(url, headers=headers, params=data, timeout=timeout) + else: + response = await self._client.post(url, json=data, headers=headers, timeout=timeout) + return response.status_code, response.text @staticmethod def _build_response( diff --git a/flytekit/extras/webhook/task.py b/flytekit/extras/webhook/task.py index 46600962e4..e6eceb6f81 100644 --- a/flytekit/extras/webhook/task.py +++ b/flytekit/extras/webhook/task.py @@ -13,27 +13,49 @@ class WebhookTask(SyncAgentExecutorMixin, PythonTask): """ - This is the simplest form of a BigQuery Task, that can be used even for tasks that do not produce any output. - """ + The WebhookTask is used to invoke a webhook. The webhook can be invoked with a POST or GET method. + + All the parameters can be formatted using python format strings. + + Example: + ```python + simple_get = WebhookTask( + name="simple-get", + url="http://localhost:8000/", + method=http.HTTPMethod.GET, + headers={"Content-Type": "application/json"}, + ) + + get_with_params = WebhookTask( + name="get-with-params", + url="http://localhost:8000/items/{inputs.item_id}", + method=http.HTTPMethod.GET, + headers={"Content-Type": "application/json"}, + dynamic_inputs={"s": str, "item_id": int}, + show_data=True, + show_url=True, + description="Test Webhook Task", + data={"q": "{inputs.s}"}, + ) - def __init__( - self, - name: str, - url: str, - method: http.HTTPMethod = http.HTTPMethod.POST, - headers: Optional[Dict[str, str]] = None, - data: Optional[Dict[str, Any]] = None, - dynamic_inputs: Optional[Dict[str, Type]] = None, - show_data: bool = False, - show_url: bool = False, - description: Optional[str] = None, - timeout: Union[int, timedelta] = timedelta(seconds=10), - # secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon - ): - """ - This task is used to invoke a webhook. The webhook can be invoked with a POST or GET method. - All the parameters can be formatted using python format strings. The following parameters are available for + @fk.workflow + def wf(s: str) -> (dict, dict, dict): + v = hello(s=s) + w = WebhookTask( + name="invoke-slack", + url="https://hooks.slack.com/services/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", + headers={"Content-Type": "application/json"}, + data={"text": "{inputs.s}"}, + show_data=True, + show_url=True, + description="Test Webhook Task", + dynamic_inputs={"s": str}, + ) + return simple_get(), get_with_params(s=v, item_id=10), w(s=v) + ``` + + All the parameters can be formatted using python format strings. The following parameters are available for formatting: - dynamic_inputs: These are the dynamic inputs to the task. The keys are the names of the inputs and the values are the values of the inputs. All inputs are available under the prefix `inputs.`. @@ -69,7 +91,22 @@ def __init__( :param description: Description of the task :param timeout: The timeout for the request (connection and read). Default is 10 seconds. If int value is provided, it is considered as seconds. - """ + """ + + def __init__( + self, + name: str, + url: str, + method: http.HTTPMethod = http.HTTPMethod.POST, + headers: Optional[Dict[str, str]] = None, + data: Optional[Dict[str, Any]] = None, + dynamic_inputs: Optional[Dict[str, Type]] = None, + show_data: bool = False, + show_url: bool = False, + description: Optional[str] = None, + timeout: Union[int, timedelta] = timedelta(seconds=10), + # secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon + ): if method not in {http.HTTPMethod.GET, http.HTTPMethod.POST}: raise ValueError(f"Method should be either GET or POST. Got {method}") diff --git a/tests/flytekit/unit/extras/webhook/test_agent.py b/tests/flytekit/unit/extras/webhook/test_agent.py index 38484bceb1..a2bbd408b5 100644 --- a/tests/flytekit/unit/extras/webhook/test_agent.py +++ b/tests/flytekit/unit/extras/webhook/test_agent.py @@ -1,12 +1,12 @@ from unittest.mock import patch, MagicMock, AsyncMock import pytest +import httpx from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.extras.webhook.agent import WebhookAgent -from flytekit.extras.webhook.constants import SHOW_DATA_KEY, DATA_KEY, METHOD_KEY, URL_KEY, HEADERS_KEY, SHOW_URL_KEY, \ - TIMEOUT_SEC +from flytekit.extras.webhook.constants import SHOW_DATA_KEY, DATA_KEY, METHOD_KEY, URL_KEY, HEADERS_KEY, SHOW_URL_KEY, TIMEOUT_SEC from flytekit.models.core.execution import TaskExecutionPhase as TaskExecution from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -27,20 +27,15 @@ def mock_task_template(): return task_template -@pytest.fixture -def mock_aiohttp_session(): - with patch('aiohttp.ClientSession') as mock_session: - yield mock_session - - @pytest.mark.asyncio -async def test_do_post_success(mock_task_template, mock_aiohttp_session): - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.text = AsyncMock(return_value="Success") - mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response) - - agent = WebhookAgent() +async def test_do_post_success(mock_task_template): + mock_response = AsyncMock(name="httpx.Response") + mock_response.status_code = 200 + mock_response.text = "Success" + mock_httpx_client = AsyncMock(name="httpx.AsyncClient") + mock_httpx_client.post.return_value = mock_response + + agent = WebhookAgent(client=mock_httpx_client) result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) assert result.phase == TaskExecution.SUCCEEDED @@ -50,17 +45,18 @@ async def test_do_post_success(mock_task_template, mock_aiohttp_session): @pytest.mark.asyncio -async def test_do_get_success(mock_task_template, mock_aiohttp_session): +async def test_do_get_success(mock_task_template): mock_task_template.custom[METHOD_KEY] = "GET" mock_task_template.custom.pop(DATA_KEY) mock_task_template.custom[SHOW_DATA_KEY] = False - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.text = AsyncMock(return_value="Success") - mock_aiohttp_session.return_value.get = AsyncMock(return_value=mock_response) + mock_response = AsyncMock(name="httpx.Response") + mock_response.status_code = 200 + mock_response.text = "Success" + mock_httpx_client = AsyncMock(name="httpx.AsyncClient") + mock_httpx_client.get.return_value = mock_response - agent = WebhookAgent() + agent = WebhookAgent(client=mock_httpx_client) result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) assert result.phase == TaskExecution.SUCCEEDED @@ -70,13 +66,33 @@ async def test_do_get_success(mock_task_template, mock_aiohttp_session): @pytest.mark.asyncio -async def test_do_failure(mock_task_template, mock_aiohttp_session): - mock_response = AsyncMock() - mock_response.status = 500 - mock_response.text = AsyncMock(return_value="Internal Server Error") - mock_aiohttp_session.return_value.post = AsyncMock(return_value=mock_response) +async def test_do_failure(mock_task_template): + mock_response = AsyncMock(name="httpx.Response") + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_httpx_client = AsyncMock(name="httpx.AsyncClient") + mock_httpx_client.post.return_value = mock_response + + agent = WebhookAgent(client=mock_httpx_client) + result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) - agent = WebhookAgent() + assert result.phase == TaskExecution.FAILED + assert "Webhook failed with status code 500" in result.message + + +@pytest.mark.asyncio +async def test_do_get_failure(mock_task_template): + mock_task_template.custom[METHOD_KEY] = "GET" + mock_task_template.custom.pop(DATA_KEY) + mock_task_template.custom[SHOW_DATA_KEY] = False + + mock_response = AsyncMock(name="httpx.Response") + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_httpx_client = AsyncMock(name="httpx.AsyncClient") + mock_httpx_client.get.return_value = mock_response + + agent = WebhookAgent(client=mock_httpx_client) result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({})) assert result.phase == TaskExecution.FAILED diff --git a/tests/flytekit/unit/extras/webhook/test_end_to_end.py b/tests/flytekit/unit/extras/webhook/test_end_to_end.py index baf623cbdf..80cd71ae93 100644 --- a/tests/flytekit/unit/extras/webhook/test_end_to_end.py +++ b/tests/flytekit/unit/extras/webhook/test_end_to_end.py @@ -1,5 +1,5 @@ import http -from unittest.mock import patch, AsyncMock +from unittest.mock import patch import flytekit as fk from flytekit.extras.webhook import WebhookTask