1 #include <c10/core/thread_pool.h> 5 ThreadPool::ThreadPool(std::size_t pool_size,
int numa_node_id)
9 available_(threads_.size()),
10 total_(threads_.size()),
11 numa_node_id_(numa_node_id) {
12 for (std::size_t i = 0; i < threads_.size(); ++i) {
13 threads_[i] = std::thread(std::bind(&ThreadPool::main_loop,
this, i));
17 ThreadPool::~ThreadPool() {
20 std::unique_lock<std::mutex> lock(mutex_);
22 condition_.notify_all();
25 for (
auto& t : threads_) {
28 }
catch (
const std::exception&) {
33 size_t ThreadPool::size()
const {
34 return threads_.size();
42 for (
auto& thread : threads_) {
43 if (thread.get_id() == std::this_thread::get_id()) {
50 void ThreadPool::run(
const std::function<
void()>& func) {
51 std::unique_lock<std::mutex> lock(mutex_);
57 condition_.notify_one();
61 std::unique_lock<std::mutex> lock(mutex_);
63 completed_.wait(lock);
67 void ThreadPool::main_loop(std::size_t index) {
70 std::unique_lock<std::mutex> lock(mutex_);
74 while (tasks_.empty() && running_) {
75 condition_.wait(lock);
88 auto tasks = tasks_.front();
97 if (tasks.run_with_id) {
102 }
catch (
const std::exception&) {
111 if (tasks_.empty() && available_ == total_) {
113 completed_.notify_one();
124 std::atomic<int> num_threads{1};
125 void setNumThreads(
size_t v) {
126 if(-1 == num_threads.exchange(v)) {
127 throw std::runtime_error(
"Error: cannot set num threads after pool has started");
132 static std::shared_ptr<TaskThreadPoolBase> pool =
133 ThreadPoolRegistry()->Create(
"C10", 0, num_threads.exchange(-1),
false);
137 C10_DEFINE_SHARED_REGISTRY(
146 std::shared_ptr<TaskThreadPoolBase> createC10ThreadPool(
150 static std::shared_ptr<TaskThreadPoolBase> pool =
151 std::make_shared<ThreadPool>(pool_size);
157 C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, createC10ThreadPool);
size_t numAvailable() const override
The number of available (i.e.
void waitWorkComplete()
Wait for queue to be empty.
bool inThreadPool() const override
Check if the current thread is from the thread pool.
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...