1 r"""Importing this file must **not** initialize CUDA context. test_distributed 2 relies on this assumption to properly run. This means that when this is imported 3 no CUDA calls shall be made, including torch.cuda.device_count(), etc. 5 common_cuda.py can freely initialize CUDA context when imported. 22 from collections
import OrderedDict
23 from functools
import wraps
24 from itertools
import product
25 from copy
import deepcopy
26 from numbers
import Number
46 parser = argparse.ArgumentParser(add_help=
False)
47 parser.add_argument(
'--seed', type=int, default=1234)
48 parser.add_argument(
'--accept', action=
'store_true')
49 args, remaining = parser.parse_known_args()
51 if not expecttest.ACCEPT:
52 expecttest.ACCEPT = args.accept
53 UNITTEST_ARGS = [sys.argv[0]] + remaining
54 torch.manual_seed(SEED)
57 def run_tests(argv=UNITTEST_ARGS):
58 unittest.main(argv=argv)
60 PY3 = sys.version_info > (3, 0)
61 PY34 = sys.version_info >= (3, 4)
63 IS_WINDOWS = sys.platform ==
"win32" 64 IS_PPC = platform.machine() ==
"ppc64le" 67 IS_PYTORCH_CI = bool(os.environ.get(
'IS_PYTORCH_CI', 0))
70 def _check_module_exists(name):
71 r"""Returns if a top-level module with :attr:`name` exists *without** 72 importing it. This is generally safer than try-catch block around a 73 `import X`. It avoids third party libraries breaking assumptions of some of 74 our tests, e.g., setting multiprocessing start method when imported 75 (see librosa/#747, torchvision/#544). 86 loader = importlib.find_loader(name)
87 return loader
is not None 91 spec = importlib.util.find_spec(name)
92 return spec
is not None 94 TEST_NUMPY = _check_module_exists(
'numpy')
95 TEST_SCIPY = _check_module_exists(
'scipy')
97 TEST_NUMBA = _check_module_exists(
'numba')
102 TEST_LIBROSA = _check_module_exists(
'librosa')
and PY3
105 NO_MULTIPROCESSING_SPAWN = os.environ.get(
'NO_MULTIPROCESSING_SPAWN',
'0') ==
'1' or sys.version_info[0] == 2
106 TEST_WITH_ASAN = os.getenv(
'PYTORCH_TEST_WITH_ASAN',
'0') ==
'1' 107 TEST_WITH_UBSAN = os.getenv(
'PYTORCH_TEST_WITH_UBSAN',
'0') ==
'1' 108 TEST_WITH_ROCM = os.getenv(
'PYTORCH_TEST_WITH_ROCM',
'0') ==
'1' 116 def wrapper(*args, **kwargs):
118 raise unittest.SkipTest(
"test doesn't currently work on the ROCm stack")
124 def skipIfNoLapack(fn):
126 def wrapper(*args, **kwargs):
127 if not torch._C.has_lapack:
128 raise unittest.SkipTest(
'PyTorch compiled without Lapack')
134 def skipCUDAMemoryLeakCheckIf(condition):
136 if getattr(fn,
'_do_cuda_memory_leak_check',
True):
137 fn._do_cuda_memory_leak_check =
not condition
142 def suppress_warnings(fn):
144 def wrapper(*args, **kwargs):
145 with warnings.catch_warnings():
146 warnings.simplefilter(
"ignore")
151 def get_cpu_type(type_name):
152 module, name = type_name.rsplit(
'.', 1)
153 assert module ==
'torch.cuda' 154 return getattr(torch, name)
157 def get_gpu_type(type_name):
158 if isinstance(type_name, type):
159 type_name =
'{}.{}'.format(type_name.__module__, type_name.__name__)
160 module, name = type_name.rsplit(
'.', 1)
161 assert module ==
'torch' 165 def to_gpu(obj, type_map=None):
168 if isinstance(obj, torch.Tensor):
170 t = type_map.get(obj.type(), get_gpu_type(obj.type()))
171 with torch.no_grad():
172 res = obj.clone().type(t)
173 res.requires_grad = obj.requires_grad
176 return obj.new().resize_(obj.size()).copy_(obj)
177 elif isinstance(obj, list):
178 return [to_gpu(o, type_map)
for o
in obj]
179 elif isinstance(obj, tuple):
180 return tuple(to_gpu(o, type_map)
for o
in obj)
185 def get_function_arglist(func):
186 if sys.version_info > (3,):
187 return inspect.getfullargspec(func).args
189 return inspect.getargspec(func).args
192 def set_rng_seed(seed):
193 torch.manual_seed(seed)
196 numpy.random.seed(seed)
199 @contextlib.contextmanager
200 def freeze_rng_state():
201 rng_state = torch.get_rng_state()
203 cuda_rng_state = torch.cuda.get_rng_state()
206 torch.cuda.set_rng_state(cuda_rng_state)
207 torch.set_rng_state(rng_state)
210 def iter_indices(tensor):
211 if tensor.dim() == 0:
213 if tensor.dim() == 1:
214 return range(tensor.size(0))
215 return product(*(range(s)
for s
in tensor.size()))
218 def is_iterable(obj):
227 def __init__(self, testcase, name=None):
228 self.
name = testcase.id()
if name
is None else name
233 from common_cuda
import initialize_cuda_context_rng
234 initialize_cuda_context_rng()
237 def get_cuda_memory_usage():
247 def __exit__(self, exec_type, exec_value, traceback):
249 if exec_type
is not None:
253 for i, (before, after)
in enumerate(zip(self.
befores, afters)):
254 if not TEST_WITH_ROCM:
255 self.testcase.assertEqual(
256 before, after,
'{} leaked {} bytes CUDA memory on device {}'.format(
257 self.
name, after - before, i))
261 warnings.warn(
'{} leaked {} bytes ROCm memory on device {}'.format(
262 self.
name, after - before, i), RuntimeWarning)
268 _do_cuda_memory_leak_check =
False 270 def __init__(self, method_name='runTest'):
271 super(TestCase, self).__init__(method_name)
273 test_method = getattr(self, method_name)
279 from common_cuda
import TEST_CUDA
280 fullname = self.id().lower()
281 if TEST_CUDA
and (
'gpu' in fullname
or 'cuda' in fullname):
284 def assertLeaksNoCudaTensors(self, name=None):
285 name = self.id()
if name
is None else name
288 def wrap_with_cuda_memory_check(self, method):
296 def wrapper(self, *args, **kwargs):
298 method(*args, **kwargs)
299 return types.MethodType(wrapper, self)
304 def assertTensorsSlowEqual(self, x, y, prec=None, message=''):
307 for index
in iter_indices(x):
308 max_err = max(max_err, abs(x[index] - y[index]))
309 self.assertLessEqual(max_err, prec, message)
311 def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device='cpu'):
314 assert all(size[d] > 0
for d
in range(sparse_dim))
or nnz == 0,
'invalid arguments' 316 v_size = [nnz] + list(size[sparse_dim:])
317 v = torch.randn(*v_size, device=device)
318 i = torch.rand(sparse_dim, nnz, device=device)
319 i.mul_(
torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
322 v = torch.cat([v, torch.randn_like(v)], 0)
323 i = torch.cat([i, i], 1)
325 x = torch.sparse_coo_tensor(i, v, torch.Size(size))
327 if not is_uncoalesced:
336 x = x.detach().clone()
337 return x, x._indices().clone(), x._values().clone()
339 def safeToDense(self, t):
343 def safeCoalesce(self, t):
346 self.assertTrue(tc.is_coalesced())
356 for idx, val
in zip(t._indices().t(), t._values()):
357 idx_tup = tuple(idx.tolist())
358 if idx_tup
in value_map:
359 value_map[idx_tup] += val
361 value_map[idx_tup] = val.clone()
if isinstance(val, torch.Tensor)
else val
363 new_indices = sorted(list(value_map.keys()))
364 new_values = [value_map[idx]
for idx
in new_indices]
365 if t._values().ndimension() < 2:
366 new_values = t._values().new(new_values)
368 new_values = torch.stack(new_values)
370 new_indices = t._indices().new(new_indices).t()
371 tg = t.new(new_indices, new_values, t.size())
382 def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
383 if isinstance(prec, str)
and message ==
'':
389 if isinstance(x, torch.Tensor)
and isinstance(y, Number):
390 self.
assertEqual(x.item(), y, prec, message, allow_inf)
391 elif isinstance(y, torch.Tensor)
and isinstance(x, Number):
392 self.
assertEqual(x, y.item(), prec, message, allow_inf)
393 elif isinstance(x, torch.Tensor)
and isinstance(y, torch.Tensor):
394 def assertTensorsEqual(a, b):
395 super(TestCase, self).assertEqual(a.size(), b.size(), message)
397 if a.device.type ==
'cpu' and a.dtype == torch.float16:
399 a = a.to(torch.float32)
403 b = b.to(a.dtype).to(a.device)
407 if x.dtype == torch.bool
and y.dtype == torch.bool:
409 elif x.dtype == torch.bool
or y.dtype == torch.bool:
410 raise TypeError(
"Was expecting both tensors to be bool type.")
413 if a.is_floating_point():
415 nan_mask = torch.isnan(a)
416 self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
420 inf_mask = torch.isinf(a)
421 inf_sign = inf_mask.sign()
422 self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
425 if diff.is_signed()
and diff.dtype != torch.int8:
428 self.assertLessEqual(max_err, prec, message)
429 super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
433 assertTensorsEqual(x._indices(), y._indices())
434 assertTensorsEqual(x._values(), y._values())
436 assertTensorsEqual(x, y)
437 elif isinstance(x, string_classes)
and isinstance(y, string_classes):
438 super(TestCase, self).assertEqual(x, y, message)
439 elif type(x) == set
and type(y) == set:
440 super(TestCase, self).assertEqual(x, y, message)
441 elif isinstance(x, dict)
and isinstance(y, dict):
442 if isinstance(x, OrderedDict)
and isinstance(y, OrderedDict):
446 key_list = list(x.keys())
447 self.
assertEqual([x[k]
for k
in key_list], [y[k]
for k
in key_list])
448 elif is_iterable(x)
and is_iterable(y):
449 super(TestCase, self).assertEqual(len(x), len(y), message)
450 for x_, y_
in zip(x, y):
452 elif isinstance(x, bool)
and isinstance(y, bool):
453 super(TestCase, self).assertEqual(x, y, message)
454 elif isinstance(x, Number)
and isinstance(y, Number):
455 if abs(x) == inf
or abs(y) == inf:
457 super(TestCase, self).assertEqual(x, y, message)
459 self.fail(
"Expected finite numeric values - x={}, y={}".format(x, y))
461 super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
463 super(TestCase, self).assertEqual(x, y, message)
465 def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None, allow_inf=None):
471 def assertNotEqual(self, x, y, prec=None, message=''):
472 if isinstance(prec, str)
and message ==
'':
478 if isinstance(x, torch.Tensor)
and isinstance(y, torch.Tensor):
479 if x.size() != y.size():
480 super(TestCase, self).assertNotEqual(x.size(), y.size())
481 self.assertGreater(x.numel(), 0)
483 y = y.cuda(device=x.get_device())
if x.is_cuda
else y.cpu()
485 if torch.equal(nan_mask, y != y):
491 self.assertGreaterEqual(max_err, prec, message)
492 elif type(x) == str
and type(y) == str:
493 super(TestCase, self).assertNotEqual(x, y)
494 elif is_iterable(x)
and is_iterable(y):
495 super(TestCase, self).assertNotEqual(x, y)
498 self.assertGreaterEqual(abs(x - y), prec, message)
500 except (TypeError, AssertionError):
502 super(TestCase, self).assertNotEqual(x, y, message)
504 def assertObjectIn(self, obj, iterable):
505 for elem
in iterable:
506 if id(obj) == id(elem):
508 raise AssertionError(
"object not found in iterable")
513 def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
515 if 'subname' in kwargs:
516 subname = kwargs[
'subname']
517 del kwargs[
'subname']
519 callable(*args, **kwargs)
520 except exc_type
as e:
524 self.fail(msg=
"Did not raise when expected to")
526 def assertWarns(self, callable, msg=''):
528 Test if :attr:`callable` raises a warning. 530 with warnings.catch_warnings(record=
True)
as ws:
531 warnings.simplefilter(
"always")
533 self.assertTrue(len(ws) > 0, msg)
535 def assertWarnsRegex(self, callable, regex, msg=''):
537 Test if :attr:`callable` raises any warning with message that contains 538 the regex pattern :attr:`regex`. 540 with warnings.catch_warnings(record=
True)
as ws:
541 warnings.simplefilter(
"always")
543 self.assertTrue(len(ws) > 0, msg)
544 found = any(re.search(regex, str(w.message))
is not None for w
in ws)
545 self.assertTrue(found, msg)
547 def assertExpected(self, s, subname=None):
549 Test that a string matches the recorded contents of a file 550 derived from the name of this test and subname. This file 551 is placed in the 'expect' directory in the same directory 552 as the test script. You can automatically update the recorded test 553 output using --accept. 555 If you call this multiple times in a single function, you must 556 give a unique subname each time. 558 if not (isinstance(s, str)
or (sys.version_info[0] == 2
and isinstance(s, unicode))):
559 raise TypeError(
"assertExpected is strings only")
561 def remove_prefix(text, prefix):
562 if text.startswith(prefix):
563 return text[len(prefix):]
570 module_id = self.__class__.__module__
571 munged_id = remove_prefix(self.id(), module_id +
".")
572 test_file = os.path.realpath(sys.modules[module_id].__file__)
573 expected_file = os.path.join(os.path.dirname(test_file),
579 expected_file +=
"-" + subname
580 subname_output =
" ({})".format(subname)
581 expected_file +=
".expect" 584 def accept_output(update_type):
585 print(
"Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, s))
586 with open(expected_file,
'w')
as f:
590 with open(expected_file)
as f:
593 if e.errno != errno.ENOENT:
595 elif expecttest.ACCEPT:
596 return accept_output(
"output")
599 (
"I got this output for {}{}:\n\n{}\n\n" 600 "No expect file exists; to accept the current output, run:\n" 601 "python {} {} --accept").format(munged_id, subname_output, s, __main__.__file__, munged_id))
605 expected = re.sub(
r'CppOp\[(.+?)\]',
'CppOp[]', expected)
606 s = re.sub(
r'CppOp\[(.+?)\]',
'CppOp[]', s)
608 if expecttest.ACCEPT:
610 return accept_output(
"updated output")
612 if hasattr(self,
"assertMultiLineEqual"):
615 self.assertMultiLineEqual(expected, s)
619 if sys.version_info < (3, 2):
621 assertRegex = unittest.TestCase.assertRegexpMatches
623 assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
626 def download_file(url, binary=True):
627 if sys.version_info < (3,):
628 from urlparse
import urlsplit
633 from urllib.parse
import urlsplit
634 from urllib
import request, error
636 filename = os.path.basename(urlsplit(url)[2])
637 data_dir = get_writable_path(os.path.join(os.path.dirname(__file__),
'data'))
638 path = os.path.join(data_dir, filename)
640 if os.path.exists(path):
643 data = request.urlopen(url, timeout=15).read()
644 with open(path,
'wb' if binary
else 'w')
as f:
647 except error.URLError:
648 msg =
"could not download test file '{}'".format(url)
649 warnings.warn(msg, RuntimeWarning)
650 raise unittest.SkipTest(msg)
653 def find_free_port():
654 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
655 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
656 sock.bind((
'localhost', 0))
657 sockname = sock.getsockname()
662 def retry_on_address_already_in_use_error(func):
663 """Reruns a test if it sees "Address already in use" error.""" 665 def wrapper(*args, **kwargs):
669 return func(*args, **kwargs)
670 except RuntimeError
as error:
671 if str(error) ==
"Address already in use":
673 if tries_remaining == 0:
675 time.sleep(random.random())
683 def prod_single_zero(dim_size):
684 result = torch.randn(dim_size, dim_size)
689 def random_square_matrix_of_rank(l, rank):
691 A = torch.randn(l, l)
698 return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
701 def random_symmetric_matrix(l):
702 A = torch.randn(l, l)
709 def random_symmetric_psd_matrix(l):
710 A = torch.randn(l, l)
711 return A.mm(A.transpose(0, 1))
714 def random_symmetric_pd_matrix(l, *batches):
715 A = torch.randn(*(batches + (l, l)))
716 return A.matmul(A.transpose(-2, -1)) + torch.eye(l) * 1e-5
719 def make_nonzero_det(A, sign=None, min_singular_value=0.1):
721 s[s < min_singular_value] = min_singular_value
722 A = u.mm(torch.diag(s)).mm(v.t())
725 if (det < 0) ^ (sign < 0):
730 def random_fullrank_matrix_distinct_singular_value(l, *batches, **kwargs):
731 silent = kwargs.get(
"silent",
False)
732 if silent
and not torch._C.has_lapack:
733 return torch.ones(l, l)
735 if len(batches) == 0:
736 A = torch.randn(l, l)
738 s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
739 return u.mm(torch.diag(s)).mm(v.t())
742 for _
in range(0, torch.prod(torch.as_tensor(batches)).item()):
743 A = torch.randn(l, l)
745 s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
746 all_matrices.append(u.mm(torch.diag(s)).mm(v.t()))
747 return torch.stack(all_matrices).reshape(*(batches + (l, l)))
750 def brute_pdist(inp, p=2):
751 """Computes the same as torch.pdist using primitives""" 756 return torch.empty(inp.shape[:-2] + (0,), dtype=inp.dtype, device=inp.device)
757 square = torch.norm(inp[...,
None, :] - inp[...,
None, :, :], p=p, dim=-1)
758 unroll = square.view(square.shape[:-2] + (n * n,))
759 inds = torch.ones(k, dtype=torch.int)
760 inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
761 return unroll[..., inds.cumsum(0)]
764 def brute_cdist(x, y, p=2):
767 if r1 == 0
or r2 == 0:
768 return torch.empty(r1, r2, device=x.device)
769 return torch.norm(x[...,
None, :] - y[...,
None, :, :], p=p, dim=-1)
772 def do_test_dtypes(self, dtypes, layout, device):
774 if dtype != torch.float16:
775 out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
776 self.assertIs(dtype, out.dtype)
777 self.assertIs(layout, out.layout)
781 def do_test_empty_full(self, dtypes, layout, device):
782 shape = torch.Size([2, 3])
784 def check_value(tensor, dtype, layout, device, value, requires_grad):
786 self.assertIs(dtype, tensor.dtype)
787 self.assertIs(layout, tensor.layout)
788 self.
assertEqual(tensor.requires_grad, requires_grad)
789 if tensor.is_cuda
and device
is not None:
791 if value
is not None:
792 fill = tensor.new(shape).fill_(value)
795 def get_int64_dtype(dtype):
796 module =
'.'.join(str(dtype).
split(
'.')[1:-1])
799 return operator.attrgetter(module)(torch).int64
801 default_dtype = torch.get_default_dtype()
802 check_value(torch.empty(shape), default_dtype, torch.strided, -1,
None,
False)
803 check_value(torch.full(shape, -5), default_dtype, torch.strided, -1,
None,
False)
805 for rg
in {dtype.is_floating_point,
False}:
806 int64_dtype = get_int64_dtype(dtype)
807 v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
808 check_value(v, dtype, layout, device,
None, rg)
810 check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg),
811 dtype, layout, device,
None, rg)
812 check_value(v.new_empty(shape), dtype, layout, device,
None,
False)
813 check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=
False),
814 int64_dtype, layout, device,
None,
False)
815 check_value(torch.empty_like(v), dtype, layout, device,
None,
False)
816 check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=
False),
817 int64_dtype, layout, device,
None,
False)
819 if dtype
is not torch.float16
and layout != torch.sparse_coo:
821 v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg)
822 check_value(v, dtype, layout, device, fv, rg)
823 check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1,
False)
825 check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg),
826 dtype, layout, device, fv + 2, rg)
827 check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=
False),
828 int64_dtype, layout, device, fv + 3,
False)
829 check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4,
False)
830 check_value(torch.full_like(v, fv + 5,
831 dtype=int64_dtype, layout=layout, device=device, requires_grad=
False),
832 int64_dtype, layout, device, fv + 5,
False)
835 IS_SANDCASTLE = os.getenv(
'SANDCASTLE') ==
'1' or os.getenv(
'TW_JOB_USER') ==
'sandcastle' 837 THESE_TAKE_WAY_TOO_LONG = {
838 'test_Conv3d_groups',
839 'test_conv_double_backward',
840 'test_conv_double_backward_groups',
841 'test_Conv3d_dilated',
842 'test_Conv3d_stride_padding',
843 'test_Conv3d_dilated_strided',
845 'test_Conv2d_dilated',
846 'test_ConvTranspose3d_dilated',
847 'test_ConvTranspose2d_dilated',
850 'test_Conv2d_padding',
851 'test_ConvTranspose2d_no_bias',
852 'test_ConvTranspose2d',
853 'test_ConvTranspose3d',
854 'test_Conv2d_no_bias',
856 'test_multinomial_invalid_probs',
860 running_script_path =
None 863 def set_running_script_path():
864 global running_script_path
866 running_file = os.path.abspath(os.path.realpath(sys.argv[0]))
867 if running_file.endswith(
'.py'):
868 running_script_path = running_file
873 def check_test_defined_in_running_script(test_case):
874 if running_script_path
is None:
881 test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__)))
882 assert test_case_class_file == running_script_path,
"Class of loaded TestCase \"{}\" " \
883 "is not defined in the running script \"{}\", but in \"{}\". Did you " \
884 "accidentally import a unittest.TestCase from another file?".format(
885 test_case.id(), running_script_path, test_case_class_file)
888 num_shards = os.environ.get(
'TEST_NUM_SHARDS',
None)
889 shard = os.environ.get(
'TEST_SHARD',
None)
890 if num_shards
is not None and shard
is not None:
891 num_shards = int(num_shards)
894 def load_tests(loader, tests, pattern):
895 set_running_script_path()
896 test_suite = unittest.TestSuite()
897 for test_group
in tests:
898 for test
in test_group:
899 check_test_defined_in_running_script(test)
900 name = test.id().
split(
'.')[-1]
901 if name
in THESE_TAKE_WAY_TOO_LONG:
903 hash_id = int(hashlib.sha256(str(test).encode(
'utf-8')).hexdigest(), 16)
904 if hash_id % num_shards == shard:
905 test_suite.addTest(test)
909 def load_tests(loader, tests, pattern):
910 set_running_script_path()
911 test_suite = unittest.TestSuite()
912 for test_group
in tests:
913 for test
in test_group:
914 check_test_defined_in_running_script(test)
915 test_suite.addTest(test)
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def wrap_with_cuda_memory_check(self, method)
Module caffe2.python.layers.split.
bool _do_cuda_memory_leak_check
def assertExpected(self, s, subname=None)
def disable_global_flags()
def memory_allocated(device=None)
def safeCoalesce(self, t)
def assertLeaksNoCudaTensors(self, name=None)
def get_cuda_memory_usage()
def set_default_tensor_type(t)