2 #include <ATen/Parallel.h> 3 #include <ATen/test/test_assert.h> 11 void test(
int given_num_threads) {
12 auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
13 if (given_num_threads >= 0) {
14 ASSERT(at::get_num_threads() == given_num_threads);
16 ASSERT(at::get_num_threads() == -1);
19 for (
int i = 0; i < 1000; i ++) {
20 t_sum = t_sum + t.sum();
28 std::thread t1(
test, -1);
30 at::set_num_threads(4);
31 std::thread t2(
test, 4);
32 std::thread t3(
test, 4);
33 std::thread t4(
test, 4);
37 at::set_num_threads(5);