3 from common_utils
import TestCase, run_tests, skipIfRocm
18 print(
'CUDA not available, skipping tests')
25 def get_is_primary_context_created(device):
26 flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
27 active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
30 path = glob.glob(
'{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
31 _thnvrtc = ctypes.cdll.LoadLibrary(path)
32 result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
33 assert result == 0,
'cuDevicePrimaryCtxGetState failed' 34 return bool(active[0])
38 @unittest.skipIf(
not TEST_MULTIGPU,
"only one GPU detected")
42 self.assertFalse(get_is_primary_context_created(0))
43 self.assertFalse(get_is_primary_context_created(1))
45 x = torch.randn(1, device=
'cuda:1')
48 self.assertFalse(get_is_primary_context_created(0))
49 self.assertTrue(get_is_primary_context_created(1))
54 self.assertFalse(get_is_primary_context_created(0))
55 self.assertTrue(get_is_primary_context_created(1))
57 y = torch.randn(1, device=
'cpu')
61 self.assertFalse(get_is_primary_context_created(0))
62 self.assertTrue(get_is_primary_context_created(1))
66 if __name__ ==
'__main__':