Bootstrap

基于C++11实现的手写线程池

在实际的项目中,使用线程池是非常广泛的,所以最近学习了线程池的开发,在此做一个总结。
源码:https://github.com/Cheeron955/Handwriting-threadpool-based-on-C-17

项目介绍

项目分为两个部分,在初版的时候,使用了C++11中的知识,自己实现了Any类,Semaphore类以及Result类的开发,代码比较绕,但是有很多细节是值得学习的;最终版使用了C++17提供的future类,使得代码轻量化。接下来先看初版:

test.cpp

先从test.cpp开始剖析项目构造:

#include <iostream>
#include <chrono>
#include <thread>

#include "threadpool.h"

using ULong = unsigned long long;

class MyTask : public Task
{
public:
	MyTask(int begin, int end)
		:begin_(begin)
		,end_(end)
	{
		
	}

	Any run() 
	{
		std::cout << "tid:" << std::this_thread::get_id() << "begin!" << std::endl;
		std::this_thread::sleep_for(std::chrono::seconds(3));
		ULong sum = 0;
		for (ULong i = begin_; i < end_; i++)
		{
			sum += i;
		}

		std::cout << "tid:" << std::this_thread::get_id() << "end!" << std::endl;
		return sum;
	}
private:
	int begin_;
	int end_;
};
int main()
{
	{
		ThreadPool pool;
		pool.setMode(PoolMode::MODE_CACHED);
		pool.start(4);

		Result res1 = pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		Result res2 = pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		ULong sum1 = res1.get().cast_<ULong>();
		std::cout << sum1 << std::endl;
	}
	
	std::cout << "main over!" << std::endl;
	getchar();

}
  • 在main函数中,创建了一个ThreadPool对象,进入ThreadPool中:

ThreadPool

重要成员变量

std::unordered_map<int, std::unique_ptr<Thread>> threads_;

//初始的线程数量 
int initThreadSize_;

//记录当前线程池里面线程的总数量
std::atomic_int curThreadSize_;

//线程数量上限阈值
int threadSizeThresHold_;

//记录空闲线程的数量
std::atomic_int idleThreadSize_;

//任务队列
std::queue<std::shared_ptr<Task>> taskQue_;

//任务数量 需要保证线程安全
std::atomic_int taskSize_;

//任务队列数量上限阈值
int taskQueMaxThresHold_;

//任务队列互斥锁,保证任务队列的线程安全
std::mutex taskQueMtx_;

//表示任务队列不满
std::condition_variable notFull_;

//表示任务队列不空
std::condition_variable notEmpty_;

//等待线程资源全部回收
std::condition_variable exitCond_;

//当前线程池的工作模式
PoolMode poolMode_;

//表示当前线程池的启动状态
std::atomic_bool isPoolRuning_;

const int TASK_MAX_THRESHHOLD = INT32_MAX;
const int THREAD_MAX_THRESHHOLD = 100;
const int THREAD_MAX_IDLE_TIME = 60; //60s
  • 具体含义请看代码中注释

重要成员函数

  1. 构造函数
//线程池构造
ThreadPool::ThreadPool()
	: initThreadSize_(4)
	, taskSize_(0)
	, idleThreadSize_(0)
	, curThreadSize_(0)
	, threadSizeThresHold_(THREAD_MAX_THRESHHOLD)
	, taskQueMaxThresHold_(TASK_MAX_THRESHHOLD)
	, poolMode_(PoolMode::MODE_FIXED)
	, isPoolRuning_(false)
{
}
  • 进行了一系列的初始化,包括线程数量,阙值等等
  1. 析构函数
ThreadPool::~ThreadPool()
{
	isPoolRuning_ = false;
	//notEmpty_.notify_all();//把等待的叫醒 进入阻塞 会死锁

	std::unique_lock<std::mutex> lock(taskQueMtx_);
	//等待线程池里面所有的线程返回用户调用ThreadPool退出 两种状态:阻塞 正在执行任务中
	notEmpty_.notify_all();//把等待的叫醒 进入阻塞
	exitCond_.wait(lock, [&]()->bool {return threads_.size() == 0; });
}
  • 析构函数中,主要是回收线程池的资源,但是这里要注意notEmpty_.notify_all();位置,如果在获得锁之前就唤醒,可能会发生死锁问题,这个在下面还会在提到。
  1. 设置线程池的工作模式
void ThreadPool::setMode(PoolMode mode)
{
	if (checkRunningState()) return;
	poolMode_ = mode;
}
  1. 设置task任务队列上限阈值
