本例使用std::packaged_task和lambda表达式构建了一个简单的线程池。
众所周知,线程池的核心就是一个后台运行的线程组和一个不断提交任务的threadsafe_queue。线程组不断从queue里面拿取任务,进行执行。
本例就是简单的线程池,没有做工作偷取线程这些,后面两节应该有。
本例是C++ Concurrency in Action一书的源码,但是原书附的源码有诸多错误,无法运行。
本例是一个能work的sample。
代码如下,
conanfile.txt
[requires]
boost/1.72.0
[generators]
cmake
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)
project(9_3_parallel_accumulate_thread_pool)
set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")
set ( CMAKE_CXX_FLAGS "-pthread")
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
include_directories(${INCLUDE_DIRS})
LINK_DIRECTORIES(${LINK_DIRS})
file( GLOB main_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
foreach( main_file ${main_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${main_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${main_file})
target_link_libraries(${file} ${CONAN_LIBS} pthread)
endforeach( main_file ${main_file_list})
threadsafe_queue.hpp
#ifndef _FREDRIC_THREAD_SAFE_QUEUE_HPP_
#define _FREDRIC_THREAD_SAFE_QUEUE_HPP_
#include <mutex>
#include <string>
#include <queue>
#include <memory>
#include <atomic>
#include <condition_variable>
#include <exception>
template <typename T>
class threadsafe_queue {
private:
struct node {
std::shared_ptr<T> data;
std::unique_ptr<node> next;
};
std::mutex head_mutex;
std::mutex tail_mutex;
std::unique_ptr<node> head;
node* tail;
std::condition_variable data_cond;
node* get_tail() {
std::lock_guard<std::mutex> tail_lock(tail_mutex);
return tail;
}
std::unique_ptr<node> pop_head() {
std::unique_ptr<node> old_head = std::move(head);
head = std::move(old_head->next);
return old_head;
}
std::unique_lock<std::mutex> wait_for_data() {
std::unique_lock<std::mutex> head_lock(head_mutex);
data_cond.wait(head_lock, [&]() {
return head.get() != get_tail();
});
return std::move(head_lock);
}
std::unique_ptr<node> wait_pop_head() {
std::unique_lock<std::mutex> head_lock(wait_for_data());
return pop_head();
}
std::unique_ptr<node> wait_pop_head(T& value) {
std::unique_lock<std::mutex> head_lock(wait_for_data());
value = std::move(*head->data);
return pop_head();
}
std::unique_ptr<node> try_pop_head() {
std::lock_guard<std::mutex> head_lock(head_mutex);
if(head.get() == get_tail()) {
return std::unique_ptr<node>();
}
return pop_head();
}
std::unique_ptr<node> try_pop_head(T& value) {
std::lock_guard<std::mutex> head_lock(head_mutex);
if(head.get() == get_tail()) {
return std::unique_ptr<node>();
}
value = std::move(*head->data);
return pop_head();
}
public:
threadsafe_queue():
head(new node), tail(head.get()) {}
threadsafe_queue(threadsafe_queue const&) = delete;
threadsafe_queue& operator=(threadsafe_queue const&) = delete;
void push(T new_value) {
std::shared_ptr<T> new_data(std::make_shared<T>(std::move(new_value)));
std::unique_ptr<node> p (new node);
{
std::lock_guard<std::mutex> tail_lock(tail_mutex);
tail->data = new_data;
node* const new_tail = p.get();
tail->next = std::move(p);
tail = new_tail;
}
data_cond.notify_one();
}
std::shared_ptr<T> wait_and_pop() {
std::unique_ptr<node> const old_head = wait_pop_head();
return old_head->data;
}
void wait_and_pop(T& value) {
wait_pop_head(value);
}
bool empty() {
std::lock_guard<std::mutex> head_lock(head_mutex);
return (head.get() == get_tail());
}
std::shared_ptr<T> try_pop() {
std::unique_ptr<node> old_head = try_pop_head();
return old_head ? old_head->data: std::shared_ptr<T>();
}
bool try_pop(T& value) {
std::unique_ptr<node> old_head = try_pop_head(value);
return old_head != nullptr;
}
};
#endif
thread_pool.hpp
#ifndef _FREDRIC_THREAD_POOL_HPP_
#define _FREDRIC_THREAD_POOL_HPP_
#include "thread_safe_queue.hpp"
#include <thread>
#include <vector>
#include <atomic>
#include <functional>
#include <utility>
#include <future>
#include <utility>
#include <functional>
#include <memory>
struct join_threads {
std::thread& operator[](int index) {
return threads[index];
}
void add_thread(std::thread&& thread) {
threads.emplace_back(std::move(thread));
}
~join_threads() {
for(std::thread& thread: threads) {
if(thread.joinable()) {
thread.join();
}
}
}
private:
std::vector<std::thread> threads;
};
class function_wrapper {
struct impl_base {
virtual void call() = 0;
virtual ~impl_base() {}
};
template <typename F>
struct impl_type: impl_base {
F f;
impl_type(F&& f_): f(std::move(f_)) {}
void call() {
f();
}
};
std::unique_ptr<impl_base> impl;
public:
function_wrapper() {}
// 这个wrapper wrapper的是 packaged_task
template <typename F>
function_wrapper(F&& f):
impl(new impl_type<F>(std::move(f))) {}
void call() {
impl->call();
}
function_wrapper(function_wrapper&& other): impl(std::move(other.impl)) {}
function_wrapper& operator=(function_wrapper&& other) {
impl = std::move(other.impl);
return *this;
}
function_wrapper(function_wrapper const&) = delete;
function_wrapper(function_wrapper&) = delete;
function_wrapper& operator=(function_wrapper const&) = delete;
};
class thread_pool {
std::atomic<bool> done;
threadsafe_queue<function_wrapper> work_queue;
join_threads joiner;
void work_thread() {
while(!done) {
function_wrapper task;
if(work_queue.try_pop(task)) {
task.call();
} else {
std::this_thread::yield();
}
}
}
public:
thread_pool():
done(false) {
unsigned const thread_count = std::thread::hardware_concurrency();
try {
for(unsigned i=0; i<thread_count; ++i) {
joiner.add_thread(std::thread(&thread_pool::work_thread, this));
}
} catch(...) {
done = true;
throw;
}
}
~thread_pool() {
done = true;
}
template <typename FunctionType>
std::future<typename std::result_of<FunctionType()>::type> submit(FunctionType f) {
typedef typename std::result_of<FunctionType()>::type result_type;
std::packaged_task<result_type()> task(std::move(f));
std::future<result_type> res = task.get_future();
work_queue.push(std::move(task));
return res;
}
};
#endif
main.cpp
#include "thread_pool.hpp"
#include <iostream>
#include <algorithm>
#include <numeric>
template <typename Iterator, typename T>
struct accumulate_block {
T operator()(Iterator first, Iterator last) {
return std::accumulate(first, last, T());
}
};
template <typename Iterator, typename T>
T parallel_accumulate(Iterator first, Iterator last, T init) {
unsigned long const length = std::distance(first, last);
if(!length) {
return init;
}
unsigned long const block_size = 25;
unsigned long const num_blocks = (length + block_size - 1)/block_size;
std::vector<std::future<T>> futures(num_blocks - 1);
thread_pool pool;
Iterator block_start = first;
for(unsigned long i=0; i<num_blocks - 1; ++i) {
Iterator block_end = block_start;
std::advance(block_end, block_size);
// 这里不能传引用,否则会产生竟态
futures[i] = pool.submit([=]() {
return accumulate_block<Iterator, T>()(block_start, block_end);
});
block_start = block_end;
}
T last_result = accumulate_block<Iterator, T>()(block_start, last);
T result = init;
for(unsigned long i=0; i<num_blocks - 1; ++i) {
result += futures[i].get();
}
result += last_result;
return result;
}
int main(int argc, char* argv[]) {
std::vector<int> v(100);
for(std::size_t i=0; i<100; ++i) {
v[i] = i+1;
}
int res = parallel_accumulate(v.begin(), v.end(), 0);
std::cout << "1 + 2 + ... + 100 = " << res << std::endl;
return EXIT_SUCCESS;
}
本例使用线程池实现了并发计算 1+2+3... +100的任务。
你也可以完成其他任务,思路就是先分批进行拆分,然后提交到线程池进行执行就可以。
执行完成之后有一个std::future对象可以用于获取结果。
程序执行效果如图,