1 r"""Importing this file must **not** initialize CUDA context. test_distributed     2 relies on this assumption to properly run. This means that when this is imported     3 no CUDA calls shall be made, including torch.cuda.device_count(), etc.     5 common_cuda.py can freely initialize CUDA context when imported.    22 from collections 
import OrderedDict
    23 from functools 
import wraps
    24 from itertools 
import product
    25 from copy 
import deepcopy
    26 from numbers 
import Number
    46 parser = argparse.ArgumentParser(add_help=
False)
    47 parser.add_argument(
'--seed', type=int, default=1234)
    48 parser.add_argument(
'--accept', action=
'store_true')
    49 args, remaining = parser.parse_known_args()
    51 if not expecttest.ACCEPT:
    52     expecttest.ACCEPT = args.accept
    53 UNITTEST_ARGS = [sys.argv[0]] + remaining
    54 torch.manual_seed(SEED)
    57 def run_tests(argv=UNITTEST_ARGS):
    58     unittest.main(argv=argv)
    60 PY3 = sys.version_info > (3, 0)
    61 PY34 = sys.version_info >= (3, 4)
    63 IS_WINDOWS = sys.platform == 
"win32"    64 IS_PPC = platform.machine() == 
"ppc64le"    67 IS_PYTORCH_CI = bool(os.environ.get(
'IS_PYTORCH_CI', 0))
    70 def _check_module_exists(name):
    71     r"""Returns if a top-level module with :attr:`name` exists *without**    72     importing it. This is generally safer than try-catch block around a    73     `import X`. It avoids third party libraries breaking assumptions of some of    74     our tests, e.g., setting multiprocessing start method when imported    75     (see librosa/#747, torchvision/#544).    86         loader = importlib.find_loader(name)
    87         return loader 
is not None    91         spec = importlib.util.find_spec(name)
    92         return spec 
is not None    94 TEST_NUMPY = _check_module_exists(
'numpy')
    95 TEST_SCIPY = _check_module_exists(
'scipy')
    97 TEST_NUMBA = _check_module_exists(
'numba')
   102 TEST_LIBROSA = _check_module_exists(
'librosa') 
and PY3
   105 NO_MULTIPROCESSING_SPAWN = os.environ.get(
'NO_MULTIPROCESSING_SPAWN', 
'0') == 
'1' or sys.version_info[0] == 2
   106 TEST_WITH_ASAN = os.getenv(
'PYTORCH_TEST_WITH_ASAN', 
'0') == 
'1'   107 TEST_WITH_UBSAN = os.getenv(
'PYTORCH_TEST_WITH_UBSAN', 
'0') == 
'1'   108 TEST_WITH_ROCM = os.getenv(
'PYTORCH_TEST_WITH_ROCM', 
'0') == 
'1'   116     def wrapper(*args, **kwargs):
   118             raise unittest.SkipTest(
"test doesn't currently work on the ROCm stack")
   124 def skipIfNoLapack(fn):
   126     def wrapper(*args, **kwargs):
   127         if not torch._C.has_lapack:
   128             raise unittest.SkipTest(
'PyTorch compiled without Lapack')
   134 def skipCUDAMemoryLeakCheckIf(condition):
   136         if getattr(fn, 
'_do_cuda_memory_leak_check', 
True):  
   137             fn._do_cuda_memory_leak_check = 
not condition
   142 def suppress_warnings(fn):
   144     def wrapper(*args, **kwargs):
   145         with warnings.catch_warnings():
   146             warnings.simplefilter(
"ignore")
   151 def get_cpu_type(type_name):
   152     module, name = type_name.rsplit(
'.', 1)
   153     assert module == 
'torch.cuda'   154     return getattr(torch, name)
   157 def get_gpu_type(type_name):
   158     if isinstance(type_name, type):
   159         type_name = 
'{}.{}'.format(type_name.__module__, type_name.__name__)
   160     module, name = type_name.rsplit(
'.', 1)
   161     assert module == 
'torch'   165 def to_gpu(obj, type_map=None):
   168     if isinstance(obj, torch.Tensor):
   170         t = type_map.get(obj.type(), get_gpu_type(obj.type()))
   171         with torch.no_grad():
   172             res = obj.clone().type(t)
   173             res.requires_grad = obj.requires_grad
   176         return obj.new().resize_(obj.size()).copy_(obj)
   177     elif isinstance(obj, list):
   178         return [to_gpu(o, type_map) 
for o 
in obj]
   179     elif isinstance(obj, tuple):
   180         return tuple(to_gpu(o, type_map) 
for o 
in obj)
   185 def get_function_arglist(func):
   186     if sys.version_info > (3,):
   187         return inspect.getfullargspec(func).args
   189         return inspect.getargspec(func).args
   192 def set_rng_seed(seed):
   193     torch.manual_seed(seed)
   196         numpy.random.seed(seed)
   199 @contextlib.contextmanager
   200 def freeze_rng_state():
   201     rng_state = torch.get_rng_state()
   203         cuda_rng_state = torch.cuda.get_rng_state()
   206         torch.cuda.set_rng_state(cuda_rng_state)
   207     torch.set_rng_state(rng_state)
   210 def iter_indices(tensor):
   211     if tensor.dim() == 0:
   213     if tensor.dim() == 1:
   214         return range(tensor.size(0))
   215     return product(*(range(s) 
for s 
in tensor.size()))
   218 def is_iterable(obj):
   227     def __init__(self, testcase, name=None):
   228         self.
name = testcase.id() 
if name 
is None else name
   233         from common_cuda 
import initialize_cuda_context_rng
   234         initialize_cuda_context_rng()
   237     def get_cuda_memory_usage():
   247     def __exit__(self, exec_type, exec_value, traceback):
   249         if exec_type 
is not None:
   253         for i, (before, after) 
in enumerate(zip(self.
befores, afters)):
   254             if not TEST_WITH_ROCM:
   255                 self.testcase.assertEqual(
   256                     before, after, 
'{} leaked {} bytes CUDA memory on device {}'.format(
   257                         self.
name, after - before, i))
   261                     warnings.warn(
'{} leaked {} bytes ROCm memory on device {}'.format(
   262                         self.
name, after - before, i), RuntimeWarning)
   268     _do_cuda_memory_leak_check = 
False   270     def __init__(self, method_name='runTest'):
   271         super(TestCase, self).__init__(method_name)
   273         test_method = getattr(self, method_name)
   279             from common_cuda 
import TEST_CUDA
   280             fullname = self.id().lower()  
   281             if TEST_CUDA 
and (
'gpu' in fullname 
or 'cuda' in fullname):
   284     def assertLeaksNoCudaTensors(self, name=None):
   285         name = self.id() 
if name 
is None else name
   288     def wrap_with_cuda_memory_check(self, method):
   296         def wrapper(self, *args, **kwargs):
   298                 method(*args, **kwargs)
   299         return types.MethodType(wrapper, self)
   304     def assertTensorsSlowEqual(self, x, y, prec=None, message=''):
   307         for index 
in iter_indices(x):
   308             max_err = max(max_err, abs(x[index] - y[index]))
   309         self.assertLessEqual(max_err, prec, message)
   311     def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device='cpu'):
   314         assert all(size[d] > 0 
for d 
in range(sparse_dim)) 
or nnz == 0, 
'invalid arguments'   316         v_size = [nnz] + list(size[sparse_dim:])
   317         v = torch.randn(*v_size, device=device)
   318         i = torch.rand(sparse_dim, nnz, device=device)
   319         i.mul_(
torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
   322             v = torch.cat([v, torch.randn_like(v)], 0)
   323             i = torch.cat([i, i], 1)
   325         x = torch.sparse_coo_tensor(i, v, torch.Size(size))
   327         if not is_uncoalesced:
   336             x = x.detach().clone()
   337         return x, x._indices().clone(), x._values().clone()
   339     def safeToDense(self, t):
   343     def safeCoalesce(self, t):
   346         self.assertTrue(tc.is_coalesced())
   356         for idx, val 
in zip(t._indices().t(), t._values()):
   357             idx_tup = tuple(idx.tolist())
   358             if idx_tup 
in value_map:
   359                 value_map[idx_tup] += val
   361                 value_map[idx_tup] = val.clone() 
if isinstance(val, torch.Tensor) 
else val
   363         new_indices = sorted(list(value_map.keys()))
   364         new_values = [value_map[idx] 
for idx 
in new_indices]
   365         if t._values().ndimension() < 2:
   366             new_values = t._values().new(new_values)
   368             new_values = torch.stack(new_values)
   370         new_indices = t._indices().new(new_indices).t()
   371         tg = t.new(new_indices, new_values, t.size())
   382     def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
   383         if isinstance(prec, str) 
and message == 
'':
   389         if isinstance(x, torch.Tensor) 
and isinstance(y, Number):
   390             self.
assertEqual(x.item(), y, prec, message, allow_inf)
   391         elif isinstance(y, torch.Tensor) 
and isinstance(x, Number):
   392             self.
assertEqual(x, y.item(), prec, message, allow_inf)
   393         elif isinstance(x, torch.Tensor) 
and isinstance(y, torch.Tensor):
   394             def assertTensorsEqual(a, b):
   395                 super(TestCase, self).assertEqual(a.size(), b.size(), message)
   397                     if a.device.type == 
'cpu' and a.dtype == torch.float16:
   399                         a = a.to(torch.float32)
   403                         b = b.to(a.dtype).to(a.device)
   407                     if x.dtype == torch.bool 
and y.dtype == torch.bool:
   409                     elif x.dtype == torch.bool 
or y.dtype == torch.bool:
   410                         raise TypeError(
"Was expecting both tensors to be bool type.")
   413                         if a.is_floating_point():
   415                             nan_mask = torch.isnan(a)
   416                             self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
   420                                 inf_mask = torch.isinf(a)
   421                                 inf_sign = inf_mask.sign()
   422                                 self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
   425                         if diff.is_signed() 
and diff.dtype != torch.int8:
   428                         self.assertLessEqual(max_err, prec, message)
   429             super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
   433                 assertTensorsEqual(x._indices(), y._indices())
   434                 assertTensorsEqual(x._values(), y._values())
   436                 assertTensorsEqual(x, y)
   437         elif isinstance(x, string_classes) 
and isinstance(y, string_classes):
   438             super(TestCase, self).assertEqual(x, y, message)
   439         elif type(x) == set 
and type(y) == set:
   440             super(TestCase, self).assertEqual(x, y, message)
   441         elif isinstance(x, dict) 
and isinstance(y, dict):
   442             if isinstance(x, OrderedDict) 
and isinstance(y, OrderedDict):
   446                 key_list = list(x.keys())
   447                 self.
assertEqual([x[k] 
for k 
in key_list], [y[k] 
for k 
in key_list])
   448         elif is_iterable(x) 
and is_iterable(y):
   449             super(TestCase, self).assertEqual(len(x), len(y), message)
   450             for x_, y_ 
in zip(x, y):
   452         elif isinstance(x, bool) 
and isinstance(y, bool):
   453             super(TestCase, self).assertEqual(x, y, message)
   454         elif isinstance(x, Number) 
and isinstance(y, Number):
   455             if abs(x) == inf 
or abs(y) == inf:
   457                     super(TestCase, self).assertEqual(x, y, message)
   459                     self.fail(
"Expected finite numeric values - x={}, y={}".format(x, y))
   461             super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
   463             super(TestCase, self).assertEqual(x, y, message)
   465     def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None, allow_inf=None):
   471     def assertNotEqual(self, x, y, prec=None, message=''):
   472         if isinstance(prec, str) 
