1 #include <gtest/gtest.h> 3 #include <torch/cuda.h> 8 std::string add_negative_flag(
const std::string& flag) {
9 std::string filter = ::testing::GTEST_FLAG(filter);
10 if (filter.find(
'-') == std::string::npos) {
11 filter.push_back(
'-');
13 filter.push_back(
':');
19 int main(
int argc,
char* argv[]) {
20 ::testing::InitGoogleTest(&argc, argv);
21 if (!torch::cuda::is_available()) {
22 std::cout <<
"CUDA not available. Disabling CUDA and MultiCUDA tests" 24 ::testing::GTEST_FLAG(filter) = add_negative_flag(
"*_CUDA:*_MultiCUDA");
25 }
else if (torch::cuda::device_count() < 2) {
26 std::cout <<
"Only one CUDA device detected. Disabling MultiCUDA tests" 28 ::testing::GTEST_FLAG(filter) = add_negative_flag(
"*_MultiCUDA");
31 return RUN_ALL_TESTS();