Caffe2 - C++ API
A deep learning, cross platform ML framework
thread_pool.h
1 
17 #ifndef CAFFE2_UTILS_THREAD_POOL_H_
18 #define CAFFE2_UTILS_THREAD_POOL_H_
19 
20 #include <condition_variable>
21 #include <functional>
22 #include <mutex>
23 #include <queue>
24 #include <thread>
25 #include <utility>
26 
27 namespace caffe2 {
28 
30  private:
31  struct task_element_t {
32  bool run_with_id;
33  const std::function< void() > no_id;
34  const std::function< void(std::size_t) > with_id;
35 
36  explicit task_element_t(const std::function< void() >& f) :
37  run_with_id(false), no_id(f), with_id(nullptr) { }
38  explicit task_element_t(const std::function< void(std::size_t) >& f) :
39  run_with_id(true), no_id(nullptr), with_id(f) { }
40  };
41  std::queue<task_element_t> tasks_;
42  std::vector<std::thread> threads_;
43  std::mutex mutex_;
44  std::condition_variable condition_;
45  std::condition_variable completed_;
46  bool running_;
47  bool complete_;
48  std::size_t available_;
49  std::size_t total_;
50 
51  public:
53  explicit TaskThreadPool(std::size_t pool_size)
54  : threads_(pool_size), running_(true), complete_(true),
55  available_(pool_size), total_(pool_size) {
56  for ( std::size_t i = 0; i < pool_size; ++i ) {
57  threads_[i] = std::thread(
58  std::bind(&TaskThreadPool::main_loop, this, i));
59  }
60  }
61 
64  // Set running flag to false then notify all threads.
65  {
66  std::unique_lock< std::mutex > lock(mutex_);
67  running_ = false;
68  condition_.notify_all();
69  }
70 
71  try {
72  for (auto& t : threads_) {
73  t.join();
74  }
75  }
76  // Suppress all exceptions.
77  catch (const std::exception&) {}
78  }
79 
81  template <typename Task>
82  void runTask(Task task) {
83  std::unique_lock<std::mutex> lock(mutex_);
84 
85  // Set task and signal condition variable so that a worker thread will
86  // wake up and use the task.
87  tasks_.push(task_element_t(static_cast<std::function< void() >>(task)));
88  complete_ = false;
89  condition_.notify_one();
90  }
91 
92  void run(const std::function<void()>& func) {
93  runTask(func);
94  }
95 
96  template <typename Task>
97  void runTaskWithID(Task task) {
98  std::unique_lock<std::mutex> lock(mutex_);
99 
100  // Set task and signal condition variable so that a worker thread will
101  // wake up and use the task.
102  tasks_.push(task_element_t(static_cast<std::function< void(std::size_t) >>(
103  task)));
104  complete_ = false;
105  condition_.notify_one();
106  }
107 
110  std::unique_lock<std::mutex> lock(mutex_);
111  while (!complete_)
112  completed_.wait(lock);
113  }
114 
115  private:
117  void main_loop(std::size_t index) {
118  while (running_) {
119  // Wait on condition variable while the task is empty and
120  // the pool is still running.
121  std::unique_lock<std::mutex> lock(mutex_);
122  while (tasks_.empty() && running_) {
123  condition_.wait(lock);
124  }
125  // If pool is no longer running, break out of loop.
126  if (!running_) break;
127 
128  // Copy task locally and remove from the queue. This is
129  // done within its own scope so that the task object is
130  // destructed immediately after running the task. This is
131  // useful in the event that the function contains
132  // shared_ptr arguments bound via bind.
133  {
134  auto tasks = tasks_.front();
135  tasks_.pop();
136  // Decrement count, indicating thread is no longer available.
137  --available_;
138 
139  lock.unlock();
140 
141  // Run the task.
142  try {
143  if (tasks.run_with_id) {
144  tasks.with_id(index);
145  } else {
146  tasks.no_id();
147  }
148  }
149  // Suppress all exceptions.
150  catch ( const std::exception& ) {}
151 
152  // Update status of empty, maybe
153  // Need to recover the lock first
154  lock.lock();
155 
156  // Increment count, indicating thread is available.
157  ++available_;
158  if (tasks_.empty() && available_ == total_) {
159  complete_ = true;
160  completed_.notify_one();
161  }
162  }
163  } // while running_
164  }
165 };
166 
167 } // namespace caffe2
168 
169 #endif
~TaskThreadPool()
Destructor.
Definition: thread_pool.h:63
void runTask(Task task)
Add task to the thread pool if a thread is currently available.
Definition: thread_pool.h:82
TaskThreadPool(std::size_t pool_size)
Constructor.
Definition: thread_pool.h:53
Copyright (c) 2016-present, Facebook, Inc.
void waitWorkComplete()
Wait for queue to be empty.
Definition: thread_pool.h:109