Caffe2 - Python API
A deep learning, cross platform ML framework
test_cuda_primary_ctx.py
1 import ctypes
2 import torch
3 from common_utils import TestCase, run_tests, skipIfRocm
4 import unittest
5 import glob
6 import os
7 
8 # NOTE: this needs to be run in a brand new process
9 
10 # We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here,
11 # because if we do that, the TEST_CUDNN line from common_cuda will be executed
12 # multiple times as well during the execution of this test suite, and it will
13 # cause CUDA OOM error on Windows.
14 TEST_CUDA = torch.cuda.is_available()
15 TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
16 
17 if not TEST_CUDA:
18  print('CUDA not available, skipping tests')
19  TestCase = object # noqa: F811
20 
21 
22 _thnvrtc = None
23 
24 
25 def get_is_primary_context_created(device):
26  flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint))
27  active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
28  global _thnvrtc
29  if _thnvrtc is None:
30  path = glob.glob('{}/lib/libthnvrtc.*'.format(os.path.dirname(torch.__file__)))[0]
31  _thnvrtc = ctypes.cdll.LoadLibrary(path)
32  result = _thnvrtc.cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active)
33  assert result == 0, 'cuDevicePrimaryCtxGetState failed'
34  return bool(active[0])
35 
36 
38  @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
39  @skipIfRocm
40  def test_cuda_primary_ctx(self):
41  # Ensure context has not been created beforehand
42  self.assertFalse(get_is_primary_context_created(0))
43  self.assertFalse(get_is_primary_context_created(1))
44 
45  x = torch.randn(1, device='cuda:1')
46 
47  # We should have only created context on 'cuda:1'
48  self.assertFalse(get_is_primary_context_created(0))
49  self.assertTrue(get_is_primary_context_created(1))
50 
51  print(x)
52 
53  # We should still have only created context on 'cuda:1'
54  self.assertFalse(get_is_primary_context_created(0))
55  self.assertTrue(get_is_primary_context_created(1))
56 
57  y = torch.randn(1, device='cpu')
58  y.copy_(x)
59 
60  # We should still have only created context on 'cuda:1'
61  self.assertFalse(get_is_primary_context_created(0))
62  self.assertTrue(get_is_primary_context_created(1))
63 
64  # DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS
65 
66 if __name__ == '__main__':
67  run_tests()
def is_available()
Definition: __init__.py:45
def device_count()
Definition: __init__.py:341