and message == 
'':
   478         if isinstance(x, torch.Tensor) 
and isinstance(y, torch.Tensor):
   479             if x.size() != y.size():
   480                 super(TestCase, self).assertNotEqual(x.size(), y.size())
   481             self.assertGreater(x.numel(), 0)
   483             y = y.cuda(device=x.get_device()) 
if x.is_cuda 
else y.cpu()
   485             if torch.equal(nan_mask, y != y):
   491                 self.assertGreaterEqual(max_err, prec, message)
   492         elif type(x) == str 
and type(y) == str:
   493             super(TestCase, self).assertNotEqual(x, y)
   494         elif is_iterable(x) 
and is_iterable(y):
   495             super(TestCase, self).assertNotEqual(x, y)
   498                 self.assertGreaterEqual(abs(x - y), prec, message)
   500             except (TypeError, AssertionError):
   502             super(TestCase, self).assertNotEqual(x, y, message)
   504     def assertObjectIn(self, obj, iterable):
   505         for elem 
in iterable:
   506             if id(obj) == id(elem):
   508         raise AssertionError(
"object not found in iterable")
   513     def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
   515         if 'subname' in kwargs:
   516             subname = kwargs[
'subname']
   517             del kwargs[
'subname']
   519             callable(*args, **kwargs)
   520         except exc_type 
