diff --git a/packages/faststream-stomp/faststream_stomp/broker.py b/packages/faststream-stomp/faststream_stomp/broker.py index 4158678..6d4563e 100644 --- a/packages/faststream-stomp/faststream_stomp/broker.py +++ b/packages/faststream-stomp/faststream_stomp/broker.py @@ -82,6 +82,7 @@ def __init__( name="stomp", default_context={"channel": ""}, message_id_ln=self.__max_msg_id_ln ), ) + self._attempted_to_connect = False async def start(self) -> None: await super().start() @@ -91,6 +92,9 @@ async def start(self) -> None: await handler.start() async def _connect(self, client: stompman.Client) -> stompman.Client: # type: ignore[override] + if self._attempted_to_connect: + return client + self._attempted_to_connect = True self._producer = StompProducer(client) return await client.__aenter__() diff --git a/packages/faststream-stomp/pyproject.toml b/packages/faststream-stomp/pyproject.toml index fd74a0e..b13182e 100644 --- a/packages/faststream-stomp/pyproject.toml +++ b/packages/faststream-stomp/pyproject.toml @@ -24,7 +24,7 @@ requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [dependency-groups] -dev = ["faststream[otel,prometheus]~=0.5"] +dev = ["faststream[otel,prometheus]~=0.5", "asgi-lifespan"] [tool.hatch.version] source = "vcs" diff --git a/packages/faststream-stomp/test_faststream_stomp/test_integration.py b/packages/faststream-stomp/test_faststream_stomp/test_integration.py index 33c4bc6..9897d26 100644 --- a/packages/faststream-stomp/test_faststream_stomp/test_integration.py +++ b/packages/faststream-stomp/test_faststream_stomp/test_integration.py @@ -8,7 +8,9 @@ import faststream_stomp import pytest import stompman +from asgi_lifespan import LifespanManager from faststream import BaseMiddleware, Context, FastStream +from faststream.asgi import AsgiFastStream from faststream.broker.message import gen_cor_id from faststream.broker.middlewares.logging import CriticalLogMiddleware from faststream.exceptions import AckMessage, NackMessage, RejectMessage @@ -217,3 +219,9 @@ def some_handler(message_frame: Annotated[stompman.MessageFrame, Context("messag mock.call(logging.ERROR, "MyError: ", extra=extra, exc_info=MyError()), mock.call(logging.INFO, "Processed", extra=extra), ] + + +async def test_broker_connect_twice(broker: faststream_stomp.StompBroker) -> None: + app = AsgiFastStream(broker, on_startup=[broker.connect]) + async with LifespanManager(app): + pass