void ThreadPool::setTaskQueMaxThreshHold(int threshhold)
{
	if (checkRunningState()) return;
	taskQueMaxThresHold_ = threshhold;
}
  1. 设置线程池的工作模式,支持fixed以及cached模式
enum class PoolMode
{
	MODE_FIXED, //固定数量的线程
	MODE_CACHED, //线程数量可动态增长
};

void ThreadPool::setMode(PoolMode mode)
{
	if (checkRunningState()) return;
	poolMode_ = mode;
}
  1. 设置task任务队列上限阈值
void ThreadPool::setTaskQueMaxThreshHold(int threshhold)
{
	if (checkRunningState()) return;
	taskQueMaxThresHold_ = threshhold;
}
  1. 设置线程池cached模式下线程阈值
void ThreadPool::setThreadSizeThreshHold(int threshhold)
{
	if (checkRunningState()) return;

	if (poolMode_ == PoolMode::MODE_CACHED)
	{
		threadSizeThresHold_ = threshhold;
	}
}
  1. 给线程池提交任务,这是重中之重,用来生产任务
Result ThreadPool::submitTask(std::shared_ptr<Task> sp)
{
	//获取锁
	std::unique_lock<std::mutex> lock(taskQueMtx_);

	//线程通信 等待任务队列有空余 并且用户提交任务最长不能阻塞超过1s 否则判断提交失败,返回
	if(!notFull_.wait_for(lock, std::chrono::seconds(1),
		[&]()->bool {return taskQue_.size() < (size_t)taskQueMaxThresHold_; }))
	{ 
		
		std::cerr << "task queue is full,submit task fail." << std::endl;
		return Result(sp, false);
	}

	//如果有空余,把任务放入任务队列中
	taskQue_.emplace(sp);
	taskSize_++;

	notEmpty_.notify_all();

	if (poolMode_ == PoolMode::MODE_CACHED 
		&& taskSize_>idleThreadSize_ 
		&& curThreadSize_ < threadSizeThresHold_)
	{

		std::cout << ">>> create new thread" << std::endl;

		//创建thread线程对象
		auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::threadFunc, this, std::placeholders::_1));
		//threads_.emplace_back(std::move(ptr)); //资源转移
		int threadId = ptr->getId();
		threads_.emplace(threadId, std::move(ptr));
		threads_[threadId]->start(); //启动线程

		//修改线程个数相关的变量
		curThreadSize_++;
		idleThreadSize_++;
	}

	//返回任务的Result对象
	return Result(sp);
}
  • 在submitTask函数中,首先这是生产任务的函数,所以我们要保证线程安全,获取锁;
  • 考虑到了如果有耗时严重的任务一直占用,线程,导致提交任务一直失败,所以等待1s提交失败以后会通知用户;
  • 此时队列里面的任务没有超过阙值,就把任务放入任务队列中,更新任务数;
  • 因为新放了任务,任务队列不空了,在notEmpty_上进行通知,赶快分配线程执行任务;
  • cached模式下,需要根据任务数量和空闲线程的数量,判断是否需要创建新的线程出来,如果任务数大于现有的空闲线程数并且没有超过阙值,就增加线程,修改相关数量;
  • 返回任务的Result对象
  1. 开启线程池
void ThreadPool::start(int initThreadSize)
{
	//设置线程池的运行状态
	isPoolRuning_ = true;

	//记录初始线程个数
	initThreadSize_ = initThreadSize;
	curThreadSize_ = initThreadSize;

	//创建线程对象
	for (int i = 0; i < initThreadSize_; i++)
	{
		auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::threadFunc, this,std::placeholders::_1));
		
		int threadId = ptr->getId();
		threads_.emplace(threadId, std::move(ptr));
	}
	 
	//启动所有线程 std::vector<Thread*> threads_;
	for (int i = 0; i < initThreadSize_; i++)
	{
		threads_[i]->start(); //需要执行一个线程函数

		//记录初始空闲线程的数量
		idleThreadSize_++;
	}
}
  • 设置线程池的运行状态,如果线程在运行状态了,之前所有的设置相关的函数都不能运行了,记录初始相关数量
  • 创建线程对象,把线程函数threadFunc给到thread线程对象,使用绑定器,获取线程id,方便回收线程资源;
  • 加入线程列表std::unordered_map<int, std::unique_ptr<Thread>> 类型;
  • 启动所有线程,执行线程函数,threadFunc
void Thread::start()
{
	std::thread t(func_,threadId_);
	t.detach();
}
  1. 线程函数,从任务队列里面消费任务
