# This file is a part of Remoulade.
#
# Copyright (C) 2017,2018 CLEARTYPE SRL <bogdan@cleartype.io>
#
# Remoulade is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
#
# Remoulade is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import time
from contextlib import contextmanager
from functools import partial
from queue import Empty, Full, LifoQueue
from threading import Lock, local
from typing import TYPE_CHECKING, Callable, List, Optional
from amqpstorm import AMQPChannelError, AMQPConnectionError, AMQPError, Channel, UriConnection
from typing_extensions import Final
from ..broker import Broker, Consumer, MessageProxy
from ..common import current_millis
from ..errors import ChannelPoolTimeout, ConnectionClosed, MessageNotDelivered, QueueJoinTimeout
from ..helpers.queues import dq_name, xq_name
from ..logging import get_logger
from ..message import Message
if TYPE_CHECKING:
from ..middleware import Middleware # noqa
#: The maximum amount of time a message can be in the dead queue.
DEAD_MESSAGE_TTL: Final[int] = 86400000 * 7
#: The max number of times to attempt an enqueue operation in case of
#: a connection error.
MAX_ENQUEUE_ATTEMPTS: Final[int] = 6
[docs]class RabbitmqBroker(Broker):
"""A broker that can be used with RabbitMQ.
Examples:
>>> RabbitmqBroker(url="amqp://guest:guest@127.0.0.1:5672")
Parameters:
confirm_delivery(bool): Wait for RabbitMQ to confirm that
messages have been committed on every call to enqueue.
Defaults to False.
url(str): The optional connection URL to use to determine which Rabbit server to connect to.
If None is provided, connection is made with 'amqp://guest:guest@localhost:5672'
middleware(list[Middleware]): The set of middleware that apply
to this broker.
max_priority(int): Configure the queues with x-max-priority to
support priority queue in RabbitMQ itself
channel_pool_size(int): Size of the channel pool
dead_queue_max_length(int|None): Max size of the dead queue. If None, no max size.
delivery_mode(int): 2 (persistent) to wait for message to be flushed to disk for confirmation (safer)
or 1 (transient) which don't (faster)
group_transaction(bool): If true, use transactions by default when running group and pipelines
"""
def __init__(
self,
*,
confirm_delivery: bool = False,
url: Optional[str] = None,
middleware: Optional[List["Middleware"]] = None,
max_priority: Optional[int] = None,
channel_pool_size: int = 200,
dead_queue_max_length: Optional[int] = None,
delivery_mode: int = 2,
group_transaction: bool = False,
):
super().__init__(middleware=middleware)
if max_priority is not None and not (0 < max_priority <= 255):
raise ValueError("max_priority must be a value between 0 and 255")
if dead_queue_max_length is not None and dead_queue_max_length <= 0:
raise ValueError("dead_queue_max_length must be strictly above 0")
if delivery_mode not in {1, 2}:
raise ValueError("Invalid value for delivery_mode, should be 1 for non-persistent 2, for persistent")
self.url = url or os.getenv("REMOULADE_RABBITMQ_URL") or ""
self.confirm_delivery = confirm_delivery
self.max_priority = max_priority
self.dead_queue_max_length = dead_queue_max_length
self._connection = None
self.queues = {}
self.state = local()
self.channel_pools = {
"confirm_delivery": ChannelPool(
channel_factory=partial(self.channel_factory, confirm_delivery=True), pool_size=channel_pool_size
),
"no_confirm_delivery": ChannelPool(
channel_factory=partial(self.channel_factory, confirm_delivery=False), pool_size=channel_pool_size
),
}
self.queues_declared = False
# we need a Lock on self._connection as it can be modified by multiple threads
self.lock = Lock()
self.group_transaction = group_transaction
self.delivery_mode = delivery_mode
self.actor_options.add("confirm_delivery")
@property
def connection(self):
"""The :class:amqpstorm.Connection` for the current
proccess. This property may change without notice.
"""
with self.lock:
if self._connection is None or self._connection.is_closed:
self._connection = UriConnection(self.url)
return self._connection
@connection.deleter
def connection(self):
with self.lock:
if self._connection is not None:
try:
self._connection.close()
except AMQPError:
pass
self._connection = None
def channel_factory(self, confirm_delivery):
channel = self.connection.channel()
if confirm_delivery:
channel.confirm_deliveries()
return channel
def get_channel_pool(self, confirm_delivery: bool):
if confirm_delivery:
return self.channel_pools["confirm_delivery"]
return self.channel_pools["no_confirm_delivery"]
@property
def default_channel_pool(self):
return self.get_channel_pool(self.confirm_delivery)
def clear_channel_pools(self):
self.channel_pools["confirm_delivery"].clear()
self.channel_pools["no_confirm_delivery"].clear()
[docs] def close(self) -> None:
"""Close all open RabbitMQ connections."""
self.logger.debug("Closing channels and connection...")
try:
del self.connection
except Exception: # pragma: no cover
self.logger.debug("Encountered an error while connection.", exc_info=True)
self._connection = None
self.clear_channel_pools()
self.logger.debug("Channels and connections closed.")
[docs] def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 5000) -> "_RabbitmqConsumer":
"""Create a new consumer for a queue.
Parameters:
queue_name(str): The queue to consume.
prefetch(int): The number of messages to prefetch.
timeout(int): The idle timeout in milliseconds.
Returns:
Consumer: A consumer that retrieves messages from RabbitMQ.
"""
try:
self._declare_rabbitmq_queues()
except (AMQPConnectionError, AMQPChannelError) as e:
if isinstance(e, AMQPConnectionError):
del self.connection
self.clear_channel_pools()
raise ConnectionClosed(e) from None
return _RabbitmqConsumer(self.connection, queue_name, prefetch, timeout)
def _declare_rabbitmq_queues(self):
"""Real Queue declaration to happen before enqueuing or consuming
Raises:
AMQPConnectionError or AMQPChannelError: If the underlying channel or connection has been closed.
"""
with self.default_channel_pool.acquire() as channel:
for queue_name in self.queues:
self._declare_queue(channel, queue_name)
self._declare_dq_queue(channel, queue_name)
self._declare_xq_queue(channel, queue_name)
[docs] def declare_queue(self, queue_name: str) -> None:
"""Declare a queue. Has no effect if a queue with the given
name already exists.
Parameters:
queue_name(str): The name of the new queue.
Raises:
ConnectionClosed: If the underlying channel or connection
has been closed.
"""
if queue_name not in self.queues:
self.emit_before("declare_queue", queue_name)
self.queues[queue_name] = None
self.emit_after("declare_queue", queue_name)
delayed_name = dq_name(queue_name)
self.delay_queues.add(delayed_name)
self.emit_after("declare_delay_queue", delayed_name)
def _build_queue_arguments(self, queue_name):
arguments = {
"x-dead-letter-exchange": "",
"x-dead-letter-routing-key": xq_name(queue_name),
}
if self.max_priority:
arguments["x-max-priority"] = self.max_priority
return arguments
def _declare_queue(self, channel, queue_name):
arguments = self._build_queue_arguments(queue_name)
return channel.queue.declare(queue=queue_name, durable=True, arguments=arguments)
def _declare_dq_queue(self, channel, queue_name):
arguments = self._build_queue_arguments(queue_name)
return channel.queue.declare(queue=dq_name(queue_name), durable=True, arguments=arguments)
def _declare_xq_queue(self, channel, queue_name):
arguments = {
# This HAS to be a static value since messages are expired
# in order inside of RabbitMQ (head-first).
"x-message-ttl": DEAD_MESSAGE_TTL,
}
if self.dead_queue_max_length:
arguments["x-max-length"] = self.dead_queue_max_length
return channel.queue.declare(queue=xq_name(queue_name), durable=True, arguments=arguments)
def _apply_delay(self, message: "Message", delay: Optional[int] = None) -> "Message":
if delay is not None:
message_eta = current_millis() + delay
queue_name = message.queue_name if delay is None else dq_name(message.queue_name)
message = message.copy(queue_name=queue_name, options={"eta": message_eta})
return message
@contextmanager
def tx(self):
with self.get_channel_pool(confirm_delivery=False).acquire() as channel:
with channel.tx:
self.state.channel_with_transaction = channel
try:
yield
finally:
self.state.channel_with_transaction = None
@property
def _has_transaction(self) -> bool:
return bool(getattr(self.state, "channel_with_transaction", None))
@contextmanager
def _get_channel(self, confirm_delivery: bool):
if self._has_transaction:
yield self.state.channel_with_transaction
else:
with self.get_channel_pool(confirm_delivery).acquire() as channel:
yield channel
def _enqueue(self, message: "Message", *, delay: Optional[int] = None) -> "Message":
"""Enqueue a message.
Parameters:
message(Message): The message to enqueue.
delay(int): The minimum amount of time, in milliseconds, to
delay the message by.
Raises:
ConnectionClosed: If the underlying channel or connection
has been closed.
"""
queue_name = message.queue_name
actor = self.get_actor(message.actor_name)
properties = {"delivery_mode": self.delivery_mode, "priority": message.options.get("priority", actor.priority)}
confirm_delivery = (
message.options.get("confirm_delivery", actor.options.get("confirm_delivery", self.confirm_delivery))
and not self._has_transaction
)
attempts = 1
while True:
try:
# I chose to do queue declaration only on first enqueuing, it should be sufficient but it do not
# resolve the case of queue deletion at runtime. But we do not want the overhead of queue creation on
# each enqueue
if not self.queues_declared:
self._declare_rabbitmq_queues()
self.queues_declared = True
self.logger.debug("Enqueueing message %r on queue %r.", message.message_id, queue_name)
with self._get_channel(confirm_delivery) as channel:
confirmation = channel.basic.publish(
exchange="", routing_key=queue_name, body=message.encode(), properties=properties
)
if confirm_delivery and not confirmation:
raise MessageNotDelivered("Message could not be delivered")
return message
except MessageNotDelivered:
attempts += 1
if self._has_transaction or attempts > MAX_ENQUEUE_ATTEMPTS:
raise
time.sleep(0.1) # wait a bit and retry
self.logger.debug("Retrying enqueue on message not delivered. [%d/%d]", attempts, MAX_ENQUEUE_ATTEMPTS)
except (AMQPConnectionError, AMQPChannelError) as e:
# Delete the channel (and the connection if needed) so that the
# next caller/attempt may initiate new ones of each.
if isinstance(e, AMQPConnectionError):
del self.connection
self.clear_channel_pools()
attempts += 1
if self._has_transaction or attempts > MAX_ENQUEUE_ATTEMPTS:
raise ConnectionClosed(e) from None
self.logger.debug("Retrying enqueue due to closed connection. [%d/%d]", attempts, MAX_ENQUEUE_ATTEMPTS)
[docs] def get_queue_message_counts(self, queue_name: str):
"""Get the number of messages in a queue. This method is only
meant to be used in unit and integration tests.
Parameters:
queue_name(str): The queue whose message counts to get.
Returns:
tuple: A triple representing the number of messages in the
queue, its delayed queue and its dead letter queue.
"""
with self.default_channel_pool.acquire() as channel:
queue_response = self._declare_queue(channel, queue_name)
dq_queue_response = self._declare_dq_queue(channel, queue_name)
xq_queue_response = self._declare_xq_queue(channel, queue_name)
return (
queue_response["message_count"],
dq_queue_response["message_count"],
xq_queue_response["message_count"],
)
[docs] def flush(self, queue_name: str) -> None:
"""Drop all the messages from a queue.
Parameters:
queue_name(str): The queue to flush.
"""
for name in (queue_name, dq_name(queue_name), xq_name(queue_name)):
with self.default_channel_pool.acquire() as channel:
channel.queue.purge(name)
[docs] def flush_all(self) -> None:
"""Drop all messages from all declared queues."""
for queue_name in self.queues:
self.flush(queue_name)
[docs] def join(
self, queue_name: str, min_successes: int = 10, idle_time: int = 100, *, timeout: Optional[int] = None
) -> None:
"""Wait for all the messages on the given queue to be
processed. This method is only meant to be used in tests to
wait for all the messages in a queue to be processed.
Warning:
This method doesn't wait for unacked messages so it may not
be completely reliable. Use the stub broker in your unit
tests and only use this for simple integration tests.
Parameters:
queue_name(str): The queue to wait on.
min_successes(int): The minimum number of times all the
polled queues should be empty.
idle_time(int): The number of milliseconds to wait between
counts.
timeout(Optional[int]): The max amount of time, in
milliseconds, to wait on this queue.
"""
deadline = timeout and time.monotonic() + timeout / 1000
successes = 0
while successes < min_successes:
if deadline and time.monotonic() >= deadline:
raise QueueJoinTimeout(queue_name)
total_messages = sum(self.get_queue_message_counts(queue_name)[:-1])
if total_messages == 0:
successes += 1
else:
successes = 0
time.sleep(idle_time / 1000)
class _RabbitmqConsumer(Consumer):
def __init__(self, connection, queue_name, prefetch, timeout):
try:
self.logger = get_logger(__name__, type(self))
self.channel = connection.channel()
self.channel.basic.qos(prefetch_count=prefetch)
self.channel.basic.consume(queue=queue_name, no_ack=False)
self.timeout = timeout
except (AMQPConnectionError, AMQPChannelError) as e:
raise ConnectionClosed(e) from None
def ack(self, message):
try:
message.ack()
except (AMQPConnectionError, AMQPChannelError) as e:
raise ConnectionClosed(e) from None
except Exception: # pragma: no cover
self.logger.warning("Failed to ack message.", exc_info=True)
def nack(self, message):
try:
message.nack(requeue=False)
except (AMQPConnectionError, AMQPChannelError) as e:
raise ConnectionClosed(e) from None
except Exception: # pragma: no cover
self.logger.warning("Failed to nack message.", exc_info=True)
def requeue(self, messages):
"""RabbitMQ automatically re-enqueues unacked messages when
consumers disconnect so this is a no-op.
"""
def __next__(self):
"""Return None if no value after timeout seconds"""
try:
deadline = time.monotonic() + self.timeout / 1000
message = None
while message is None and time.monotonic() < deadline:
try:
message = next(self.channel.build_inbound_messages(auto_decode=False, break_on_empty=True))
except StopIteration:
time.sleep(0.1)
return _RabbitmqMessage(message) if message else None
except (AMQPConnectionError, AMQPChannelError) as e:
raise ConnectionClosed(e) from None
def close(self):
try:
self.channel.close()
except (AMQPConnectionError, AMQPChannelError):
pass
class _RabbitmqMessage(MessageProxy):
def __init__(self, rabbitmq_message):
super().__init__(Message.decode(rabbitmq_message.body))
self._rabbitmq_message = rabbitmq_message
def ack(self):
self._rabbitmq_message.ack()
def nack(self, requeue):
self._rabbitmq_message.nack(requeue)
class ChannelPool:
"""A pool of channels that can be used by the RabbitmqBroker.
The pool uses a synchronized queue as a backend, making sure that two threads never end up sharing a channel.
The channels are created lazily as the reservation requests comes.
The ChannelPool should be used via the `acquire` context manager, to make sure that a used channel is properly put
back into the pool.
Examples:
>>> channel_pool = ChannelPool(channel_factory=lambda connection: connection.channel(), pool_size=5)
>>> with channel_pool.acquire() as channel:
... channel.basic.publish(...)
Parameters:
channel_factory(function): Function that will be called to create new channels.
pool_size(int): The max size of the pool.
"""
def __init__(self, channel_factory: Callable[[], Channel], *, pool_size: int):
self._channel_factory = channel_factory
self._pool: LifoQueue = LifoQueue(pool_size)
self._pool_size = pool_size
for _ in range(pool_size):
# The goal is to create lazily the channels, so Nones are put as stand-in for the channel that can be used
self.put(None)
@contextmanager
def acquire(self, timeout=None):
"""
Parameters:
timeout(int): The max number of second to wait when fetching a channel from the pool. If None, it will wait
indefinitely to get a channel. Default None.
Raises:
ChannelPoolTimeout: when the timeout for reserving a channel is run out.
"""
channel = self.get(timeout=timeout)
if channel is None:
channel = self._channel_factory()
try:
yield channel
finally:
if channel.is_closed:
self.put(None)
else:
self.put(channel)
def get(self, *, timeout=None):
try:
return self._pool.get(timeout=timeout)
except Empty as e:
raise ChannelPoolTimeout("Could not get any channel from the pool") from e
def put(self, channel):
try:
return self._pool.put_nowait(channel)
except Full:
pass
def __len__(self):
return self._pool.qsize()
def clear(self):
"""
This will empty the pool and fill it back with None.
It is best to use it inside a lock to avoid doing it multiple times.
"""
while len(self) > 0:
channel = self._pool.get_nowait()
if channel is not None:
try:
channel.close()
except AMQPError:
pass
for _ in range(self._pool_size):
self.put(None)