Caffe2 - Python API
A deep learning, cross platform ML framework
common_utils.py
1 r"""Importing this file must **not** initialize CUDA context. test_distributed
2 relies on this assumption to properly run. This means that when this is imported
3 no CUDA calls shall be made, including torch.cuda.device_count(), etc.
4 
5 common_cuda.py can freely initialize CUDA context when imported.
6 """
7 
8 import sys
9 import os
10 import platform
11 import re
12 import gc
13 import types
14 import inspect
15 import argparse
16 import unittest
17 import warnings
18 import random
19 import contextlib
20 import socket
21 import time
22 from collections import OrderedDict
23 from functools import wraps
24 from itertools import product
25 from copy import deepcopy
26 from numbers import Number
27 
28 import __main__
29 import errno
30 
31 import expecttest
32 import hashlib
33 
34 import torch
35 import torch.cuda
36 from torch._utils_internal import get_writable_path
37 from torch._six import string_classes, inf
39 import torch.backends.mkl
40 
41 
42 torch.set_default_tensor_type('torch.DoubleTensor')
44 
45 
46 parser = argparse.ArgumentParser(add_help=False)
47 parser.add_argument('--seed', type=int, default=1234)
48 parser.add_argument('--accept', action='store_true')
49 args, remaining = parser.parse_known_args()
50 SEED = args.seed
51 if not expecttest.ACCEPT:
52  expecttest.ACCEPT = args.accept
53 UNITTEST_ARGS = [sys.argv[0]] + remaining
54 torch.manual_seed(SEED)
55 
56 
57 def run_tests(argv=UNITTEST_ARGS):
58  unittest.main(argv=argv)
59 
60 PY3 = sys.version_info > (3, 0)
61 PY34 = sys.version_info >= (3, 4)
62 
63 IS_WINDOWS = sys.platform == "win32"
64 IS_PPC = platform.machine() == "ppc64le"
65 
66 # Environment variable `IS_PYTORCH_CI` is set in `.jenkins/common.sh`.
67 IS_PYTORCH_CI = bool(os.environ.get('IS_PYTORCH_CI', 0))
68 
69 
70 def _check_module_exists(name):
71  r"""Returns if a top-level module with :attr:`name` exists *without**
72  importing it. This is generally safer than try-catch block around a
73  `import X`. It avoids third party libraries breaking assumptions of some of
74  our tests, e.g., setting multiprocessing start method when imported
75  (see librosa/#747, torchvision/#544).
76  """
77  if not PY3: # Python 2
78  import imp
79  try:
80  imp.find_module(name)
81  return True
82  except ImportError:
83  return False
84  elif not PY34: # Python [3, 3.4)
85  import importlib
86  loader = importlib.find_loader(name)
87  return loader is not None
88  else: # Python >= 3.4
89  import importlib
90  import importlib.util
91  spec = importlib.util.find_spec(name)
92  return spec is not None
93 
94 TEST_NUMPY = _check_module_exists('numpy')
95 TEST_SCIPY = _check_module_exists('scipy')
97 TEST_NUMBA = _check_module_exists('numba')
98 
99 # On Py2, importing librosa 0.6.1 triggers a TypeError (if using newest joblib)
100 # see librosa/librosa#729.
101 # TODO: allow Py2 when librosa 0.6.2 releases
102 TEST_LIBROSA = _check_module_exists('librosa') and PY3
103 
104 # Python 2.7 doesn't have spawn
105 NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1' or sys.version_info[0] == 2
106 TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1'
107 TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1'
108 TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1'
109 
110 if TEST_NUMPY:
111  import numpy
112 
113 
114 def skipIfRocm(fn):
115  @wraps(fn)
116  def wrapper(*args, **kwargs):
117  if TEST_WITH_ROCM:
118  raise unittest.SkipTest("test doesn't currently work on the ROCm stack")
119  else:
120  fn(*args, **kwargs)
121  return wrapper
122 
123 
124 def skipIfNoLapack(fn):
125  @wraps(fn)
126  def wrapper(*args, **kwargs):
127  if not torch._C.has_lapack:
128  raise unittest.SkipTest('PyTorch compiled without Lapack')
129  else:
130  fn(*args, **kwargs)
131  return wrapper
132 
133 
134 def skipCUDAMemoryLeakCheckIf(condition):
135  def dec(fn):
136  if getattr(fn, '_do_cuda_memory_leak_check', True): # if current True
137  fn._do_cuda_memory_leak_check = not condition
138  return fn
139  return dec
140 
141 
142 def suppress_warnings(fn):
143  @wraps(fn)
144  def wrapper(*args, **kwargs):
145  with warnings.catch_warnings():
146  warnings.simplefilter("ignore")
147  fn(*args, **kwargs)
148  return wrapper
149 
150 
151 def get_cpu_type(type_name):
152  module, name = type_name.rsplit('.', 1)
153  assert module == 'torch.cuda'
154  return getattr(torch, name)
155 
156 
157 def get_gpu_type(type_name):
158  if isinstance(type_name, type):
159  type_name = '{}.{}'.format(type_name.__module__, type_name.__name__)
160  module, name = type_name.rsplit('.', 1)
161  assert module == 'torch'
162  return getattr(torch.cuda, name)
163 
164 
165 def to_gpu(obj, type_map=None):
166  if type_map is None:
167  type_map = {}
168  if isinstance(obj, torch.Tensor):
169  assert obj.is_leaf
170  t = type_map.get(obj.type(), get_gpu_type(obj.type()))
171  with torch.no_grad():
172  res = obj.clone().type(t)
173  res.requires_grad = obj.requires_grad
174  return res
175  elif torch.is_storage(obj):
176  return obj.new().resize_(obj.size()).copy_(obj)
177  elif isinstance(obj, list):
178  return [to_gpu(o, type_map) for o in obj]
179  elif isinstance(obj, tuple):
180  return tuple(to_gpu(o, type_map) for o in obj)
181  else:
182  return deepcopy(obj)
183 
184 
185 def get_function_arglist(func):
186  if sys.version_info > (3,):
187  return inspect.getfullargspec(func).args
188  else:
189  return inspect.getargspec(func).args
190 
191 
192 def set_rng_seed(seed):
193  torch.manual_seed(seed)
194  random.seed(seed)
195  if TEST_NUMPY:
196  numpy.random.seed(seed)
197 
198 
199 @contextlib.contextmanager
200 def freeze_rng_state():
201  rng_state = torch.get_rng_state()
203  cuda_rng_state = torch.cuda.get_rng_state()
204  yield
206  torch.cuda.set_rng_state(cuda_rng_state)
207  torch.set_rng_state(rng_state)
208 
209 
210 def iter_indices(tensor):
211  if tensor.dim() == 0:
212  return range(0)
213  if tensor.dim() == 1:
214  return range(tensor.size(0))
215  return product(*(range(s) for s in tensor.size()))
216 
217 
218 def is_iterable(obj):
219  try:
220  iter(obj)
221  return True
222  except TypeError:
223  return False
224 
225 
227  def __init__(self, testcase, name=None):
228  self.name = testcase.id() if name is None else name
229  self.testcase = testcase
230 
231  # initialize context & RNG to prevent false positive detections
232  # when the test is the first to initialize those
233  from common_cuda import initialize_cuda_context_rng
234  initialize_cuda_context_rng()
235 
236  @staticmethod
237  def get_cuda_memory_usage():
238  # we don't need CUDA synchronize because the statistics are not tracked at
239  # actual freeing, but at when marking the block as free.
240  num_devices = torch.cuda.device_count()
241  gc.collect()
242  return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
243 
244  def __enter__(self):
245  self.befores = self.get_cuda_memory_usage()
246 
247  def __exit__(self, exec_type, exec_value, traceback):
248  # Don't check for leaks if an exception was thrown
249  if exec_type is not None:
250  return
251  afters = self.get_cuda_memory_usage()
252 
253  for i, (before, after) in enumerate(zip(self.befores, afters)):
254  if not TEST_WITH_ROCM:
255  self.testcase.assertEqual(
256  before, after, '{} leaked {} bytes CUDA memory on device {}'.format(
257  self.name, after - before, i))
258  else:
259  # TODO: Investigate ROCm memory leaking.
260  if before != after:
261  warnings.warn('{} leaked {} bytes ROCm memory on device {}'.format(
262  self.name, after - before, i), RuntimeWarning)
263 
264 
266  precision = 1e-5
267  maxDiff = None
268  _do_cuda_memory_leak_check = False
269 
270  def __init__(self, method_name='runTest'):
271  super(TestCase, self).__init__(method_name)
272  # Wraps the tested method if we should do CUDA memory check.
273  test_method = getattr(self, method_name)
274  self._do_cuda_memory_leak_check &= getattr(test_method, '_do_cuda_memory_leak_check', True)
275  # FIXME: figure out the flaky -1024 anti-leaks on windows. See #8044
276  if self._do_cuda_memory_leak_check and not IS_WINDOWS:
277  # the import below may initialize CUDA context, so we do it only if
278  # self._do_cuda_memory_leak_check is True.
279  from common_cuda import TEST_CUDA
280  fullname = self.id().lower() # class_name.method_name
281  if TEST_CUDA and ('gpu' in fullname or 'cuda' in fullname):
282  setattr(self, method_name, self.wrap_with_cuda_memory_check(test_method))
283 
284  def assertLeaksNoCudaTensors(self, name=None):
285  name = self.id() if name is None else name
286  return CudaMemoryLeakCheck(self, name)
287 
288  def wrap_with_cuda_memory_check(self, method):
289  # Assumes that `method` is the tested function in `self`.
290  # NOTE: Python Exceptions (e.g., unittest.Skip) keeps objects in scope
291  # alive, so this cannot be done in setUp and tearDown because
292  # tearDown is run unconditionally no matter whether the test
293  # passes or not. For the same reason, we can't wrap the `method`
294  # call in try-finally and always do the check.
295  @wraps(method)
296  def wrapper(self, *args, **kwargs):
297  with self.assertLeaksNoCudaTensors():
298  method(*args, **kwargs)
299  return types.MethodType(wrapper, self)
300 
301  def setUp(self):
302  set_rng_seed(SEED)
303 
304  def assertTensorsSlowEqual(self, x, y, prec=None, message=''):
305  max_err = 0
306  self.assertEqual(x.size(), y.size())
307  for index in iter_indices(x):
308  max_err = max(max_err, abs(x[index] - y[index]))
309  self.assertLessEqual(max_err, prec, message)
310 
311  def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device='cpu'):
312  # Assert not given impossible combination, where the sparse dims have
313  # empty numel, but nnz > 0 makes the indices containing values.
314  assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
315 
316  v_size = [nnz] + list(size[sparse_dim:])
317  v = torch.randn(*v_size, device=device)
318  i = torch.rand(sparse_dim, nnz, device=device)
319  i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
320  i = i.to(torch.long)
321  if is_uncoalesced:
322  v = torch.cat([v, torch.randn_like(v)], 0)
323  i = torch.cat([i, i], 1)
324 
325  x = torch.sparse_coo_tensor(i, v, torch.Size(size))
326 
327  if not is_uncoalesced:
328  x = x.coalesce()
329  else:
330  # FIXME: `x` is a sparse view of `v`. Currently rebase_history for
331  # sparse views is not implemented, so this workaround is
332  # needed for inplace operations done on `x`, e.g., copy_().
333  # Remove after implementing something equivalent to CopySlice
334  # for sparse views.
335  # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards
336  x = x.detach().clone()
337  return x, x._indices().clone(), x._values().clone()
338 
339  def safeToDense(self, t):
340  r = self.safeCoalesce(t)
341  return r.to_dense()
342 
343  def safeCoalesce(self, t):
344  tc = t.coalesce()
345  self.assertEqual(tc.to_dense(), t.to_dense())
346  self.assertTrue(tc.is_coalesced())
347 
348  # Our code below doesn't work when nnz is 0, because
349  # then it's a 0D tensor, not a 2D tensor.
350  if t._nnz() == 0:
351  self.assertEqual(t._indices(), tc._indices())
352  self.assertEqual(t._values(), tc._values())
353  return tc
354 
355  value_map = {}
356  for idx, val in zip(t._indices().t(), t._values()):
357  idx_tup = tuple(idx.tolist())
358  if idx_tup in value_map:
359  value_map[idx_tup] += val
360  else:
361  value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
362 
363  new_indices = sorted(list(value_map.keys()))
364  new_values = [value_map[idx] for idx in new_indices]
365  if t._values().ndimension() < 2:
366  new_values = t._values().new(new_values)
367  else:
368  new_values = torch.stack(new_values)
369 
370  new_indices = t._indices().new(new_indices).t()
371  tg = t.new(new_indices, new_values, t.size())
372 
373  self.assertEqual(tc._indices(), tg._indices())
374  self.assertEqual(tc._values(), tg._values())
375 
376  if t.is_coalesced():
377  self.assertEqual(tc._indices(), t._indices())
378  self.assertEqual(tc._values(), t._values())
379 
380  return tg
381 
382  def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
383  if isinstance(prec, str) and message == '':
384  message = prec
385  prec = None
386  if prec is None:
387  prec = self.precision
388 
389  if isinstance(x, torch.Tensor) and isinstance(y, Number):
390  self.assertEqual(x.item(), y, prec, message, allow_inf)
391  elif isinstance(y, torch.Tensor) and isinstance(x, Number):
392  self.assertEqual(x, y.item(), prec, message, allow_inf)
393  elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
394  def assertTensorsEqual(a, b):
395  super(TestCase, self).assertEqual(a.size(), b.size(), message)
396  if a.numel() > 0:
397  if a.device.type == 'cpu' and a.dtype == torch.float16:
398  # CPU half tensors don't have the methods we need below
399  a = a.to(torch.float32)
400  if TEST_WITH_ROCM:
401  # Workaround for bug https://github.com/pytorch/pytorch/issues/16448
402  # TODO: remove after the bug is resolved.
403  b = b.to(a.dtype).to(a.device)
404  else:
405  b = b.to(a)
406 
407  if x.dtype == torch.bool and y.dtype == torch.bool:
408  self.assertEqual(x.tolist(), y.tolist())
409  elif x.dtype == torch.bool or y.dtype == torch.bool:
410  raise TypeError("Was expecting both tensors to be bool type.")
411  else:
412  diff = a - b
413  if a.is_floating_point():
414  # check that NaNs are in the same locations
415  nan_mask = torch.isnan(a)
416  self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
417  diff[nan_mask] = 0
418  # inf check if allow_inf=True
419  if allow_inf:
420  inf_mask = torch.isinf(a)
421  inf_sign = inf_mask.sign()
422  self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
423  diff[inf_mask] = 0
424  # TODO: implement abs on CharTensor (int8)
425  if diff.is_signed() and diff.dtype != torch.int8:
426  diff = diff.abs()
427  max_err = diff.max()
428  self.assertLessEqual(max_err, prec, message)
429  super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
430  if x.is_sparse:
431  x = self.safeCoalesce(x)
432  y = self.safeCoalesce(y)
433  assertTensorsEqual(x._indices(), y._indices())
434  assertTensorsEqual(x._values(), y._values())
435  else:
436  assertTensorsEqual(x, y)
437  elif isinstance(x, string_classes) and isinstance(y, string_classes):
438  super(TestCase, self).assertEqual(x, y, message)
439  elif type(x) == set and type(y) == set:
440  super(TestCase, self).assertEqual(x, y, message)
441  elif isinstance(x, dict) and isinstance(y, dict):
442  if isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
443  self.assertEqual(x.items(), y.items())
444  else:
445  self.assertEqual(set(x.keys()), set(y.keys()))
446  key_list = list(x.keys())
447  self.assertEqual([x[k] for k in key_list], [y[k] for k in key_list])
448  elif is_iterable(x) and is_iterable(y):
449  super(TestCase, self).assertEqual(len(x), len(y), message)
450  for x_, y_ in zip(x, y):
451  self.assertEqual(x_, y_, prec, message)
452  elif isinstance(x, bool) and isinstance(y, bool):
453  super(TestCase, self).assertEqual(x, y, message)
454  elif isinstance(x, Number) and isinstance(y, Number):
455  if abs(x) == inf or abs(y) == inf:
456  if allow_inf:
457  super(TestCase, self).assertEqual(x, y, message)
458  else:
459  self.fail("Expected finite numeric values - x={}, y={}".format(x, y))
460  return
461  super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
462  else:
463  super(TestCase, self).assertEqual(x, y, message)
464 
465  def assertAlmostEqual(self, x, y, places=None, msg=None, delta=None, allow_inf=None):
466  prec = delta
467  if places:
468  prec = 10**(-places)
469  self.assertEqual(x, y, prec, msg, allow_inf)
470 
471  def assertNotEqual(self, x, y, prec=None, message=''):
472  if isinstance(prec, str) and message == '':
473  message = prec
474  prec = None
475  if prec is None:
476  prec = self.precision
477 
478  if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
479  if x.size() != y.size():
480  super(TestCase, self).assertNotEqual(x.size(), y.size())
481  self.assertGreater(x.numel(), 0)
482  y = y.type_as(x)
483  y = y.cuda(device=x.get_device()) if x.is_cuda else y.cpu()
484  nan_mask = x != x
485  if torch.equal(nan_mask, y != y):
486  diff = x - y
487  if diff.is_signed():
488  diff = diff.abs()
489  diff[nan_mask] = 0
490  max_err = diff.max()
491  self.assertGreaterEqual(max_err, prec, message)
492  elif type(x) == str and type(y) == str:
493  super(TestCase, self).assertNotEqual(x, y)
494  elif is_iterable(x) and is_iterable(y):
495  super(TestCase, self).assertNotEqual(x, y)
496  else:
497  try:
498  self.assertGreaterEqual(abs(x - y), prec, message)
499  return
500  except (TypeError, AssertionError):
501  pass
502  super(TestCase, self).assertNotEqual(x, y, message)
503 
504  def assertObjectIn(self, obj, iterable):
505  for elem in iterable:
506  if id(obj) == id(elem):
507  return
508  raise AssertionError("object not found in iterable")
509 
510  # TODO: Support context manager interface
511  # NB: The kwargs forwarding to callable robs the 'subname' parameter.
512  # If you need it, manually apply your callable in a lambda instead.
513  def assertExpectedRaises(self, exc_type, callable, *args, **kwargs):
514  subname = None
515  if 'subname' in kwargs:
516  subname = kwargs['subname']
517  del kwargs['subname']
518  try:
519  callable(*args, **kwargs)
520  except exc_type as e:
521  self.assertExpected(str(e), subname)
522  return
523  # Don't put this in the try block; the AssertionError will catch it
524  self.fail(msg="Did not raise when expected to")
525 
526  def assertWarns(self, callable, msg=''):
527  r"""
528  Test if :attr:`callable` raises a warning.
529  """
530  with warnings.catch_warnings(record=True) as ws:
531  warnings.simplefilter("always") # allow any warning to be raised
532  callable()
533  self.assertTrue(len(ws) > 0, msg)
534 
535  def assertWarnsRegex(self, callable, regex, msg=''):
536  r"""
537  Test if :attr:`callable` raises any warning with message that contains
538  the regex pattern :attr:`regex`.
539  """
540  with warnings.catch_warnings(record=True) as ws:
541  warnings.simplefilter("always") # allow any warning to be raised
542  callable()
543  self.assertTrue(len(ws) > 0, msg)
544  found = any(re.search(regex, str(w.message)) is not None for w in ws)
545  self.assertTrue(found, msg)
546 
547  def assertExpected(self, s, subname=None):
548  r"""
549  Test that a string matches the recorded contents of a file
550  derived from the name of this test and subname. This file
551  is placed in the 'expect' directory in the same directory
552  as the test script. You can automatically update the recorded test
553  output using --accept.
554 
555  If you call this multiple times in a single function, you must
556  give a unique subname each time.
557  """
558  if not (isinstance(s, str) or (sys.version_info[0] == 2 and isinstance(s, unicode))):
559  raise TypeError("assertExpected is strings only")
560 
561  def remove_prefix(text, prefix):
562  if text.startswith(prefix):
563  return text[len(prefix):]
564  return text
565  # NB: we take __file__ from the module that defined the test
566  # class, so we place the expect directory where the test script
567  # lives, NOT where test/common_utils.py lives. This doesn't matter in
568  # PyTorch where all test scripts are in the same directory as
569  # test/common_utils.py, but it matters in onnx-pytorch
570  module_id = self.__class__.__module__
571  munged_id = remove_prefix(self.id(), module_id + ".")
572  test_file = os.path.realpath(sys.modules[module_id].__file__)
573  expected_file = os.path.join(os.path.dirname(test_file),
574  "expect",
575  munged_id)
576 
577  subname_output = ""
578  if subname:
579  expected_file += "-" + subname
580  subname_output = " ({})".format(subname)
581  expected_file += ".expect"
582  expected = None
583 
584  def accept_output(update_type):
585  print("Accepting {} for {}{}:\n\n{}".format(update_type, munged_id, subname_output, s))
586  with open(expected_file, 'w') as f:
587  f.write(s)
588 
589  try:
590  with open(expected_file) as f:
591  expected = f.read()
592  except IOError as e:
593  if e.errno != errno.ENOENT:
594  raise
595  elif expecttest.ACCEPT:
596  return accept_output("output")
597  else:
598  raise RuntimeError(
599  ("I got this output for {}{}:\n\n{}\n\n"
600  "No expect file exists; to accept the current output, run:\n"
601  "python {} {} --accept").format(munged_id, subname_output, s, __main__.__file__, munged_id))
602 
603  # a hack for JIT tests
604  if IS_WINDOWS:
605  expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected)
606  s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s)
607 
608  if expecttest.ACCEPT:
609  if expected != s:
610  return accept_output("updated output")
611  else:
612  if hasattr(self, "assertMultiLineEqual"):
613  # Python 2.7 only
614  # NB: Python considers lhs "old" and rhs "new".
615  self.assertMultiLineEqual(expected, s)
616  else:
617  self.assertEqual(s, expected)
618 
619  if sys.version_info < (3, 2):
620  # assertRegexpMatches renamed to assertRegex in 3.2
621  assertRegex = unittest.TestCase.assertRegexpMatches
622  # assertRaisesRegexp renamed to assertRaisesRegex in 3.2
623  assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
624 
625 
626 def download_file(url, binary=True):
627  if sys.version_info < (3,):
628  from urlparse import urlsplit
629  import urllib2
630  request = urllib2
631  error = urllib2
632  else:
633  from urllib.parse import urlsplit
634  from urllib import request, error
635 
636  filename = os.path.basename(urlsplit(url)[2])
637  data_dir = get_writable_path(os.path.join(os.path.dirname(__file__), 'data'))
638  path = os.path.join(data_dir, filename)
639 
640  if os.path.exists(path):
641  return path
642  try:
643  data = request.urlopen(url, timeout=15).read()
644  with open(path, 'wb' if binary else 'w') as f:
645  f.write(data)
646  return path
647  except error.URLError:
648  msg = "could not download test file '{}'".format(url)
649  warnings.warn(msg, RuntimeWarning)
650  raise unittest.SkipTest(msg)
651 
652 
653 def find_free_port():
654  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
655  sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
656  sock.bind(('localhost', 0))
657  sockname = sock.getsockname()
658  sock.close()
659  return sockname[1]
660 
661 
662 def retry_on_address_already_in_use_error(func):
663  """Reruns a test if it sees "Address already in use" error."""
664  @wraps(func)
665  def wrapper(*args, **kwargs):
666  tries_remaining = 10
667  while True:
668  try:
669  return func(*args, **kwargs)
670  except RuntimeError as error:
671  if str(error) == "Address already in use":
672  tries_remaining -= 1
673  if tries_remaining == 0:
674  raise
675  time.sleep(random.random())
676  continue
677  raise
678  return wrapper
679 
680 
681 # Methods for matrix generation
682 # Used in test_autograd.py and test_torch.py
683 def prod_single_zero(dim_size):
684  result = torch.randn(dim_size, dim_size)
685  result[0, 1] = 0
686  return result
687 
688 
689 def random_square_matrix_of_rank(l, rank):
690  assert rank <= l
691  A = torch.randn(l, l)
692  u, s, v = A.svd()
693  for i in range(l):
694  if i >= rank:
695  s[i] = 0
696  elif s[i] == 0:
697  s[i] = 1
698  return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
699 
700 
701 def random_symmetric_matrix(l):
702  A = torch.randn(l, l)
703  for i in range(l):
704  for j in range(i):
705  A[i, j] = A[j, i]
706  return A
707 
708 
709 def random_symmetric_psd_matrix(l):
710  A = torch.randn(l, l)
711  return A.mm(A.transpose(0, 1))
712 
713 
714 def random_symmetric_pd_matrix(l, *batches):
715  A = torch.randn(*(batches + (l, l)))
716  return A.matmul(A.transpose(-2, -1)) + torch.eye(l) * 1e-5
717 
718 
719 def make_nonzero_det(A, sign=None, min_singular_value=0.1):
720  u, s, v = A.svd()
721  s[s < min_singular_value] = min_singular_value
722  A = u.mm(torch.diag(s)).mm(v.t())
723  det = A.det().item()
724  if sign is not None:
725  if (det < 0) ^ (sign < 0):
726  A[0, :].neg_()
727  return A
728 
729 
730 def random_fullrank_matrix_distinct_singular_value(l, *batches, **kwargs):
731  silent = kwargs.get("silent", False)
732  if silent and not torch._C.has_lapack:
733  return torch.ones(l, l)
734 
735  if len(batches) == 0:
736  A = torch.randn(l, l)
737  u, _, v = A.svd()
738  s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
739  return u.mm(torch.diag(s)).mm(v.t())
740  else:
741  all_matrices = []
742  for _ in range(0, torch.prod(torch.as_tensor(batches)).item()):
743  A = torch.randn(l, l)
744  u, _, v = A.svd()
745  s = torch.arange(1., l + 1).mul_(1.0 / (l + 1))
746  all_matrices.append(u.mm(torch.diag(s)).mm(v.t()))
747  return torch.stack(all_matrices).reshape(*(batches + (l, l)))
748 
749 
750 def brute_pdist(inp, p=2):
751  """Computes the same as torch.pdist using primitives"""
752  n = inp.shape[-2]
753  k = n * (n - 1) // 2
754  if k == 0:
755  # torch complains about empty indices
756  return torch.empty(inp.shape[:-2] + (0,), dtype=inp.dtype, device=inp.device)
757  square = torch.norm(inp[..., None, :] - inp[..., None, :, :], p=p, dim=-1)
758  unroll = square.view(square.shape[:-2] + (n * n,))
759  inds = torch.ones(k, dtype=torch.int)
760  inds[torch.arange(n - 1, 1, -1, dtype=torch.int).cumsum(0)] += torch.arange(2, n, dtype=torch.int)
761  return unroll[..., inds.cumsum(0)]
762 
763 
764 def brute_cdist(x, y, p=2):
765  r1 = x.shape[-2]
766  r2 = y.shape[-2]
767  if r1 == 0 or r2 == 0:
768  return torch.empty(r1, r2, device=x.device)
769  return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
770 
771 
772 def do_test_dtypes(self, dtypes, layout, device):
773  for dtype in dtypes:
774  if dtype != torch.float16:
775  out = torch.zeros((2, 3), dtype=dtype, layout=layout, device=device)
776  self.assertIs(dtype, out.dtype)
777  self.assertIs(layout, out.layout)
778  self.assertEqual(device, out.device)
779 
780 
781 def do_test_empty_full(self, dtypes, layout, device):
782  shape = torch.Size([2, 3])
783 
784  def check_value(tensor, dtype, layout, device, value, requires_grad):
785  self.assertEqual(shape, tensor.shape)
786  self.assertIs(dtype, tensor.dtype)
787  self.assertIs(layout, tensor.layout)
788  self.assertEqual(tensor.requires_grad, requires_grad)
789  if tensor.is_cuda and device is not None:
790  self.assertEqual(device, tensor.device)
791  if value is not None:
792  fill = tensor.new(shape).fill_(value)
793  self.assertEqual(tensor, fill)
794 
795  def get_int64_dtype(dtype):
796  module = '.'.join(str(dtype).split('.')[1:-1])
797  if not module:
798  return torch.int64
799  return operator.attrgetter(module)(torch).int64
800 
801  default_dtype = torch.get_default_dtype()
802  check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False)
803  check_value(torch.full(shape, -5), default_dtype, torch.strided, -1, None, False)
804  for dtype in dtypes:
805  for rg in {dtype.is_floating_point, False}:
806  int64_dtype = get_int64_dtype(dtype)
807  v = torch.empty(shape, dtype=dtype, device=device, layout=layout, requires_grad=rg)
808  check_value(v, dtype, layout, device, None, rg)
809  out = v.new()
810  check_value(torch.empty(shape, out=out, device=device, layout=layout, requires_grad=rg),
811  dtype, layout, device, None, rg)
812  check_value(v.new_empty(shape), dtype, layout, device, None, False)
813  check_value(v.new_empty(shape, dtype=int64_dtype, device=device, requires_grad=False),
814  int64_dtype, layout, device, None, False)
815  check_value(torch.empty_like(v), dtype, layout, device, None, False)
816  check_value(torch.empty_like(v, dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
817  int64_dtype, layout, device, None, False)
818 
819  if dtype is not torch.float16 and layout != torch.sparse_coo:
820  fv = 3
821  v = torch.full(shape, fv, dtype=dtype, layout=layout, device=device, requires_grad=rg)
822  check_value(v, dtype, layout, device, fv, rg)
823  check_value(v.new_full(shape, fv + 1), dtype, layout, device, fv + 1, False)
824  out = v.new()
825  check_value(torch.full(shape, fv + 2, out=out, device=device, layout=layout, requires_grad=rg),
826  dtype, layout, device, fv + 2, rg)
827  check_value(v.new_full(shape, fv + 3, dtype=int64_dtype, device=device, requires_grad=False),
828  int64_dtype, layout, device, fv + 3, False)
829  check_value(torch.full_like(v, fv + 4), dtype, layout, device, fv + 4, False)
830  check_value(torch.full_like(v, fv + 5,
831  dtype=int64_dtype, layout=layout, device=device, requires_grad=False),
832  int64_dtype, layout, device, fv + 5, False)
833 
834 
835 IS_SANDCASTLE = os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle'
836 
837 THESE_TAKE_WAY_TOO_LONG = {
838  'test_Conv3d_groups',
839  'test_conv_double_backward',
840  'test_conv_double_backward_groups',
841  'test_Conv3d_dilated',
842  'test_Conv3d_stride_padding',
843  'test_Conv3d_dilated_strided',
844  'test_Conv3d',
845  'test_Conv2d_dilated',
846  'test_ConvTranspose3d_dilated',
847  'test_ConvTranspose2d_dilated',
848  'test_snli',
849  'test_Conv2d',
850  'test_Conv2d_padding',
851  'test_ConvTranspose2d_no_bias',
852  'test_ConvTranspose2d',
853  'test_ConvTranspose3d',
854  'test_Conv2d_no_bias',
855  'test_matmul_4d_4d',
856  'test_multinomial_invalid_probs',
857 }
858 
859 
860 running_script_path = None
861 
862 
863 def set_running_script_path():
864  global running_script_path
865  try:
866  running_file = os.path.abspath(os.path.realpath(sys.argv[0]))
867  if running_file.endswith('.py'): # skip if the running file is not a script
868  running_script_path = running_file
869  except Exception:
870  pass
871 
872 
873 def check_test_defined_in_running_script(test_case):
874  if running_script_path is None:
875  return
876  if TEST_WITH_ROCM:
877  # In ROCm CI, to avoid forking after HIP is initialized, we
878  # indeed load test module from test/run_test.py and run all
879  # tests in the same process.
880  return
881  test_case_class_file = os.path.abspath(os.path.realpath(inspect.getfile(test_case.__class__)))
882  assert test_case_class_file == running_script_path, "Class of loaded TestCase \"{}\" " \
883  "is not defined in the running script \"{}\", but in \"{}\". Did you " \
884  "accidentally import a unittest.TestCase from another file?".format(
885  test_case.id(), running_script_path, test_case_class_file)
886 
887 
888 num_shards = os.environ.get('TEST_NUM_SHARDS', None)
889 shard = os.environ.get('TEST_SHARD', None)
890 if num_shards is not None and shard is not None:
891  num_shards = int(num_shards)
892  shard = int(shard)
893 
894  def load_tests(loader, tests, pattern):
895  set_running_script_path()
896  test_suite = unittest.TestSuite()
897  for test_group in tests:
898  for test in test_group:
899  check_test_defined_in_running_script(test)
900  name = test.id().split('.')[-1]
901  if name in THESE_TAKE_WAY_TOO_LONG:
902  continue
903  hash_id = int(hashlib.sha256(str(test).encode('utf-8')).hexdigest(), 16)
904  if hash_id % num_shards == shard:
905  test_suite.addTest(test)
906  return test_suite
907 else:
908 
909  def load_tests(loader, tests, pattern):
910  set_running_script_path()
911  test_suite = unittest.TestSuite()
912  for test_group in tests:
913  for test in test_group:
914  check_test_defined_in_running_script(test)
915  test_suite.addTest(test)
916  return test_suite
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def wrap_with_cuda_memory_check(self, method)
Module caffe2.python.layers.split.
def assertExpected(self, s, subname=None)
def is_available()
Definition: __init__.py:45
def device_count()
Definition: __init__.py:341
def disable_global_flags()
Definition: __init__.py:155
def memory_allocated(device=None)
Definition: __init__.py:409
def safeCoalesce(self, t)
def is_available()
Definition: __init__.py:4
def assertLeaksNoCudaTensors(self, name=None)
def set_default_tensor_type(t)
Definition: __init__.py:132
def is_storage(obj)
Definition: __init__.py:123