Caffe2 - C++ API
A deep learning, cross platform ML framework
tbb_init_test.cpp
1 #include <ATen/ATen.h>
2 #include <ATen/Parallel.h>
3 #include <ATen/test/test_assert.h>
4 #include <thread>
5 
6 
7 // This checks whether threads can see the global
8 // numbers of threads set and also whether the scheduler
9 // will throw an exception when multiple threads call
10 // their first parallel construct.
11 void test(int given_num_threads) {
12  auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
13  if (given_num_threads >= 0) {
14  ASSERT(at::get_num_threads() == given_num_threads);
15  } else {
16  ASSERT(at::get_num_threads() == -1);
17  }
18  auto t_sum = t.sum();
19  for (int i = 0; i < 1000; i ++) {
20  t_sum = t_sum + t.sum();
21  }
22 }
23 
24 int main() {
25  at::manual_seed(123);
26 
27  test(-1);
28  std::thread t1(test, -1);
29  t1.join();
30  at::set_num_threads(4);
31  std::thread t2(test, 4);
32  std::thread t3(test, 4);
33  std::thread t4(test, 4);
34  t4.join();
35  t3.join();
36  t2.join();
37  at::set_num_threads(5);
38  test(5);
39 
40  return 0;
41 }
Definition: module.cpp:17