Caffe2 - C++ API
A deep learning, cross platform ML framework
main.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/cuda.h>
4 
5 #include <iostream>
6 #include <string>
7 
8 std::string add_negative_flag(const std::string& flag) {
9  std::string filter = ::testing::GTEST_FLAG(filter);
10  if (filter.find('-') == std::string::npos) {
11  filter.push_back('-');
12  } else {
13  filter.push_back(':');
14  }
15  filter += flag;
16  return filter;
17 }
18 
19 int main(int argc, char* argv[]) {
20  ::testing::InitGoogleTest(&argc, argv);
21  if (!torch::cuda::is_available()) {
22  std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests"
23  << std::endl;
24  ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA");
25  } else if (torch::cuda::device_count() < 2) {
26  std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests"
27  << std::endl;
28  ::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA");
29  }
30 
31  return RUN_ALL_TESTS();
32 }