diff --git a/aio_pika/exchange.py b/aio_pika/exchange.py index 2e5cc3d1..68850fb0 100644 --- a/aio_pika/exchange.py +++ b/aio_pika/exchange.py @@ -36,7 +36,7 @@ def __init__( self.durable = durable self.internal = internal self.passive = passive - self.arguments = arguments or {} + self.arguments = frozenset(arguments.items()) if arguments else frozenset() def __str__(self) -> str: return self.name @@ -46,9 +46,42 @@ def __repr__(self) -> str: f"<{self.__class__.__name__}({self}):" f" auto_delete={self.auto_delete}," f" durable={self.durable}," - f" arguments={self.arguments!r})>" + f" arguments={dict(self.arguments)!r})>" ) + def __eq__(self, other: object) -> bool: + """ + Defines equality for Exchange objects. + Two exchanges are considered equal if their name, type, and + all boolean flags, and arguments are the same. + """ + if not isinstance(other, Exchange): + return False + return ( + self.name == other.name and + self._type == other._type and + self.auto_delete == other.auto_delete and + self.durable == other.durable and + self.internal == other.internal and + self.passive == other.passive and + self.arguments == other.arguments + ) + + def __hash__(self) -> int: + """ + Computes a hash for the Exchange object. + The hash is based on the same attributes used in __eq__. + """ + return hash(( + self.name, + self._type, + self.auto_delete, + self.durable, + self.internal, + self.passive, + self.arguments # This is a frozenset, which is hashable + )) + async def declare( self, timeout: TimeoutType = None, ) -> aiormq.spec.Exchange.DeclareOk: @@ -60,7 +93,7 @@ async def declare( auto_delete=self.auto_delete, internal=self.internal, passive=self.passive, - arguments=self.arguments, + arguments=dict(self.arguments), timeout=timeout, ) diff --git a/tests/test_amqp.py b/tests/test_amqp.py index a9a5a6ff..e114301d 100644 --- a/tests/test_amqp.py +++ b/tests/test_amqp.py @@ -1636,6 +1636,16 @@ async def test_get_exchange(self, connection, declare_exchange): assert exchange.name == exchange_passive.name + async def test_get_exchange_memory_leak(self, connection, declare_exchange): + channel = await self.create_channel(connection) + _ = await declare_exchange( + "passive", channel=channel, + ) + exchange_1 = await channel.get_exchange("passive") + exchange_2 = await channel.get_exchange("passive") + + assert exchange_1 == exchange_2 + async def test_get_queue(self, connection, declare_queue): channel = await self.create_channel(connection) name = get_random_name("passive", "queue")