Source code for testsuite.databases.rabbitmq.classes
import asyncio
import dataclasses
import aio_pika
class BaseError(Exception):
pass
class RabbitMqDisabledError(BaseError):
pass
[docs]@dataclasses.dataclass(frozen=True)
class ConnectionInfo:
"""RabbitMQ connection parameters"""
host: str
tcp_port: int
class Channel:
def __init__(self, channel: aio_pika.Channel):
self._channel = channel
async def __aenter__(self) -> 'Channel':
if not self._channel.is_initialized:
await self._channel.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self._channel.close(exc_val)
async def declare_exchange(
self,
exchange: str,
exchange_type: aio_pika.ExchangeType,
timeout: float = 1.0,
) -> None:
await self._channel.declare_exchange(
name=exchange,
type=exchange_type,
timeout=timeout,
)
async def declare_queue(self, queue: str, timeout: float = 1.0) -> None:
await self._channel.declare_queue(name=queue, timeout=timeout)
async def bind_queue(
self,
exchange: str,
queue: str,
routing_key: str,
timeout: float = 1.0,
):
async def _do_bind():
rmq_queue = await self._channel.get_queue(queue)
await rmq_queue.bind(exchange=exchange, routing_key=routing_key)
await asyncio.wait_for(_do_bind(), timeout=timeout)
async def publish(
self,
exchange: str,
routing_key: str,
body: bytes,
timeout: float = 1.0,
):
async def _do_publish():
rmq_exchange = await self._channel.get_exchange(name=exchange)
await rmq_exchange.publish(
aio_pika.Message(body=body),
routing_key=routing_key,
)
await asyncio.wait_for(_do_publish(), timeout=timeout)
async def consume(self, queue: str, count: int, timeout: float = 2.0):
async def _do_consume():
result = []
rmq_queue = await self._channel.get_queue(name=queue)
for i in range(count):
incoming_message = await rmq_queue.get()
if incoming_message is not None:
await incoming_message.ack()
result.append(
incoming_message.body[: incoming_message.body_size],
)
return result
return await asyncio.wait_for(_do_consume(), timeout=timeout)
class Client:
def __init__(self, connection_future):
self._connection_future = connection_future
self._connection = None
async def teardown(self):
if self._connection is not None:
await self._connection.close()
async def get_channel(self) -> Channel:
if self._connection is None:
self._connection = await self._connection_future
return Channel(
channel=self._connection.channel(publisher_confirms=True),
)
class Control:
def __init__(self, enabled: bool, conn_info: ConnectionInfo):
self._enabled = enabled
if self._enabled:
self._client = Client(
connection_future=aio_pika.connect_robust(
host=conn_info.host,
port=conn_info.tcp_port,
timeout=2.0,
),
)
async def teardown(self):
if self._enabled:
await self._client.teardown()
def get_channel(self) -> Channel:
if not self._enabled:
raise RabbitMqDisabledError
return self._client.get_channel()