void ThreadPool::threadFunc(int threadid) //线程函数返回,相应的线程就结束了
{
	auto lastTime = std::chrono::high_resolution_clock().now();

	for(;;)
	{
		std::shared_ptr<Task> task;
		
		{
			//获取锁
			std::unique_lock<std::mutex> lock(taskQueMtx_);

			std::cout << "tid:" << std::this_thread::get_id() 
				<< "尝试获取任务..." << std::endl;

				while ( taskQue_.size() == 0 )
				{

					if (!isPoolRuning_)
					{
						threads_.erase(threadid);
						std::cout << "threadid:" << std::this_thread::get_id()
							<< "exit!" << std::endl;

						//通知主线程线程被回收了,再次查看是否满足条件
						exitCond_.notify_all();
						return;
					}

					if (poolMode_ == PoolMode::MODE_CACHED)
					{	//超时返回std::cv_status::timeout
						if (std::cv_status::timeout ==
							notEmpty_.wait_for(lock, std::chrono::seconds(1)))
						{
							auto now = std::chrono::high_resolution_clock().now();
							auto dur = std::chrono::duration_cast<std::chrono::seconds>(now - lastTime);
							if (dur.count() >= THREAD_MAX_IDLE_TIME
								&& curThreadSize_ > initThreadSize_)
							{

								threads_.erase(threadid);
								curThreadSize_--;
								idleThreadSize_--;

								std::cout << "threadid:" << std::this_thread::get_id()
									<< "exit!" << std::endl;

								return;
							}
						}
					}
					else
					{
						//等待notEmpty_条件
						notEmpty_.wait(lock);
					}

					/*if (!isPoolRuning_)
					{
						threads_.erase(threadid);
						std::cout << "threadid:" << std::this_thread::get_id()
							<< "exit!" << std::endl;

						exitCond_.notify_all();
						return;
					}*/
				}
			
			idleThreadSize_--;

			std::cout << "tid:" << std::this_thread::get_id()
				<< "获取任务成功..." << std::endl;

			//从任务队列中取一个任务出来
			task = taskQue_.front();
			taskQue_.pop();
			taskSize_--;	

			//若依然有剩余任务,继续通知其他线程执行任务
			if (taskQue_.size() > 0)
			{
				notEmpty_.notify_all();
			}

			notFull_.notify_all();

		}//释放锁,使其他线程获取任务或者提交任务

		if (task != nullptr)
		{
			task->exec();
		}

		
		idleThreadSize_++;
		
		auto lastTime = std::chrono::high_resolution_clock().now();
	}
}
  • 获取任务开始的时间,便于在cached模式下,判断是否需要回收线程
  • 创造一个Task类,获取锁
class Task
{
public:

	Task();
	~Task()=default;

	void exec();

	void setResult(Result*res);

	//用户可以自定义任意任务类型,从Task继承,重写run方法,实现自定义任务处理
	virtual Any run() = 0;
private:
	Result* result_; //Result的生命周期》Task的
};
  • 如果此时任务队列里没有任务,并且主函数退出了,此时会在ThreadPool析构中设置isPoolRuning_为false,这时候就该回收线程资源了,并通知析构函数是否满足条件;
  • 如果isPoolRuning_为ture,但是在cached模式下,根据当前时间和上一次线程使用时间,判断有没有超过60s,如果超过了,并且当前线程数大于初始定义,说明不需要那么多线程了就需要回收线程资源;
  • 如果不在cached模式,就阻塞等待任务队列里面有任务
  • 获取成功任务,取出,如果队列里面还有任务,继续通知。并且取完任务,消费了一个任务 进行通知可以继续提交生产任务了,释放锁,使其他线程获取任务或者提交任务;
  • 执行任务,把任务的返回值通过setVal方法给到Result;
  • 线程处理完了,更新线程执行完任务调度的时间
  1. 检查线程池状态
bool ThreadPool::checkRunningState() const
{
	return isPoolRuning_;
}
  1. 执行任务
void Task::exec()
{
	if (result_ != nullptr)
	{
		result_->setVal(run()); //多态调用,run是用户的任务
	}
}

void Task::setResult(Result* res)
{
	result_ = res;
}
  • 把任务的返回值通过setVal方法给到Result
  1. 信号量类
class Semaphore
{
public:
	Semaphore(int limit = 0)
		:resLimit_(limit)
	{}

	~Semaphore() = default;

	void wait()
	{  
		std::unique_lock<std::mutex> lock(mtx_);
		//等待信号量有资源 没有资源的话 会阻塞当前线程
		cond_.wait(lock, [&]()->bool { return resLimit_ > 0; });
		resLimit_--;
	}

