文章目录
0. 引言
本文将介绍在 C++ 中使用 ZeroMQ 实现并行计算,并结合自定义的序列化机制,以矩阵乘法作为示例,详细解析相关实现细节。
方法:
- 将矩阵分块,每个工作者负责计算一部分的子矩阵结果。
- 生产者负责拆分任务并发送子任务(如矩阵块的操作)。
- 所有工作者完成后,生产者收集并合并结果。
1. 示例概述
整个示例主要由以下几个模块组成:
- 序列化模块 (
serialization.h
):负责将复杂的数据结构(如任务和结果)序列化为字节流,便于通过 ZeroMQ 传输。 - 任务与结果结构体 (
task_result.h
):定义了计算任务(Task
)和计算结果(Result
)的数据结构。 - 并行计算实现 (
zmq_parallel_tasks.cpp
):主程序负责生成任务,分发给多个工作线程进行计算,并收集结果。
完整代码见:zmq_parallel_tasks
2. 自定义序列化模块
不同进程或线程间需要传输复杂的数据结构。serialization.h
提供了一套通用的序列化和反序列化函数,支持多种数据类型,包括基本类型、std::pair
、std::string
、容器和元组。
2.1 关键特性
- 模板化设计:通过模板和类型特征,支持多种数据类型的序列化与反序列化。
- 类型特征判断:利用
std::is_trivially_copyable
判断类型是否可以平凡复制,优化序列化效率。 - 递归支持复杂结构:通过递归调用,实现对嵌套结构(如容器内含容器)的序列化。
2.2 代码示例
以下是序列化模块的关键代码片段:
// serialization.h
#ifndef SERIALIZE_HPP
#define SERIALIZE_HPP
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
namespace Serialization {
// 判断类型是否平凡可复制
template <typename T>
constexpr bool IsTriviallyCopyable = std::is_trivially_copyable<T>::value;
// 序列化基本类型
template <typename T>
void serialize(std::ostream &os, const T &val, typename std::enable_if<IsTriviallyCopyable<T>, int>::type = 0) {
os.write(reinterpret_cast<const char *>(&val), sizeof(T));
}
// 反序列化基本类型
template <typename T>
void deserialize(std::istream &is, T &val, typename std::enable_if<IsTriviallyCopyable<T>, int>::type = 0) {
is.read(reinterpret_cast<char *>(&val), sizeof(T));
}
// 序列化 std::pair
template <typename K, typename V>
void serialize(std::ostream &os, const std::pair<K, V> &val) {
serialize(os, val.first);
serialize(os, val.second);
}
// 反序列化 std::pair
template <typename K, typename V>
void deserialize(std::istream &is, std::pair<K, V> &val) {
deserialize(is, val.first);
deserialize(is, val.second);
}
// 序列化 std::string
void serialize(std::ostream &os, const std::string &val) {
const std::size_t size = val.size();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
os.write(val.data(), size);
}
// 反序列化 std::string
void deserialize(std::istream &is, std::string &val) {
std::size_t size = 0;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
val.resize(size);
is.read(&val[0], size);
}
// 序列化容器
template <typename Container>
void serialize(
std::ostream &os, const Container &container,
typename std::enable_if<std::is_same<typename std::iterator_traits<typename Container::iterator>::value_type,
typename Container::value_type>::value,
int>::type = 0) {
const std::size_t size = container.size();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (const auto &item : container) {
serialize(os, item);
}
}
// 反序列化容器
template <typename Container>
void deserialize(
std::istream &is, Container &container,
typename std::enable_if<std::is_same<typename std::iterator_traits<typename Container::iterator>::value_type,
typename Container::value_type>::value,
int>::type = 0) {
std::size_t size = 0;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
container.clear();
container.reserve(size);
for (std::size_t i = 0; i < size; ++i) {
typename Container::value_type item;
deserialize(is, item);
container.emplace_back(std::move(item));
}
}
// 序列化元组
template <typename Tuple, std::size_t... Indices>
void serializeTuple(std::ostream &os, const Tuple &tup, std::index_sequence<Indices...>) {
int dummy[] = {(serialize(os, std::get<Indices>(tup)), 0)...};
static_cast<void>(dummy);
}
template <typename... Args>
void serialize(std::ostream &os, const std::tuple<Args...> &val) {
serializeTuple(os, val, std::index_sequence_for<Args...>{});
}
// 反序列化元组
template <typename Tuple, std::size_t... Indices>
void DeserializeTuple(std::istream &is, Tuple &tup, std::index_sequence<Indices...>) {
int dummy[] = {(deserialize(is, std::get<Indices>(tup)), 0)...};
static_cast<void>(dummy);
}
template <typename... Args>
void deserialize(std::istream &is, std::tuple<Args...> &val) {
DeserializeTuple(is, val, std::index_sequence_for<Args...>{});
}
} // namespace Serialization
#endif // SERIALIZE_HPP
通过上述模板函数,任意平凡复制的类型都可以直接进行序列化和反序列化,大大简化了数据传输过程。
3. 任务与结果结构体定义
task_result.h
定义了两个主要的数据结构:Task
和 Result
,分别表示一个计算任务和其对应的计算结果。这些结构体将通过序列化模块进行编码和解码,以便在不同线程或进程间传输。
3.1 结构体定义
// task_result.h
#ifndef TASK_RESULT_H_
#define TASK_RESULT_H_
#include <cstdint>
#include <vector>
/// 表示一个计算任务(例如矩阵乘法中的一行乘法)
struct Task {
std::int32_t row_index; ///< 矩阵 A 中的行索引
std::vector<std::int32_t> row; ///< 矩阵 A 的一行
std::vector<std::vector<std::int32_t>> B; ///< 矩阵 B
};
/// 表示计算结果
struct Result {
std::int32_t row_index; ///< 矩阵 C 中的行索引
std::vector<std::int32_t> result_row; ///< 计算得到的矩阵 C 的一行
};
#endif // TASK_RESULT_H_
3.2 设计思路
- Task:包含了矩阵 A 的一行及整个矩阵 B。通过将矩阵 A 的每一行作为独立任务,多个工作线程可以并行计算。
- Result:存储了计算得到的矩阵 C 的一行及其对应的行索引,便于在主线程中汇总结果。
4. ZeroMQ 并行计算实现
zmq_parallel_tasks.cpp
是整个系统的核心,负责生成任务、分发给多个工作线程进行计算,并收集结果。ZeroMQ 提供了多种通信模式,本系统采用了 PUSH/PULL 模式,适用于任务分发和结果收集。
4.1 核心流程
-
初始化:
- 生成随机矩阵 A 和 B。
- 初始化 ZeroMQ 上下文和套接字,分别用于发送任务和接收结果。
- 启动多个工作线程,每个线程连接到任务和结果的内部套接字。
-
任务分发:
- 将矩阵 A 的每一行作为一个
Task
,序列化后通过 ZeroMQ 发送给工作线程。
- 将矩阵 A 的每一行作为一个
-
结果收集:
- 主线程接收所有
Result
,根据行索引将结果填充到矩阵 C 中。
- 主线程接收所有
-
线程管理:
- 发送停止信号 (
"STOP"
) 给所有工作线程,确保其优雅退出。 - 等待所有工作线程完成。
- 发送停止信号 (
4.2 代码结构
// zmq_parallel_tasks.cpp
#include <cstdlib> // For rand()
#include <iostream>
#include <sstream>
#include <thread>
#include <vector>
#include <zmq.hpp>
#include "serialization.h"
#include "task_result.h"
using namespace Serialization;
/// 序列化 Task 结构体
void encode(std::ostream& os, const Task& task) {
Serialization::serialize(os, task.row_index);
Serialization::serialize(os, task.row);
Serialization::serialize(os, task.B);
}
/// 反序列化 Task 结构体
void decode(std::istream& is, Task& task) {
Serialization::deserialize(is, task.row_index);
Serialization::deserialize(is, task.row);
Serialization::deserialize(is, task.B);
}
/// 序列化 Result 结构体
void encode(std::ostream& os, const Result& result) {
Serialization::serialize(os, result.row_index);
Serialization::serialize(os, result.result_row);
}
/// 反序列化 Result 结构体
void decode(std::istream& is, Result& result) {
Serialization::deserialize(is, result.row_index);
Serialization::deserialize(is, result.result_row);
}
/// 生成随机矩阵
std::vector<std::vector<std::int32_t>> generateMatrix(uint32_t rows, uint32_t cols) {
std::vector<std::vector<std::int32_t>> matrix(rows, std::vector<std::int32_t>(cols));
for (auto& row : matrix) {
for (auto& elem : row) {
elem = rand() % 10; // 随机值 0-9
}
}
return matrix;
}
/// 显示矩阵
void displayMatrix(const std::vector<std::vector<std::int32_t>>& matrix) {
for (const auto& row : matrix) {
for (const auto& elem : row) {
std::cout << elem << " ";
}
std::cout << std::endl;
}
}
/// 工作线程函数
void workerThread(zmq::context_t& context) {
try {
// 接收任务的套接字
zmq::socket_t task_receiver(context, ZMQ_PULL);
task_receiver.connect("inproc://tasks");
// 发送结果的套接字
zmq::socket_t result_sender(context, ZMQ_PUSH);
result_sender.connect("inproc://results");
while (true) {
// 接收任务消息
zmq::message_t task_msg;
task_receiver.recv(task_msg, zmq::recv_flags::none);
// 反序列化任务
std::string task_str(static_cast<char*>(task_msg.data()), task_msg.size());
// 检查停止信号
if (task_str == "STOP") {
break; // 退出线程
}
std::istringstream iss(task_str);
Task task;
decode(iss, task);
// 计算结果
Result result;
result.row_index = task.row_index;
uint32_t cols = static_cast<uint32_t>(task.B[0].size());
result.result_row.resize(cols, 0);
for (uint32_t j = 0; j < cols; ++j) {
for (size_t k = 0; k < task.row.size(); ++k) {
result.result_row[j] += task.row[k] * task.B[k][j];
}
}
// 序列化结果
std::ostringstream oss;
encode(oss, result);
std::string result_str = oss.str();
// 发送结果消息
zmq::message_t result_msg(result_str.size());
memcpy(result_msg.data(), result_str.data(), result_str.size());
result_sender.send(result_msg, zmq::send_flags::none);
}
} catch (const zmq::error_t& e) {
std::cerr << "Worker thread ZMQ 错误: " << e.what() << std::endl;
} catch (const std::exception& e) {
std::cerr << "Worker thread 异常: " << e.what() << std::endl;
}
}
int main() {
try {
// 初始化随机数种子
std::srand(static_cast<unsigned int>(std::time(nullptr)));
// 初始化 ZeroMQ 上下文
zmq::context_t context(1);
// 绑定任务发送套接字
zmq::socket_t task_sender(context, ZMQ_PUSH);
task_sender.bind("inproc://tasks");
// 绑定结果接收套接字
zmq::socket_t result_receiver(context, ZMQ_PULL);
result_receiver.bind("inproc://results");
// 启动工作线程
const uint32_t num_workers = 4; // 工作线程数
std::vector<std::thread> workers;
for (uint32_t i = 0; i < num_workers; ++i) {
workers.emplace_back(workerThread, std::ref(context));
}
// 生成矩阵 A 和 B
const uint32_t rows = 10;
const uint32_t cols = 10;
std::vector<std::vector<std::int32_t>> A = generateMatrix(rows, cols);
std::vector<std::vector<std::int32_t>> B = generateMatrix(cols, rows); // 转置以优化乘法
std::cout << "矩阵 A:" << std::endl;
displayMatrix(A);
std::cout << "矩阵 B:" << std::endl;
displayMatrix(B);
// 发送任务
for (uint32_t i = 0; i < rows; ++i) {
Task task;
task.row_index = i;
task.row = A[i];
task.B = B;
// 序列化任务
std::ostringstream oss;
encode(oss, task);
std::string task_str = oss.str();
// 发送任务消息
zmq::message_t task_msg(task_str.size());
memcpy(task_msg.data(), task_str.data(), task_str.size());
task_sender.send(task_msg, zmq::send_flags::none);
}
// 发送停止信号
for (uint32_t i = 0; i < num_workers; ++i) {
std::string stop_str = "STOP";
zmq::message_t stop_msg(stop_str.size());
memcpy(stop_msg.data(), stop_str.data(), stop_str.size());
task_sender.send(stop_msg, zmq::send_flags::none);
}
// 接收结果
std::vector<std::vector<std::int32_t>> C(rows, std::vector<std::int32_t>(cols, 0));
for (uint32_t i = 0; i < rows; ++i) {
zmq::message_t result_msg;
result_receiver.recv(result_msg, zmq::recv_flags::none);
std::string result_str(static_cast<char*>(result_msg.data()), result_msg.size());
// 反序列化结果
std::istringstream iss(result_str);
Result result;
decode(iss, result);
// 存储结果
C[result.row_index] = result.result_row;
}
std::cout << "矩阵 C (结果):" << std::endl;
displayMatrix(C);
// 等待工作线程结束
for (auto& worker : workers) {
if (worker.joinable()) {
worker.join();
}
}
return 0;
} catch (const zmq::error_t& e) {
std::cerr << "主线程 ZMQ 错误: " << e.what() << std::endl;
return 1;
} catch (const std::exception& e) {
std::cerr << "主线程 异常: " << e.what() << std::endl;
return 1;
}
}
5. 代码详解
5.1 序列化与反序列化
在 zmq_parallel_tasks.cpp
中,通过 encode
和 decode
函数,将 Task
和 Result
结构体进行序列化和反序列化。这些函数利用了 serialization.h
中定义的通用序列化机制,确保数据在不同线程间传输时的一致性和完整性。
/// 序列化 Task 结构体
void encode(std::ostream& os, const Task& task) {
Serialization::serialize(os, task.row_index);
Serialization::serialize(os, task.row);
Serialization::serialize(os, task.B);
}
/// 反序列化 Task 结构体
void decode(std::istream& is, Task& task) {
Serialization::deserialize(is, task.row_index);
Serialization::deserialize(is, task.row);
Serialization::deserialize(is, task.B);
}
/// 序列化 Result 结构体
void encode(std::ostream& os, const Result& result) {
Serialization::serialize(os, result.row_index);
Serialization::serialize(os, result.result_row);
}
/// 反序列化 Result 结构体
void decode(std::istream& is, Result& result) {
Serialization::deserialize(is, result.row_index);
Serialization::deserialize(is, result.result_row);
}
5.2 工作线程函数
每个工作线程执行 workerThread
函数,负责接收任务、计算结果并发送回主线程。工作线程通过 ZeroMQ 的 PULL
套接字接收任务,通过 PUSH
套接字发送结果。
void workerThread(zmq::context_t& context) {
try {
// 连接任务接收和结果发送套接字
zmq::socket_t task_receiver(context, ZMQ_PULL);
task_receiver.connect("inproc://tasks");
zmq::socket_t result_sender(context, ZMQ_PUSH);
result_sender.connect("inproc://results");
while (true) {
// 接收任务消息
zmq::message_t task_msg;
task_receiver.recv(task_msg, zmq::recv_flags::none);
// 反序列化任务
std::string task_str(static_cast<char*>(task_msg.data()), task_msg.size());
// 检查停止信号
if (task_str == "STOP") {
break; // 退出线程
}
std::istringstream iss(task_str);
Task task;
decode(iss, task);
// 计算结果
Result result;
result.row_index = task.row_index;
uint32_t cols = static_cast<uint32_t>(task.B[0].size());
result.result_row.resize(cols, 0);
for (uint32_t j = 0; j < cols; ++j) {
for (size_t k = 0; k < task.row.size(); ++k) {
result.result_row[j] += task.row[k] * task.B[k][j];
}
}
// 序列化结果
std::ostringstream oss;
encode(oss, result);
std::string result_str = oss.str();
// 发送结果消息
zmq::message_t result_msg(result_str.size());
memcpy(result_msg.data(), result_str.data(), result_str.size());
result_sender.send(result_msg, zmq::send_flags::none);
}
} catch (const zmq::error_t& e) {
std::cerr << "Worker thread ZMQ 错误: " << e.what() << std::endl;
} catch (const std::exception& e) {
std::cerr << "Worker thread 异常: " << e.what() << std::endl;
}
}
5.3 主线程逻辑
主线程负责生成任务、发送给工作线程,并收集所有结果。其主要步骤包括:
- 初始化:生成随机矩阵 A 和 B,并显示它们。
- 任务分发:将矩阵 A 的每一行封装为一个
Task
,序列化后通过 ZeroMQ 发送给工作线程。 - 发送停止信号:任务发送完毕后,发送
"STOP"
消息给所有工作线程,指示它们退出。 - 结果收集:通过 ZeroMQ 接收所有
Result
,并根据行索引将结果填充到矩阵 C 中。 - 线程管理:等待所有工作线程结束,确保程序的正常退出。
int main() {
try {
// 初始化随机数种子
std::srand(static_cast<unsigned int>(std::time(nullptr)));
// 初始化 ZeroMQ 上下文
zmq::context_t context(1);
// 绑定任务发送套接字
zmq::socket_t task_sender(context, ZMQ_PUSH);
task_sender.bind("inproc://tasks");
// 绑定结果接收套接字
zmq::socket_t result_receiver(context, ZMQ_PULL);
result_receiver.bind("inproc://results");
// 启动工作线程
const uint32_t num_workers = 4; // 工作线程数
std::vector<std::thread> workers;
for (uint32_t i = 0; i < num_workers; ++i) {
workers.emplace_back(workerThread, std::ref(context));
}
// 生成矩阵 A 和 B
const uint32_t rows = 10;
const uint32_t cols = 10;
std::vector<std::vector<std::int32_t>> A = generateMatrix(rows, cols);
std::vector<std::vector<std::int32_t>> B = generateMatrix(cols, rows); // 转置以优化乘法
std::cout << "矩阵 A:" << std::endl;
displayMatrix(A);
std::cout << "矩阵 B:" << std::endl;
displayMatrix(B);
// 发送任务
for (uint32_t i = 0; i < rows; ++i) {
Task task;
task.row_index = i;
task.row = A[i];
task.B = B;
// 序列化任务
std::ostringstream oss;
encode(oss, task);
std::string task_str = oss.str();
// 发送任务消息
zmq::message_t task_msg(task_str.size());
memcpy(task_msg.data(), task_str.data(), task_str.size());
task_sender.send(task_msg, zmq::send_flags::none);
}
// 发送停止信号
for (uint32_t i = 0; i < num_workers; ++i) {
std::string stop_str = "STOP";
zmq::message_t stop_msg(stop_str.size());
memcpy(stop_msg.data(), stop_str.data(), stop_str.size());
task_sender.send(stop_msg, zmq::send_flags::none);
}
// 接收结果
std::vector<std::vector<std::int32_t>> C(rows, std::vector<std::int32_t>(cols, 0));
for (uint32_t i = 0; i < rows; ++i) {
zmq::message_t result_msg;
result_receiver.recv(result_msg, zmq::recv_flags::none);
std::string result_str(static_cast<char*>(result_msg.data()), result_msg.size());
// 反序列化结果
std::istringstream iss(result_str);
Result result;
decode(iss, result);
// 存储结果
C[result.row_index] = result.result_row;
}
std::cout << "矩阵 C (结果):" << std::endl;
displayMatrix(C);
// 等待工作线程结束
for (auto& worker : workers) {
if (worker.joinable()) {
worker.join();
}
}
return 0;
} catch (const zmq::error_t& e) {
std::cerr << "主线程 ZMQ 错误: " << e.what() << std::endl;
return 1;
} catch (const std::exception& e) {
std::cerr << "主线程 异常: " << e.what() << std::endl;
return 1;
}
}
5.4 关键实现细节
- 任务分发:主线程将矩阵 A 的每一行封装为一个
Task
,并序列化后通过ZMQ_PUSH
套接字发送给工作线程。 - 工作线程计算:每个工作线程接收到
Task
后,进行计算,生成Result
,并通过ZMQ_PUSH
套接字发送回主线程。 - 结果收集:主线程通过
ZMQ_PULL
套接字接收所有Result
,并根据行索引将结果填充到矩阵 C 中。 - 线程终止:主线程发送
"STOP"
消息给所有工作线程,工作线程接收到停止信号后退出循环,主线程等待所有线程结束。
6. 运行示例
假设我们已经编译了上述代码并命名为 zmq_parallel_tasks
,运行程序后,输出如下:
// g++ -o zmq_parallel_tasks zmq_parallel_tasks.cpp -O2 -pthread -lzmq
$ ./zmq_parallel_tasks
Matrix A:
2 7 7 9 9 3 7 1 2 2
3 0 1 0 0 4 1 8 3 5
3 0 8 0 2 5 1 4 1 3
3 3 3 0 2 2 6 2 6 8
4 9 0 8 2 0 4 3 1 8
8 4 0 8 6 2 4 8 6 5
1 2 0 4 2 5 9 0 7 5
8 3 6 9 1 8 1 6 3 2
6 3 9 6 2 5 0 8 3 8
5 7 0 5 3 5 0 4 5 9
Matrix B:
1 6 3 8 5 4 6 8 2 2
3 8 5 2 6 7 9 6 5 5
5 0 4 5 6 7 2 8 4 8
8 5 4 1 5 1 7 4 9 0
6 2 8 1 6 5 9 8 3 6
5 8 9 9 6 7 8 8 5 2
8 5 0 2 6 5 5 6 9 5
6 5 9 6 7 8 1 8 8 7
6 3 5 5 4 1 2 2 2 0
7 0 5 7 5 4 4 0 0 4
Matrix C (Result):
287 201 233 154 278 234 305 300 265 205
137 104 161 165 144 139 91 140 109 103
139 90 158 163 158 163 115 184 112 137
201 120 158 175 189 162 161 162 131 137
219 178 177 147 215 174 236 184 193 138
281 226 269 225 281 222 270 284 248 172
225 152 150 151 186 139 189 160 174 99
241 227 258 249 263 229 248 298 242 156
267 183 281 271 282 258 227 288 218 209
226 192 245 219 233 197 245 208 166 137
程序将显示矩阵 A 和 B 的内容,以及最终计算得到的矩阵 C。由于矩阵的元素是随机生成的,每次运行结果会有所不同。