Skip to content

Commit

Permalink
Updated webhook unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed Jan 16, 2025
1 parent 5c1537b commit 8230a88
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 35 deletions.
27 changes: 17 additions & 10 deletions flytekit/extras/webhook/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from flytekit.models.task import TaskTemplate
from flytekit.utils.dict_formatter import format_dict

from .constants import DATA_KEY, HEADERS_KEY, METHOD_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, TASK_TYPE, URL_KEY
from .constants import DATA_KEY, HEADERS_KEY, METHOD_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, TASK_TYPE, TIMEOUT_SEC, URL_KEY


class WebhookAgent(SyncAgentBase):
Expand All @@ -23,26 +23,31 @@ async def do(
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
) -> Resource:
try:
custom_dict = task_template.custom
input_dict = {
"inputs": literal_map_string_repr(inputs),
}

final_dict = format_dict("test", custom_dict, input_dict)
final_dict = self._get_final_dict(task_template, inputs)
return await self._process_webhook(final_dict)
except aiohttp.ClientError as e:
return Resource(phase=TaskExecution.FAILED, message=str(e))

async def _make_http_request(self, method: http.HTTPMethod, url: str, headers: dict, data: dict = None) -> tuple:
def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> dict:
custom_dict = task_template.custom
input_dict = {
"inputs": literal_map_string_repr(inputs),
}
return format_dict("test", custom_dict, input_dict)

async def _make_http_request(
self, method: http.HTTPMethod, url: str, headers: dict, data: dict, timeout: int
) -> tuple:
# TODO This is a potential performance bottleneck. Consider using a connection pool. To do this, we need to
# create a session object and reuse it for multiple requests. This will reduce the overhead of creating a new
# connection for each request. The problem for not doing so is local execution, does not have a common event
# loop and agent executor creates a new event loop for each request (in the mixin).
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
if method == http.HTTPMethod.GET:
response = await session.get(url, headers=headers, params=data)
else:
response = await session.post(url, json=data, headers=headers)
print(f"Response status: {response.status}")
text = await response.text()
return response.status, text

Expand Down Expand Up @@ -72,7 +77,9 @@ async def _process_webhook(self, final_dict: dict) -> Resource:
method = http.HTTPMethod(final_dict.get(METHOD_KEY))
show_data = final_dict.get(SHOW_DATA_KEY, False)
show_url = final_dict.get(SHOW_URL_KEY, False)
status, text = await self._make_http_request(method, url, headers, body)
timeout_sec = final_dict.get(TIMEOUT_SEC, 10)

