Bootstrap

python多线程使用rabbitmq

python多线程使用rabbitmq

1. 介绍

‌RabbitMQ是一个开源的消息代理软件,遵循AMQP(高级消息队列协议)协议,主要用于在不同的应用程序之间进行异步通信。‌ RabbitMQ以其可靠性、灵活性、可扩展性和多语言支持等特点,在分布式系统、微服务架构等场景中得到了广泛应用。‌12

RabbitMQ的核心概念包括队列、交换机、路由键、绑定、生产者和消费者等。队列用于存储和转发消息,具有先进先出(FIFO)的特性,并且可以持久化存储消息以防止丢失。交换机用于实现消息路由,根据不同的路由规则将消息推送到指定的队列。生产者负责发送消息,而消费者则负责接收并处理这些消息。

RabbitMQ的架构基于生产者-消费者模型,通过队列实现消息的存储和转发。此外,RabbitMQ还支持虚拟主机(vhost),用于逻辑隔离,管理各自的Exchange、Queue和Binding。这些特性使得RabbitMQ在分布式系统中能够提供高效、可靠的消息传递服务,提高系统的可扩展性和响应速度。

2. 封装库

import json
import logging
import signal
import sys
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor

import pika


class FlaskRabbitMQ:

    def __init__(self, app=None, queue=None, heartbeat=60, max_retries=3, max_workers=500):
        self.app = app
        self.queue = queue
        self.config = None
        self.heartbeat = heartbeat
        self.max_retries = max_retries  # 设置最大重试次数

        self.rabbitmq_server_host = None
        self.rabbitmq_server_username = None
        self.rabbitmq_server_password = None

        self._channel = None
        self._rpc_class_list = []
        self.data = {}

        self.logger = logging.getLogger('Logger')
        self._thread_local = threading.local()  # 为每个线程存储独立通道
        self.executor = ThreadPoolExecutor(max_workers=max_workers)  # 创建线程池
        if app:
            self.init_app(app)

    def init_app(self, app=None):
        self.app = app
        if self.queue:
            self.queue.init_app(app)
        self.config = self.app.config
        self.valid_config()

    def valid_config(self):
        self.rabbitmq_server_host = self.config.get('RABBITMQ_HOST')
        self.rabbitmq_server_username = self.config.get('RABBITMQ_USERNAME')
        self.rabbitmq_server_password = self.config.get('RABBITMQ_PASSWORD')

    def _create_new_connection(self):
        credentials = pika.PlainCredentials(
            self.rabbitmq_server_username,
            self.rabbitmq_server_password
        )
        parameters = pika.ConnectionParameters(
            self.rabbitmq_server_host,
            credentials=credentials,
            heartbeat=self.heartbeat
        )
        return pika.BlockingConnection(parameters)

    def _get_connection(self):
        # 为每个线程创建独立连接
        if not hasattr(self._thread_local, 'connection') or self._thread_local.connection.is_closed:
            self._thread_local.connection = self._create_new_connection()
            print(f'创建新连接_thread_local.connection:{self._thread_local.connection}')
        return self._thread_local.connection

    def _get_channel(self):
        # 每个线程使用独立连接的通道
        connection = self._get_connection()
        return connection.channel()

    def temporary_queue_declare(self):
        return self.queue_declare(exclusive=True, auto_delete=True)

    def queue_declare(self, queue_name='', passive=False, durable=False, exclusive=False, auto_delete=False,
                      arguments=None):
        channel = self._get_channel()
        try:
            result = channel.queue_declare(
                queue=queue_name,
                passive=passive,
                durable=durable,
                exclusive=exclusive,
                auto_delete=auto_delete,
                arguments=arguments
            )
            return result.method.queue
        except pika.exceptions.ChannelClosedByBroker as e:
            if e.reply_code == 406 and "inequivalent arg 'durable'" in e.reply_text:
                self.logger.error(f"队列 '{queue_name}' 的持久化参数不匹配,正在删除并重新声明。")
                channel.queue_delete(queue=queue_name)
                result = channel.queue_declare(
                    queue=queue_name,
                    passive=passive,
                    durable=durable,
                    exclusive=exclusive,
                    auto_delete=auto_delete,
                    arguments=arguments
                )
                return result.method.queue
            else:
                self.logger.error(f"声明队列 '{queue_name}' 时出错: {e}")
                raise
        finally:
            channel.close()

    def queue_delete(self, queue_name):
        channel = self._get_channel()
        try:
            self._channel.queue_delete(queue=queue_name)
            self.logger.info(f"队列 '{queue_name}' 已成功删除。")
        except Exception as e:
            self.logger.error(f"删除队列 '{queue_name}' 失败: {e}")
            raise
        finally:
            channel.close()

    def exchange_bind_to_queue(self, type, exchange_name, routing_key, queue):
        channel = self._get_channel()
        try:
            channel.exchange_declare(exchange=exchange_name, exchange_type=type)
            channel.queue_bind(queue=queue, exchange=exchange_name, routing_key=routing_key)
        except Exception as e:
            self.logger.error(f"绑定队列 '{queue}' 到交换机 '{exchange_name}' 时出错: {e}")
            raise
        finally:
            channel.close()

    def exchange_declare(self, exchange_name, exchange_type):
        channel = self._get_channel()
        try:
            channel.exchange_declare(exchange=exchange_name, exchange_type=exchange_type)
        except Exception as e:
            self.logger.error(f"交换机 '{exchange_name}' 声明失败: {e}")
            raise
        finally:
            channel.close()

    def queue_bind(self, exchange_name, routing_key, queue_name):
        channel = self._get_channel()
        try:
            channel.queue_bind(queue=queue_name, exchange=exchange_name, routing_key=routing_key)
        except Exception as e:
            self.logger.error(f"队列 '{queue_name}' 绑定到交换机 '{exchange_name}' 时出错: {e}")
            raise
        finally:
            channel.close()

    def basic_consuming(self, queue_name, callback, arguments=None, auto_ack=False):
        channel = self._get_channel()
        try:
            channel.basic_consume(queue=queue_name, on_message_callback=callback, arguments=arguments,
                                  auto_ack=auto_ack)
        except Exception as e:
            self.logger.error(f"basic_consume 中的流失错误: {e}")
        finally:
            channel.close()

    def send_expire(self, body, exchange, key, properties=None, max_retries=3):
        channel = None  # 在外部初始化为 None
        try:
            # 创建新通道进行消息发布
            channel = self._get_channel()
            if properties:
                channel.basic_publish(
                    exchange=exchange,
                    routing_key=key,
                    body=body,
                    properties=properties
                )
            else:
                channel.basic_publish(
                    exchange=exchange,
                    routing_key=key,
                    body=body
                )
        except Exception as e:
            self.logger.error(f'推送消息异常:{e}')
        finally:
            if channel:  # 检查 channel 是否已定义
                channel.close()  # 关闭通道

    def send(self, body, exchange, key, corr_id=None):
        channel = self._get_channel()
        try:
            if not corr_id:
                channel.basic_publish(
                    exchange=exchange,
                    routing_key=key,
                    body=body
                )
            else:
                channel.basic_publish(
                    exchange=exchange,
                    routing_key=key,
                    body=body,
                    properties=pika.BasicProperties(
                        correlation_id=corr_id
                    )
                )
        finally:
            channel.close()  # 关闭通道

    def send_json(self, body, exchange, key, corr_id=None):
        data = json.dumps(body)
        self.send(data, exchange=exchange, key=key, corr_id=corr_id)

    def send_sync(self, body, key=None, timeout=5):
        if not key:
            raise Exception("The routing key is not present.")

        corr_id = str(uuid.uuid4())
        callback_queue = self.temporary_queue_declare()
        self.data[corr_id] = {
            'isAccept': False,
            'result': None,
            'reply_queue_name': callback_queue
        }

        channel = self._get_channel()

        try:
            # 设置消费回调
            channel.basic_consume(queue=callback_queue, on_message_callback=self.on_response, auto_ack=True)

            # 发送消息
            channel.basic_publish(
                exchange='',
                routing_key=key,
                body=body,
                properties=pika.BasicProperties(
                    reply_to=callback_queue,
                    correlation_id=corr_id,
                )
            )

            # 等待响应
            end = time.time() + timeout
            while time.time() < end:
                if self.data[corr_id]['isAccept']:
                    self.logger.info("已接收到 RPC 服务器的响应 => {}".format(self.data[corr_id]['result']))
                    return self.data[corr_id]['result']
                else:
                    time.sleep(0.3)
                    continue
            self.logger.error("获取响应超时。")
            return None
        finally:
            channel.close()  # 关闭通道

    def send_json_sync(self, body, key=None):
        if not key:
            raise Exception("The routing key is not present.")
        data = json.dumps(body)
        return self.send_sync(data, key=key)

    def accept(self, key, result):
        self.data[key]['isAccept'] = True
        self.data[key]['result'] = str(result)

        channel = self._get_channel()

        try:
            # 删除回复队列
            channel.queue_delete(queue=self.data[key]['reply_queue_name'])
        finally:
            channel.close()  # 关闭通道

    def on_response(self, ch, method, props, body):
        self.logger.info("接收到响应 => {}".format(body))
        corr_id = props.correlation_id
        self.accept(corr_id, body)

    def register_class(self, rpc_class):
        if not hasattr(rpc_class, 'declare'):
            raise AttributeError("The registered class must contains the declare method")
        self._rpc_class_list.append(rpc_class)

    def _run(self):
        # 注册所有声明的类
        for item in self._rpc_class_list:
            item().declare()

        # 遍历所有在 Queue 中注册的回调函数
        for (type, queue_name, exchange_name, routing_key, version, callback, auto_ack,
             thread_num) in self.queue._rpc_class_list:
            if type == ExchangeType.DEFAULT:
                if not queue_name:
                    # 如果队列名称为空,则声明一个临时队列
                    queue_name = self.temporary_queue_declare()
                elif version == 1:
                    self.basic_consuming(queue_name, callback, auto_ack=auto_ack)
                else:
                    self._channel.queue_declare(queue=queue_name, auto_delete=True)
                    self.basic_consuming(queue_name, callback)

            elif type in [ExchangeType.FANOUT, ExchangeType.DIRECT, ExchangeType.TOPIC]:
                if not queue_name:
                    # 如果队列名称为空,则声明一个临时队列
                    queue_name = self.temporary_queue_declare()
                elif version == 1:
                    arguments = {
                        'x-match': type,  # 设置 exchange_type
                        'routing_key': routing_key,  # 设置 routing_key
                    }
                    self.basic_consuming(queue_name, callback, arguments=arguments, auto_ack=auto_ack)
                else:
                    self._channel.queue_declare(queue=queue_name)
                    self.exchange_bind_to_queue(type, exchange_name, routing_key, queue_name)
                    # 消费队列
                    self.basic_consuming(queue_name, callback)

            # 启动指定数量的线程来处理消息
            for _ in range(thread_num):
                self.executor.submit(self._start_thread_consumer, queue_name, callback, auto_ack)

        self.logger.info(" * Flask RabbitMQ 应用正在消费中")

    def _start_thread_consumer(self, queue_name, callback, auto_ack) -> None:
        channel = self._get_channel()
        channel.basic_qos(prefetch_count=1)

        try:
            channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=auto_ack)
            channel.start_consuming()
        except Exception as e:
            self.logger.error(f"Error in consumer thread: {e}")
        finally:
            channel.close()  # 消费完成后关闭通道,但不关闭连接

    def shutdown_executor(self):
        # 关闭线程池
        if self.executor:
            self.logger.info("关闭线程池...")
            self.executor.shutdown(wait=True)  # 等待所有线程完成再关闭

    def close_rabbitmq_connection(self):
        if hasattr(self._thread_local, 'connection') and self._thread_local.connection.is_open:
            self._thread_local.connection.close()

    def signal_handler(self, sig, frame):
        self.logger.info('RabbitMQ开始停止...')
        # 关闭线程池
        self.shutdown_executor()
        # 关闭 RabbitMQ 连接
        self.close_rabbitmq_connection()
        sys.exit(0)

    def run(self):
        # 捕获终止信号以进行优雅关闭
        signal.signal(signal.SIGINT, self.signal_handler)
        signal.signal(signal.SIGTERM, self.signal_handler)
        self._run()


