Caffe2 - C++ API
A deep learning, cross platform ML framework
thread_pool.h
1 #pragma once
2 
3 #include <condition_variable>
4 #include <functional>
5 #include <mutex>
6 #include <queue>
7 #include <thread>
8 #include <utility>
9 
10 #include <c10/util/Optional.h>
11 #include <c10/util/intrusive_ptr.h>
12 #include <c10/util/numa.h>
13 #include <c10/util/thread_name.h>
14 
15 namespace c10 {
16 
17 namespace ivalue {
18 struct Future;
19 } // namespace ivalue
20 
21 // TODO: move this to C10 and make it C10_API
22 class C10_API TaskThreadPoolBase {
23  public:
24  virtual void run(const std::function<void()>& func) = 0;
25 
26  virtual size_t size() const = 0;
27 
31  virtual size_t numAvailable() const = 0;
32 
36  virtual bool inThreadPool() const = 0;
37 
38  virtual ~TaskThreadPoolBase() noexcept {}
39 };
40 
41 class C10_API ThreadPool : public c10::TaskThreadPoolBase {
42  protected:
43  struct task_element_t {
44  bool run_with_id;
45  const std::function<void()> no_id;
46  const std::function<void(std::size_t)> with_id;
47 
48  explicit task_element_t(const std::function<void()>& f)
49  : run_with_id(false), no_id(f), with_id(nullptr) {}
50  explicit task_element_t(const std::function<void(std::size_t)>& f)
51  : run_with_id(true), no_id(nullptr), with_id(f) {}
52  };
53 
54  std::queue<task_element_t> tasks_;
55  std::vector<std::thread> threads_;
56  std::mutex mutex_;
57  std::condition_variable condition_;
58  std::condition_variable completed_;
59  std::atomic_bool running_;
60  bool complete_;
61  std::size_t available_;
62  std::size_t total_;
63  int numa_node_id_;
64 
65  public:
66  ThreadPool() = delete;
67 
68  explicit ThreadPool(
69  std::size_t pool_size,
70  int numa_node_id = -1);
71 
72  ~ThreadPool();
73 
74  size_t size() const override;
75 
76  size_t numAvailable() const override;
77 
78  bool inThreadPool() const override;
79 
80  void run(const std::function<void()>& func) override;
81 
82  template <typename Task>
83  void runTaskWithID(Task task) {
84  std::unique_lock<std::mutex> lock(mutex_);
85 
86  // Set task and signal condition variable so that a worker thread will
87  // wake up and use the task.
88  tasks_.push(
89  task_element_t(static_cast<std::function<void(std::size_t)>>(task)));
90  complete_ = false;
91  condition_.notify_one();
92  }
93 
95  void waitWorkComplete();
96 
97  protected:
98  virtual void init_thread() {}
99 
100  private:
101  // @brief Entry point for pool threads.
102  void main_loop(std::size_t index);
103 };
104 
105 C10_API void setNumThreads(size_t v);
106 
107 C10_API TaskThreadPoolBase& global_work_queue();
108 
109 class C10_API TaskThreadPool : public c10::ThreadPool {
110  public:
111  explicit TaskThreadPool(
112  std::size_t pool_size,
113  int numa_node_id = -1)
114  : ThreadPool(pool_size, numa_node_id) {}
115 
116  // TODO move this to ATen/core/thread_pool.h
117  void init_thread() override {
118  setThreadName("CaffeTaskThread");
119  NUMABind(numa_node_id_);
120  }
121 };
122 
123 C10_DECLARE_SHARED_REGISTRY(
124  ThreadPoolRegistry,
126  int,
127  int,
128  bool);
129 
130 } // namespace c10
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
void NUMABind(int numa_node_id)
Bind to a given NUMA node.
Definition: numa.cpp:112