as e:
   524         self.fail(msg=
"Did not raise when expected to")
   526     def assertWarns(self, callable, msg=''):
   528         Test if :attr:`callable` raises a warning.   530         with warnings.catch_warnings(record=
True) 
as ws:
   531             warnings.simplefilter(
"always")  
   533             self.assertTrue(len(ws) > 0, msg)
   535     def assertWarnsRegex(self, callable, regex, msg=''):
   537         Test if :attr:`callable` raises any warning with message that contains   538         the regex pattern :attr:`regex`.   540         with warnings.catch_warnings(record=
True) 
as ws:
   541             warnings.simplefilter(
"always")  
   543             self.assertTrue(len(ws) > 0, msg)
   544             found = any(re.search(regex, str(w.message)) 
is not None for w 
in ws)
   545             self.assertTrue(found, msg)
   547     def assertExpected(self, s, subname=None):
   549         Test that a string matches the recorded contents of a file   550         derived from the name of this test and subname.  This file   551         is placed in the 'expect' directory in the same directory   552         as the test script. You can automatically update the recorded test   553         output using --accept.   555         If you call this multiple times in a single function, you must   556         give a unique subname each time.   558         if not (isinstance(s, str) 
or (sys.version_info[0] == 2 
and isinstance(s, unicode))):
   559             raise TypeError(
"assertExpected is strings only")
   561         def remove_prefix(text, prefix):
   562             if text.startswith(prefix):
   563                 return text[len(prefix):]
   570         module_id = self.__class__.__module__
   571         munged_id = remove_prefix(self.id(), module_id + 
".")
   572         test_file = os.path.realpath(sys.modules[module_id].__file__)
   573         expected_file = os.path.join(os.path.dirname(test_file),
   579             expected_file += 
"-" + subname
   580             subname_output = 
" ({})".format(subname)
   581         expected_file += 
".expect"   584         def accept_output(update_type):
   585             print(
"Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, s))
   586             with open(expected_file, 
'w') 
as f:
   590             with open(expected_file) 
as f:
   593             if e.errno != errno.ENOENT:
   595             elif expecttest.ACCEPT:
   596                 return accept_output(
"output")
   599                     (
"I got this output for {}{}:\n\n{}\n\n"   600                      "No expect file exists; to accept the current output, run:\n"   601                      "python {} {} --accept").format(munged_id, subname_output, s, __main__.__file__, munged_id))
   605             expected = re.sub(
r'CppOp\[(.+?)\]', 
'CppOp[]', expected)
   606             s = re.sub(
r'CppOp\[(.+?)\]', 
'CppOp[]', s)
   608         if expecttest.ACCEPT:
   610                 return accept_output(
"updated output")
   612             if hasattr(self, 
"assertMultiLineEqual"):
   615                 self.assertMultiLineEqual(expected, s)
   619     if sys.version_info < (3, 2):
   621         assertRegex = unittest.TestCase.assertRegexpMatches
   623         assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
   626 def download_file(url, binary=True):
   627     if sys.version_info < (3,):
   628         from urlparse 
import urlsplit
   633         from urllib.parse 
import urlsplit
   634         from urllib 
import request, error
   636     filename = os.path.basename(urlsplit(url)[2])
   637     data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 
'data'))
   638     path = os.path.join(data_dir, filename)
   640     if os.path.exists(path):
   643         data = request.urlopen(url, timeout=15).read()
   644         with open(path, 
'wb' if binary 
else 'w') 
as f:
   647     except error.URLError:
   648         msg = 