class ExchangeType:
    DEFAULT = 'default'
    DIRECT = "direct"
    FANOUT = "fanout"
    TOPIC = 'topic'


class Queue:
    """
    支持多线程的Queue类
    """

    def __init__(self) -> None:
        self._rpc_class_list = []
        self.app = None

    def __call__(self, queue=None, type=ExchangeType.DEFAULT, version=0, exchange='', routing_key='', auto_ack=False,
                 thread_num=1):
        def _(func):
            self._rpc_class_list.append((type, queue, exchange, routing_key, version, func, auto_ack, thread_num))

        return _

    def init_app(self, app=None):
        self.app = app

3. 初始化rabbitmq

rcv = FlaskRabbitMQ(queue=queue, heartbeat=15)
rpc = FlaskRabbitMQ(app=current_app)
def consume_mq(app):
    rpc.init_app(app)
    rcv.init_app(app)
    # 定义队列
    queues = [
        'test',
    ]

    for queue_name in queues:
        rcv.queue_declare(queue_name=queue_name)

    # 绑定交换机
    bindings = [
        ('record_exchange', 'test', 'test'),
    ]

    for exchange, queue, routing_key in bindings:
        rcv.exchange_bind_to_queue(
            exchange_name=exchange, type=ExchangeType.DIRECT, queue=queue, routing_key=routing_key
        )
    rcv.run()

4. 定义消费者

@queue(queue='test', type=ExchangeType.DIRECT, exchange='record_exchange',
       routing_key='test', version=1, thread_num=15)
def prop_code_signal_callback(ch, method, props, body):
    try:
        data = json.loads(body)
        loger.info(f'prop_code_signal -> data:{data}')

        # 打印当前线程名称
        current_thread = threading.current_thread().name
        loger.info(f'prop_code_signal ->当前线程名称: {current_thread}')

        # 业务逻辑
        # 处理成功,确认消息
        ch.basic_ack(delivery_tag=method.delivery_tag)
        
    except Exception as e:
       loger.error(f"test出现异常: {e}")

5. 推送消息

send_body = dict(test_id=test_id)
rpc = FlaskRabbitMQ(app=current_app)
rpc.send_expire(exchange='record_exchange', key='test', body=json.dumps(send_body))
;