Skip to content

Commit d286b77

Browse files
authored
Add MessageEvent.await_with_auto_ack() (#8)
1 parent d1c3a36 commit d286b77

File tree

3 files changed

+81
-10
lines changed

3 files changed

+81
-10
lines changed

README.md

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,32 @@ async for event in client.listen_to_events():
8181
print(f"{short_description}:\n{body!s}")
8282
```
8383

84-
More complex example, that involves handling all possible events:
84+
More complex example, that involves handling all possible events, and auto-acknowledgement:
8585

8686
```python
8787
async with asyncio.TaskGroup() as task_group:
8888
async for event in client.listen_to_events():
8989
match event:
9090
case stompman.MessageEvent(body=body):
91-
# Validate message ASAP and ack/nack, so that server won't assume we're not reliable
92-
try:
93-
validated_message = MyMessageModel.model_validate_json(body)
94-
except ValidationError:
95-
await event.nack()
96-
raise
97-
98-
await event.ack()
99-
task_group.create_task(run_business_logic(validated_message))
91+
task_group.create_task(event.await_with_auto_ack(handle_message(body)))
10092
case stompman.ErrorEvent(message_header=short_description, body=body):
10193
logger.error(
10294
"Received an error from server", short_description=short_description, body=body, event=event
10395
)
10496
case stompman.HeartbeatEvent():
10597
task_group.create_task(update_healthcheck_status())
98+
99+
100+
async def handle_message(body: bytes) -> None:
101+
try:
102+
validated_message = MyMessageModel.model_validate_json(body)
103+
await run_business_logic(validated_message)
104+
except Exception:
105+
logger.exception("Failed to handle message", body=body)
106106
```
107107

108+
You can pass awaitable object (coroutine, for example) to `Message.await_with_auto_ack()`. In case of error, it will catch any exceptions, send NACK to server and propagate them to the caller. Otherwise, it will send ACK, acknowledging the message was processed successfully.
109+
108110
### Cleaning Up
109111

110112
stompman takes care of cleaning up resources automatically. When you leave the context of async context managers `stompman.Client()`, `client.subscribe()`, or `client.enter_transaction()`, the necessary frames will be sent to the server.

stompman/listening_events.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Awaitable
12
from dataclasses import dataclass, field
23
from typing import TYPE_CHECKING
34

@@ -36,6 +37,21 @@ async def nack(self) -> None:
3637
)
3738
)
3839

40+
async def await_with_auto_ack(
41+
self, awaitable: Awaitable[None], exception_types: tuple[type[Exception],] = (Exception,)
42+
) -> None:
43+
called_nack = False
44+
45+
try:
46+
await awaitable
47+
except exception_types:
48+
await self.nack()
49+
called_nack = True
50+
raise
51+
finally:
52+
if not called_nack:
53+
await self.ack()
54+
3955

4056
@dataclass
4157
class ErrorEvent:

tests/test_client.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,59 @@ async def test_ack_nack() -> None:
346346
assert_frames_between_lifespan_match(collected_frames, [message_frame, nack_frame, ack_frame])
347347

348348

349+
def get_mocked_message_event() -> tuple[MessageEvent, mock.AsyncMock, mock.AsyncMock]:
350+
ack_mock, nack_mock = mock.AsyncMock(), mock.AsyncMock()
351+
352+
class CustomMessageEvent(MessageEvent):
353+
ack = ack_mock
354+
nack = nack_mock
355+
356+
return (
357+
CustomMessageEvent(
358+
_frame=MessageFrame(
359+
headers={"destination": "destination", "message-id": "message-id", "subscription": "subscription"},
360+
body=b"",
361+
),
362+
_client=mock.Mock(),
363+
),
364+
ack_mock,
365+
nack_mock,
366+
)
367+
368+
369+
async def test_message_event_await_with_auto_ack_nack() -> None:
370+
event, ack, nack = get_mocked_message_event()
371+
372+
async def raises_runtime_error() -> None: # noqa: RUF029
373+
raise RuntimeError
374+
375+
with suppress(RuntimeError):
376+
await event.await_with_auto_ack(raises_runtime_error(), exception_types=(Exception,))
377+
378+
ack.assert_not_called()
379+
nack.assert_called_once_with()
380+
381+
382+
async def test_message_event_await_with_auto_ack_ack_raises() -> None:
383+
event, ack, nack = get_mocked_message_event()
384+
385+
async def func() -> None: # noqa: RUF029
386+
raise Exception # noqa: TRY002
387+
388+
with suppress(Exception):
389+
await event.await_with_auto_ack(func(), exception_types=(RuntimeError,))
390+
391+
ack.assert_called_once_with()
392+
nack.assert_not_called()
393+
394+
395+
async def test_message_event_await_with_auto_ack_ack_ok() -> None:
396+
event, ack, nack = get_mocked_message_event()
397+
await event.await_with_auto_ack(mock.AsyncMock()())
398+
ack.assert_called_once_with()
399+
nack.assert_not_called()
400+
401+
349402
async def test_send_message_and_enter_transaction_ok(monkeypatch: pytest.MonkeyPatch) -> None:
350403
body = b"hello"
351404
destination = "/queue/test"

0 commit comments

Comments
 (0)