Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions aio_pika/robust_connection_rrhost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import List, Optional, Any
from yarl import URL
from aio_pika.robust_connection import RobustConnection
from aio_pika.connection import make_url
from urllib.parse import urlparse
from .log import get_logger

log = get_logger(__name__)


class RobustConnectionRRHost:
"""
Robust AMQP connection with round-robin host selection.

This class manages a single RobustConnection instance internally,
cycling through provided URLs until a successful connection is made.
"""

def __init__(
self, urls: List[str], default_port: int = 5672, **kwargs: Any
):
"""
Initialize with a list of broker URLs, normalizing and applying
default port if missing.

:param urls: List of AMQP broker URLs as strings.
:param default_port: Default port used if not specified in URLs.
:param kwargs: Additional arguments passed to RobustConnection.
"""
self.urls: List[URL] = []
for url in urls:
parsed = urlparse(url)
if not parsed.scheme:
url = f"amqp://{url}"
url_obj = make_url(url)
if not url_obj.host:
raise ValueError(f"Host missing in URL {url_obj}")
if url_obj.port is None:
url_obj = URL.build(
scheme=url_obj.scheme,
user=url_obj.user,
password=url_obj.password,
host=url_obj.host,
port=default_port,
path=url_obj.path,
query=url_obj.query,
fragment=url_obj.fragment,
)
self.urls.append(url_obj)
self._current_index = 0
self._kwargs = kwargs
self._connection: Optional[RobustConnection] = None
self._connect_timeout: Optional[float] = None

async def connect(self, timeout: Optional[float] = None) -> None:
"""
Attempt to connect to one of the provided URLs in round-robin order.
"""
self._connect_timeout = timeout
last_exc: Optional[Exception] = None
for _ in range(len(self.urls)):
url = self.urls[self._current_index]
try:
self._connection = RobustConnection(url, **self._kwargs)
await self._connection.connect(timeout=timeout)
return
except Exception as e:
last_exc = e
self._current_index = (self._current_index + 1) % len(self.urls)
raise last_exc or RuntimeError("All connection attempts failed")

async def reconnect(self) -> None:
"""
Perform reconnection to the next URL in round-robin order.
"""
self._current_index = (self._current_index + 1) % len(self.urls)
try:
await self.connect(timeout=self._connect_timeout)
if self._connection:
await self._connection.reconnect_callbacks()
except Exception as e:
log.info(
f"Reconnect failed on {self.urls[self._current_index]}: {e}"
)

async def _on_connection_close(self, closing: Any) -> None:
"""
Internal callback triggered on connection close to attempt reconnection.
"""
if self._connection and not self._connection.is_closed:
await self.reconnect()
if self._connection:
await self._connection._on_connection_close(closing)

@property
def is_closed(self) -> bool:
return self._connection.is_closed if self._connection else True

async def close(self) -> None:
if self._connection:
await self._connection.close()

def __getattr__(self, name: str) -> Any:
if self._connection:
return getattr(self._connection, name)
raise AttributeError(
f"'RobustConnectionRRHost' object has no attribute '{name}'"
)

__all__ = (
"RobustConnectionRRHost"
)
55 changes: 55 additions & 0 deletions tests/test_robust_connection_rrhost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
from yarl import URL
from aio_pika.robust_connection_rrhost import RobustConnectionRRHost


@pytest.mark.asyncio
async def test_connect_with_rabbitmq_container(amqp_url):
conn = RobustConnectionRRHost([str(amqp_url)])
await conn.connect(timeout=2)
assert conn.urls[0].scheme == "amqp"
assert conn._connection is not None


@pytest.mark.asyncio
async def test_failover_with_rabbitmq_container(amqp_url):
urls = ["amqp://guest:guest@invalidhost:5672/", str(amqp_url)]
conn = RobustConnectionRRHost(urls)
await conn.connect(timeout=2)
assert any(u.scheme == "amqp" for u in conn.urls)
assert conn._connection is not None


@pytest.mark.asyncio
async def test_amqp_scheme_with_rabbitmq(amqp_url):
url = f"amqp://guest:guest@{amqp_url.host}:5672/"
conn = RobustConnectionRRHost([url])
assert conn.urls[0].scheme == "amqp"


@pytest.mark.asyncio
@pytest.mark.skip(reason="AMQPS not configured on the test server")
async def test_amqps_scheme_with_rabbitmq(amqp_url):
url = f"amqps://guest:guest@{amqp_url.host}:5671/"
conn = RobustConnectionRRHost([url])
await conn.connect(timeout=2)
assert conn.urls[0].scheme == "amqps"
assert conn._connection is not None


@pytest.mark.asyncio
async def test_no_scheme_defaults_to_amqp(amqp_url):
raw_url = f"guest:guest@{amqp_url.host}:5672"
url = f"amqp://{raw_url}"
parsed = URL(url)
conn = RobustConnectionRRHost([str(parsed)])
assert conn.urls[0].scheme == "amqp"


@pytest.mark.asyncio
async def test_host_and_port_only(amqp_url):
raw_url = f"{amqp_url.host}:5672"
url = f"amqp://{raw_url}"
parsed = URL(url)
conn = RobustConnectionRRHost([str(parsed)])
assert conn.urls[0].host == amqp_url.host