"could not download test file '{}'".format(url)
   649         warnings.warn(msg, RuntimeWarning)
   650         raise unittest.SkipTest(msg)
   653 def find_free_port():
   654     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
   655     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
   656     sock.bind((
'localhost', 0))
   657     sockname = sock.getsockname()
   662 def retry_on_address_already_in_use_error(func):
   663     """Reruns a test if it sees "Address already in use" error."""   665     def wrapper(*args, **kwargs):
   669                 return func(*args, **kwargs)
   670             except RuntimeError 
as error:
   671                 if str(error) == 
"Address already in use":
   673                     if tries_remaining == 0:
   675                     time.sleep(random.random())
   683 def prod_single_zero(dim_size):
   684     result = torch.randn(dim_size, dim_size)
   689 def random_square_matrix_of_rank(l, rank):
   691     A = torch.randn(l, l)
   698     return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
   701 def random_symmetric_matrix(l):
   702     A = torch.randn(l, l)
   709 def random_symmetric_psd_matrix(l):
   710     A = torch.randn(l, l)
   711     return A.mm(A.transpose(0, 1))
   714 def random_symmetric_pd_matrix(l, *batches):
   715     A = torch.randn(*(batches + (l, l)))
   716     return A.matmul(A.transpose(-2, -1)) + torch.eye(l) * 1e-5
   719 def make_nonzero_det(A, sign=None, min_singular_value=0.1):
   721     s[s < min_singular_value] = min_singular_value
   722     A = u.mm(torch.diag(s)).mm(v.t())
   725         if (det < 0) ^ (sign < 0):
   730 def random_fullrank_matrix_distinct_singular_value(l, *batches, **kwargs):
   731     silent = kwargs.get(
"silent", 
False)
   732     if silent 