status, text = await self._make_http_request(method, url, headers, body, timeout_sec)
if status != 200:
return Resource(
phase=TaskExecution.FAILED,
Expand Down
15 changes: 8 additions & 7 deletions flytekit/extras/webhook/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
TASK_TYPE = "webhook"
TASK_TYPE: str = "webhook"

URL_KEY = "url"
METHOD_KEY = "method"
HEADERS_KEY = "headers"
DATA_KEY = "data"
SHOW_DATA_KEY = "show_data"
SHOW_URL_KEY = "show_url"
URL_KEY: str = "url"
METHOD_KEY: str = "method"
HEADERS_KEY: str = "headers"
DATA_KEY: str = "data"
SHOW_DATA_KEY: str = "show_data"
SHOW_URL_KEY: str = "show_url"
TIMEOUT_SEC: str = "timeout_sec"
10 changes: 8 additions & 2 deletions flytekit/extras/webhook/task.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import http
from typing import Any, Dict, Optional, Type
from datetime import timedelta
from typing import Any, Dict, Optional, Type, Union

from flytekit import Documentation
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin

from ...core.interface import Interface
from .constants import DATA_KEY, HEADERS_KEY, METHOD_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, TASK_TYPE, URL_KEY
from .constants import DATA_KEY, HEADERS_KEY, METHOD_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, TASK_TYPE, TIMEOUT_SEC, URL_KEY


class WebhookTask(SyncAgentExecutorMixin, PythonTask):
Expand All @@ -26,6 +27,7 @@ def __init__(
show_data: bool = False,
show_url: bool = False,
description: Optional[str] = None,
timeout: Union[int, timedelta] = timedelta(seconds=10),
# secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon
):
"""
Expand Down Expand Up @@ -65,6 +67,8 @@ def __init__(
:param show_data: If True, the body of the request will be logged in the UI as the output of the task.
:param show_url: If True, the URL of the request will be logged in the UI as the output of the task.
:param description: Description of the task
:param timeout: The timeout for the request (connection and read). Default is 10 seconds. If int value is provided,
it is considered as seconds.
"""
if method not in {http.HTTPMethod.GET, http.HTTPMethod.POST}:
raise ValueError(f"Method should be either GET or POST. Got {method}")
Expand All @@ -86,6 +90,7 @@ def __init__(
self._data = data
self._show_data = show_data
self._show_url = show_url
self._timeout_sec = timeout if isinstance(timeout, int) else timeout.total_seconds()

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
config = {
Expand All @@ -95,5 +100,6 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
DATA_KEY: self._data or {},
SHOW_DATA_KEY: self._show_data,
SHOW_URL_KEY: self._show_url,
TIMEOUT_SEC: self._timeout_sec,
}
return config
48 changes: 46 additions & 2 deletions tests/flytekit/unit/extras/webhook/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import pytest

from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine
from flytekit.extras.webhook.agent import WebhookAgent
from flytekit.extras.webhook.constants import SHOW_DATA_KEY, DATA_KEY, METHOD_KEY, URL_KEY, HEADERS_KEY, SHOW_URL_KEY
from flytekit.extras.webhook.constants import SHOW_DATA_KEY, DATA_KEY, METHOD_KEY, URL_KEY, HEADERS_KEY, SHOW_URL_KEY, \
TIMEOUT_SEC
from flytekit.models.core.execution import TaskExecutionPhase as TaskExecution
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
Expand All @@ -19,6 +22,7 @@ def mock_task_template():
DATA_KEY: {"key": "value"},
SHOW_DATA_KEY: True,
SHOW_URL_KEY: True,
TIMEOUT_SEC: 60,
}
return task_template

Expand All @@ -40,8 +44,8 @@ async def test_do_post_success(mock_task_template, mock_aiohttp_session):
result = await agent.do(mock_task_template, output_prefix="", inputs=LiteralMap({}))

assert result.phase == TaskExecution.SUCCEEDED
assert result.outputs["info"]
assert result.outputs["info"]["status_code"] == 200
assert result.outputs["info"]["response_data"] == "Success"
assert result.outputs["info"]["url"] == "http://example.com"


Expand All @@ -61,6 +65,7 @@ async def test_do_get_success(mock_task_template, mock_aiohttp_session):

assert result.phase == TaskExecution.SUCCEEDED
assert result.outputs["info"]["status_code"] == 200
assert result.outputs["info"]["response_data"] == "Success"
assert result.outputs["info"]["url"] == "http://example.com"


Expand All @@ -76,3 +81,42 @@ async def test_do_failure(mock_task_template, mock_aiohttp_session):

assert result.phase == TaskExecution.FAILED
assert "Webhook failed with status code 500" in result.message


def test_conversion_of_inputs():
agent = WebhookAgent()
result = agent._build_response(200, "Success", {"key": "value"}, "http://example.com", True, True)

assert result["status_code"] == 200
assert result["response_data"] == "Success"
assert result["input_data"] == {"key": "value"}
assert result["url"] == "http://example.com"


def test_get_final_dict():
agent = WebhookAgent()
ctx = FlyteContextManager.current_context()
inputs = TypeEngine.dict_to_literal_map(ctx, {"x": "value_x", "y": "value_y", "z": "value_z"})
task_template = MagicMock(spec=TaskTemplate)
task_template.custom = {
URL_KEY: "http://example.com/{inputs.x}",
METHOD_KEY: "POST",
HEADERS_KEY: {"Content-Type": "application/json", "Authorization": "{inputs.y}"},
DATA_KEY: {"key": "{inputs.z}"},
SHOW_DATA_KEY: True,
SHOW_URL_KEY: True,
TIMEOUT_SEC: 60,
}
result = agent._get_final_dict(task_template, inputs)

expected_result = {
"url": "http://example.com/value_x",
"method": "POST",
"headers": {"Content-Type": "application/json", "Authorization": "value_y"},
"data": {"key": "value_z"},
"show_data": True,
"show_url": True,
"timeout_sec": 60,
}

assert result == expected_result
36 changes: 22 additions & 14 deletions tests/flytekit/unit/extras/webhook/test_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import http
from asyncio import timeout
from datetime import timedelta

from flytekit.configuration import SerializationSettings, ImageConfig
from flytekit.extras.webhook.constants import HEADERS_KEY, URL_KEY, METHOD_KEY, DATA_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, \
TIMEOUT_SEC
from flytekit.extras.webhook.task import WebhookTask


Expand All @@ -10,20 +14,22 @@ def test_webhook_task_constructor():
url="http://example.com",
method=http.HTTPMethod.POST,
headers={"Content-Type": "application/json"},
body={"key": "value"},
show_body=True,
data={"key": "value"},
show_data=True,
show_url=True,
description="Test Webhook Task"
description="Test Webhook Task",
timeout=60,
)

assert task.name == "test_task"
assert task._url == "http://example.com"
assert task._method == http.HTTPMethod.POST
assert task._headers == {"Content-Type": "application/json"}
assert task._body == {"key": "value"}
assert task._show_body is True
assert task._data == {"key": "value"}
assert task._show_data is True
assert task._show_url is True
assert task.docs.short_description == "Test Webhook Task"
assert task._timeout_sec == 60


def test_webhook_task_get_custom():
Expand All @@ -32,17 +38,19 @@ def test_webhook_task_get_custom():
url="http://example.com",
method=http.HTTPMethod.POST,
headers={"Content-Type": "application/json"},
body={"key": "value"},
show_body=True,
show_url=True
data={"key": "value"},
show_data=True,
show_url=True,
timeout=timedelta(minutes=1),
)

settings = SerializationSettings(image_config=ImageConfig.auto_default_image())
custom = task.get_custom(settings)

assert custom["url"] == "http://example.com"
assert custom["method"] == "POST"
assert custom["headers"] == {"Content-Type": "application/json"}
assert custom["body"] == {"key": "value"}
assert custom["show_body"] is True
assert custom["show_url"] is True
assert custom[URL_KEY] == "http://example.com"
assert custom[METHOD_KEY] == "POST"
assert custom[HEADERS_KEY] == {"Content-Type": "application/json"}
assert custom[DATA_KEY] == {"key": "value"}
assert custom[SHOW_DATA_KEY] is True
assert custom[SHOW_URL_KEY] is True
assert custom[TIMEOUT_SEC] == 60

0 comments on commit 8230a88

Please sign in to comment.