	void post()
	{
		std::unique_lock<std::mutex> lock(mtx_);
		resLimit_++;

		cond_.notify_all();
	}
private:

	int resLimit_;
	std::mutex mtx_;
	std::condition_variable cond_;

};
  • 在信号量类中使用了条件变量和互斥锁实现了信号量的实现,等待信号量资源和释放信号量资源。
  1. Any类
class Any
{
public:
	Any() = default;
	~Any() = default;

	//左值
	Any(const Any&) = delete;
	Any& operator=(const Any&) = delete;

	//右值
	Any(Any&&) = default;
	Any& operator=(Any&&) = default;

	template<typename T>


	Any(T data) :base_(std::make_unique<Derive<T>>(data))
	{}


	template<typename T>
	T cast_()
	{
		Derive<T> *pd = dynamic_cast<Derive<T>*>(base_.get());
		if (pd == nullptr)
		{
			throw "type is unmatch";
		}
		return pd->data_;
	}

private:
	//基类类型
	class Base
	{
	public:
		virtual ~Base() = default;
	};

	//派生类类型
	template<typename T>//模板
	class Derive :public Base
	{
	public:
		Derive(T data) : data_(data)
		{}
		T data_; //保存了任意的其他类型
	};

private:
	//定义一个基类指针,基类指针可以指向派生类对象
	std::unique_ptr<Base> base_;
};
  • 定义了一个基类Base
  • 定义了一个模板类的派生类类型,继承Base,其中保存了任意的其他类型;
  • 对象包在派生类对象里面,通过基类指针指向派生类对象,构造函数可以让Any类型接收任意其他的数据类型,用户就可以使用任意期望的类型;
  • cast_()方法把Any对象里面存储的data数据提取出来,基类指针指向 派生类指针 ,使用强转dynamic_cast将基类指针或引用转换为派生类指针或引用,获取了指向的Derive对象,然后提取出来data_;
  1. Result方法的实现
class Result
{
public:

	Result(std::shared_ptr<Task> task, bool isValid = true);
	~Result() = default;

	//setVal方法,获取任务执行完的返回值
	void setVal(Any any);

	//用户调用get方法,获取task的返回值
	Any get();
private:
	//存储任务的返回值
	Any any_;

	//线程通信信号量
	Semaphore sem_;

	//指向对应获取返回值的任务对象
	std::shared_ptr<Task> task_;

	//返回值是否有效
	std::atomic_bool isValid_;
};


Result::Result(std::shared_ptr<Task> task, bool isValid)
		:isValid_(isValid)
		,task_(task)
{
	task_->setResult(this);
}

Any Result::get()
{
	if (!isValid_)
	{
		return " ";
	}

	//task任务如果没有执行完,这里会阻塞用户的线程
	sem_.wait();
	return std::move(any_);
}

void Result::setVal(Any any)
{
	//存储task的返回值
	this->any_ = std::move(any);

	//已经获取了任务的返回值,增加信号量资源
	sem_.post();
}
  • Result 实现接受提交到线程池的task任务执行完成后的返回值类型result;
  • 设置了setVal方法,获取任务执行完的返回值和用户调用get方法,获取task的返回值,使用了信号量等到setVal设置成功,才能获取值,否则会进入阻塞;

回到test.cpp

  • 定义了一个ThreadPool对象,默认是固定的,可以修改为cached模式,然后开启线程(可以使用hardware_concurrency()获取cpu核心数量);
  • 提交任务submitTask;
  • 出 } 进行析构

举个栗子~

在cached模式,代码如上test.cpp
在这里插入图片描述
可以看到,目前四个线程,六个任务,所以创建了两个线程;
六个线程获取任务成功,然后释放资源成功;

固定线程:

int main()
{
	{
		ThreadPool pool;
		pool.start(4);

		Result res1 = pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		Result res2 = pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		pool.submitTask(std::make_shared<MyTask>(1, 1000000));
		pool.submitTask(std::make_shared<MyTask>(1, 1000000));

		ULong sum1 = res1.get().cast_<ULong>();
		std::cout << sum1 << std::endl;
	}
	
	std::cout << "main over!" << std::endl;
	getchar();
}

在这里插入图片描述
有四个线程,五个任务,11676线程获取了两次任务,最后回收线程资源。

好了~基于C++11实现的手写线程池,就到此结束了。除此之外,在GitHub上,提供了linux下的使用方法,感兴趣的小伙伴可以按照步骤实现一下 ~ 注意死锁问题!下一节会剖析基于C++17实现的手写线程池,代码会看起来很轻便,下一节见 ~
;