and not torch._C.has_lapack:
   733         return torch.ones(l, l)
   735     if len(batches) == 0:
   736         A = torch.randn(l, l)
   738         s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
   739         return u.mm(torch.diag(s)).mm(v.t())
   742         for _ 
in range(0, torch.prod(torch.as_tensor(batches)).item()):
   743             A = torch.randn(l, l)
   745             s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
   746             all_matrices.append(u.mm(torch.diag(s)).mm(v.t()))
   747         return torch.stack(all_matrices).reshape(*(batches + (l, l)))
   750 def brute_pdist(inp, p=2):
   751     """Computes the same as torch.pdist using primitives"""   756         return torch.empty(inp.shape[:-2] + (0,), dtype=inp.dtype, device=inp.device)
   757     square = torch.norm(inp[..., 
None, :] - inp[..., 
None, :, :], p=p, dim=-1)
   758     unroll = square.view(square.shape[:-2] + (n * n,))
   759     inds = torch.ones(k, dtype=torch.int)
   760     inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
   761     return unroll[..., inds.cumsum(0)]
   764 def brute_cdist(x, y, p=2):
   767     if r1 == 0 
or r2 == 0:
   768         return torch.empty(r1, r2, device=x.device)
   769     return torch.norm(x[..., 
None, :] - y[..., 
None, :, :], p=p, dim=-1)
   772 def do_test_dtypes(self, dtypes, layout, device):
   774         if dtype != torch.float16:
   775             out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
   776             self.assertIs(dtype, out.dtype)
   777             self.assertIs(layout, out.layout)
   781 def do_test_empty_full(self, dtypes, layout, device):
   782     shape = torch.Size([2, 3])
   784     def check_value(tensor, dtype, layout, device, value, requires_grad):
   786         self.assertIs(dtype, tensor.dtype)
   787         self.assertIs(layout, tensor.layout)
   788         self.
assertEqual(tensor.requires_grad, requires_grad)
   789         if tensor.is_cuda 
and device 
is not None:
   791         if value 
is not None:
   792             fill = tensor.new(shape).fill_(value)
   795     def get_int64_dtype(dtype):
   796         module = 
'.'.join(str(dtype).
split(
'.')[1:-1])
   799         return operator.attrgetter(module)(torch).int64
   801     default_dtype = torch.get_default_dtype()
   802     check_value(torch.empty(shape), default_dtype, torch.strided, -1, 
None, 
False)
   803     check_value(torch.full(shape, -5), default_dtype, torch.strided, -1, 
None, 
False)
   805         for rg 
in {dtype.is_floating_point, 
False}:
   806             int64_dtype = get_int64_dtype(dtype)
   807             v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
   808             check_value(v, dtype, layout, device, 
None, rg)
   810             check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg),
   811                         dtype, layout, device, 
None, rg)
   812             check_value(v.new_empty(shape), dtype, layout, device, 
None, 
False)
   813             check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=
False),
   814                         int64_dtype, layout, device, 
None, 
False)
   815             check_value(torch.empty_like(v), dtype, layout, device, 
None, 
False)
   816             check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=
False),
   817                         int64_dtype, layout, device, 
None, 
False)
   819             if dtype 
is not torch.float16 
and layout != torch.sparse_coo:
   821                 v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg)
   822                 check_value(v, dtype, layout, device, fv, rg)
   823                 check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1, 
False)
   825                 check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg),
   826                             dtype, layout, device, fv + 2, rg)
   827                 check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=
False),
   828                             int64_dtype, layout, device, fv + 3, 
False)
   829                 check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, 
False)
   830                 check_value(torch.full_like(v, fv + 5,
   831                                             dtype=int64_dtype, layout=layout, device=device, requires_grad=
False),
   832                             int64_dtype, layout, device, fv + 5, 
False)
   835 IS_SANDCASTLE = os.getenv(
'SANDCASTLE') == 
'1' or os.getenv(
'TW_JOB_USER') == 
'sandcastle'   837 THESE_TAKE_WAY_TOO_LONG = {
   838     'test_Conv3d_groups',
   839     'test_conv_double_backward',
   840     'test_conv_double_backward_groups',
   841     'test_Conv3d_dilated',
   842     'test_Conv3d_stride_padding',
   843     'test_Conv3d_dilated_strided',
   845     'test_Conv2d_dilated',
   846     'test_ConvTranspose3d_dilated',
   847     'test_ConvTranspose2d_dilated',
   850     'test_Conv2d_padding',
   851     'test_ConvTranspose2d_no_bias',
   852     'test_ConvTranspose2d',
   853     'test_ConvTranspose3d',
   854     'test_Conv2d_no_bias',
   856     'test_multinomial_invalid_probs',
   860 running_script_path = 
None   863 def set_running_script_path():
   864     global running_script_path
   866         running_file = os.path.abspath(os.path.realpath(sys.argv[0]))
   867         if running_file.endswith(
'.py'):  
   868             running_script_path = running_file
   873 def check_test_defined_in_running_script(test_case):
   874     if running_script_path 
is None:
   881     test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__)))
   882     assert test_case_class_file == running_script_path, 
"Class of loaded TestCase \"{}\" " \
   883         "is not defined in the running script \"{}\", but in \"{}\". Did you " \
   884         "accidentally import a unittest.TestCase from another file?".format(
   885             test_case.id(), running_script_path, test_case_class_file)
   888 num_shards = os.environ.get(
'TEST_NUM_SHARDS', 
None)
   889 shard = os.environ.get(
'TEST_SHARD', 
None)
   890 if num_shards 
is not None and shard 
is not None:
   891     num_shards = int(num_shards)
   894     def load_tests(loader, tests, pattern):
   895         set_running_script_path()
   896         test_suite = unittest.TestSuite()
   897         for test_group 
in tests:
   898             for test 
in test_group:
   899                 check_test_defined_in_running_script(test)
   900                 name = test.id().
split(
'.')[-1]
   901                 if name 
in THESE_TAKE_WAY_TOO_LONG:
   903                 hash_id = int(hashlib.sha256(str(test).encode(
'utf-8')).hexdigest(), 16)
   904                 if hash_id % num_shards == shard:
   905                     test_suite.addTest(test)
   909     def load_tests(loader, tests, pattern):
   910         set_running_script_path()
   911         test_suite = unittest.TestSuite()
   912         for test_group 
in tests:
   913             for test 
in test_group:
   914                 check_test_defined_in_running_script(test)
   915                 test_suite.addTest(test)
 def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
 
def wrap_with_cuda_memory_check(self, method)
 
Module caffe2.python.layers.split. 
 
bool _do_cuda_memory_leak_check
 
def assertExpected(self, s, subname=None)
 
def disable_global_flags()
 
def memory_allocated(device=None)
 
def safeCoalesce(self, t)
 
def assertLeaksNoCudaTensors(self, name=None)
 
def get_cuda_memory_usage()
 
def set_default_tensor_type(t)