22 from torch._six import inf, nan, string_classes, istuple
23 from itertools
import product, combinations, combinations_with_replacement
24 from functools
import reduce
25 from torch
import multiprocessing
as mp
26 from common_methods_invocations
import tri_tests_args, run_additional_tri_tests, \
27 _compare_trilu_indices
28 from common_utils
import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
29 TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
30 IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
31 IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist
32 from multiprocessing.reduction
import ForkingPickler
36 load_tests = load_tests
42 from scipy
import signal
49 can_retrieve_source =
True 50 with warnings.catch_warnings(record=
True)
as warns:
51 with tempfile.NamedTemporaryFile()
as checkpoint:
52 x = torch.save(torch.nn.Module(), checkpoint)
54 if "Couldn't retrieve source code" in warn.message.args[0]:
55 can_retrieve_source =
False 60 def __init__(self, data, has_fileno=True, has_readinto=False):
72 def result(*args, **kwargs):
74 return fn(*args, **kwargs)
77 for attr
in [
'read',
'readline',
'seek',
'tell',
'write',
'flush']:
78 traced_fn = trace(getattr(self.
bytesio, attr), attr)
79 setattr(self, attr, traced_fn)
82 raise io.UnsupportedOperation(
'Not a real file')
84 def readinto_opt(self, view):
85 self.calls.add(
'readinto')
86 return self.bytesio.readinto(view)
88 def was_called(self, name):
89 return name
in self.
calls 96 def __exit__(self, *args):
102 class _TestTorchMixin(object):
103 def _check_sum_dim(tensors, dim):
104 for tensor
in tensors:
105 expected = tensor.numpy().sum(dim)
106 actual = tensor.sum(dim)
107 self.assertEqual(expected.shape, actual.shape)
108 if actual.dtype == torch.float:
109 self.assertTrue(np.allclose(expected, actual.numpy(), rtol=1e-03, atol=1e-05))
111 self.assertTrue(np.allclose(expected, actual.numpy()))
113 def _make_tensors(self, shape, val_range=(-100, 100), use_floating=
True, use_integral=
True):
114 float_types = [torch.double,
116 int_types = [torch.int64,
120 def make_contiguous(shape, dtype):
121 if dtype
in float_types:
122 val = torch.randn(shape, dtype=dtype)
123 val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
124 val = val + ((val_range[1] - val_range[0]) / 2.0)
125 val = torch.clamp(val, min=val_range[0], max=val_range[1])
127 result = torch.zeros(shape, dtype=dtype)
128 result.apply_(
lambda x: random.randint(val_range[0], val_range[1]))
131 def make_non_contiguous(shape, dtype):
132 contig = make_contiguous(shape, dtype)
133 non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
134 non_contig = non_contig.select(-1, -1)
135 non_contig.copy_(contig)
136 self.assertFalse(non_contig.is_contiguous())
139 def make_contiguous_slice(size, dtype):
140 contig = make_contiguous((1, size), dtype)
141 non_contig = contig[:1, 1:size - 1]
142 self.assertTrue(non_contig.is_contiguous())
150 tensors = {
"cont": [],
"noncont": [],
"slice": []}
152 tensors[
"cont"].append(make_contiguous(shape, dtype))
153 tensors[
"noncont"].append(make_non_contiguous(shape, dtype))
154 tensors[
"slice"].append(make_contiguous_slice(sum(list(shape)), dtype))
162 checked_types = (types.MethodType, types.FunctionType,
163 types.BuiltinFunctionType, types.BuiltinMethodType)
165 def test_namespace(ns, *skips):
166 if isinstance(ns, object):
167 ns_name = ns.__class__.__name__
169 ns_name = ns.__name__
172 if isinstance(r, string_classes):
173 skip_regexes.append(re.compile(
'^{}$'.format(re.escape(r))))
175 skip_regexes.append(r)
177 if name.startswith(
'_'):
179 var = getattr(ns, name)
180 if not isinstance(var, checked_types):
183 has_doc = doc
is not None and len(doc.strip()) > 0
184 full_name = ns_name +
'.' + name
185 if any(r.match(name)
for r
in skip_regexes):
186 self.assertFalse(has_doc,
187 'New docs have been added for {}, please remove ' 188 'it from the skipped list in TestTorch.test_doc'.format(full_name))
190 self.assertTrue(has_doc,
'{} is missing documentation'.format(full_name))
193 test_namespace(torch.randn(1),
196 re.compile(
'^clamp_(min|max)_?$'),
229 'sparse_resize_and_clear_',
232 test_namespace(
torch.nn.functional,
'assert_int_or_pair',
'bilinear',
'feature_alpha_dropout')
238 'torch.DoubleTensor': 1e-8,
239 'torch.FloatTensor': 1e-4,
241 for tname, _prec
in types.items():
242 v1 = torch.randn(100).type(tname)
243 v2 = torch.randn(100).type(tname)
244 res1 = torch.dot(v1, v2)
246 for i, j
in zip(v1, v2):
248 self.assertEqual(res1, res2)
249 out = torch.randn(()).type(tname)
250 torch.dot(v1, v2, out=out)
251 self.assertEqual(res1, out)
254 for tname, _prec
in types.items():
255 v1 = torch.randn(1).type(tname).expand(100)
256 v2 = torch.randn(100).type(tname)
257 res1 = torch.dot(v1, v2)
259 for i, j
in zip(v1, v2):
261 self.assertEqual(res1, res2)
262 out = torch.randn(()).type(tname)
263 torch.dot(v1, v2, out=out)
264 self.assertEqual(res1, out)
268 'torch.DoubleTensor': 1e-8,
269 'torch.FloatTensor': 1e-4,
271 for tname, _prec
in types.items():
272 v1 = torch.randn(100).type(tname)
273 v2 = torch.randn(100).type(tname)
274 res1 = torch.ger(v1, v2)
275 res2 = torch.zeros(100, 100).type(tname)
278 res2[i, j] = v1[i] * v2[j]
279 self.assertEqual(res1, res2)
282 for tname, _prec
in types.items():
283 v1 = torch.randn(1).type(tname).expand(100)
284 v2 = torch.randn(100).type(tname)
285 res1 = torch.ger(v1, v2)
286 res2 = torch.zeros(100, 100).type(tname)
289 res2[i, j] = v1[i] * v2[j]
290 self.assertEqual(res1, res2)
294 'torch.DoubleTensor': 1e-8,
295 'torch.FloatTensor': 1e-4,
298 def run_test(m, v1, v2, m_transform=lambda x: x):
299 m = m_transform(m.clone())
301 torch.addr(m, v1, v2, out=m)
302 for i
in range(m.size(0)):
303 for j
in range(m.size(1)):
304 ref[i, j] += v1[i] * v2[j]
305 self.assertEqual(m, ref)
307 for tname, _prec
in types.items():
308 for h, w
in [(100, 110), (1, 20), (200, 2)]:
309 m = torch.randn(h, w).type(tname)
310 v1 = torch.randn(h).type(tname)
311 v2 = torch.randn(w).type(tname)
314 run_test(m, v2, v1,
lambda x: x.transpose(0, 1))
316 v1 = torch.randn(1).type(tname).expand(h)
318 run_test(m, v2, v1,
lambda x: x.transpose(0, 1))
320 def test_addmv(self):
322 'torch.DoubleTensor': 1e-8,
323 'torch.FloatTensor': 1e-4,
325 for tname, _prec
in types.items():
326 t = torch.randn(10).type(tname)
327 m = torch.randn(10, 100).type(tname)
328 v = torch.randn(100).type(tname)
329 res1 = torch.addmv(t, m, v)
330 res2 = torch.zeros(10).type(tname)
334 res2[i] += m[i, j] * v[j]
335 self.assertEqual(res1, res2)
338 for tname, _prec
in types.items():
339 t = torch.randn(1).type(tname).expand(10)
340 m = torch.randn(10, 1).type(tname).expand(10, 100)
341 v = torch.randn(100).type(tname)
342 res1 = torch.addmv(t, m, v)
343 res2 = torch.zeros(10).type(tname)
347 res2[i] += m[i, j] * v[j]
348 self.assertEqual(res1, res2)
350 def test_addmm(self):
352 'torch.DoubleTensor': 1e-8,
353 'torch.FloatTensor': 1e-4,
355 for tname, _prec
in types.items():
356 M = torch.randn(10, 25).type(tname)
357 m1 = torch.randn(10, 50).type(tname)
358 m2 = torch.randn(50, 25).type(tname)
359 res1 = torch.addmm(M, m1, m2)
360 res2 = torch.zeros(10, 25).type(tname)
365 res2[i, j] += m1[i, k] * m2[k, j]
366 self.assertEqual(res1, res2)
369 for tname, _prec
in types.items():
370 M = torch.randn(10, 1).type(tname).expand(10, 25)
371 m1 = torch.randn(10, 1).type(tname).expand(10, 50)
372 m2 = torch.randn(50, 25).type(tname)
373 res1 = torch.addmm(M, m1, m2)
374 res2 = torch.zeros(10, 25).type(tname)
379 res2[i, j] += m1[i, k] * m2[k, j]
380 self.assertEqual(res1, res2)
382 def test_logical_any(self):
384 for device
in devices:
385 x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device)
392 torch.zeros([1, 3, 400], dtype=torch.uint8, device=device),
393 x.any(0, keepdim=
True))
396 torch.zeros([2, 1, 400], dtype=torch.uint8, device=device),
397 x.any(1, keepdim=
True))
400 torch.zeros([2, 3, 1], dtype=torch.uint8, device=device),
401 x.any(2, keepdim=
True))
410 y = torch.zeros([1, 3, 400], dtype=torch.uint8, device=device)
412 self.assertEqual(y, x.any(0, keepdim=
True))
414 y = torch.zeros([2, 1, 400], dtype=torch.uint8, device=device)
416 self.assertEqual(y, x.any(1, keepdim=
True))
418 y = torch.zeros([2, 3, 1], dtype=torch.uint8, device=device)
420 self.assertEqual(y, x.any(2, keepdim=
True))
422 def test_logical_all(self):
424 for device
in devices:
425 x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device)
432 torch.ones([1, 3, 400], dtype=torch.uint8, device=device),
433 x.all(0, keepdim=
True))
436 torch.ones([2, 1, 400], dtype=torch.uint8, device=device),
437 x.all(1, keepdim=
True))
440 torch.ones([2, 3, 1], dtype=torch.uint8, device=device),
441 x.all(2, keepdim=
True))
450 y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device)
452 self.assertEqual(y, x.all(0, keepdim=
True))
454 y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device)
456 self.assertEqual(y, x.all(1, keepdim=
True))
458 y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device)
460 self.assertEqual(y, x.all(2, keepdim=
True))
462 def test_allclose(self):
465 self.assertTrue(torch.allclose(x, y, rtol=0, atol=0.02))
466 self.assertTrue(torch.allclose(x, y, rtol=0.01, atol=0.0))
467 self.assertFalse(torch.allclose(x, y))
471 self.assertFalse(torch.allclose(x, y, rtol=1e-2))
472 self.assertTrue(torch.allclose(x, y, rtol=1e-2, equal_nan=
True))
473 self.assertFalse(torch.allclose(x, y, rtol=1e-3, equal_nan=
True))
475 self.assertTrue(torch.allclose(inf_t, inf_t))
476 self.assertTrue(torch.allclose(-inf_t, -inf_t))
477 self.assertFalse(torch.allclose(inf_t, -inf_t))
478 self.assertFalse(torch.allclose(inf_t,
torch.tensor([1e20])))
479 self.assertFalse(torch.allclose(-inf_t,
torch.tensor([-1e20])))
481 def test_linear_algebra_scalar_raises(self):
482 m = torch.randn(5, 5)
485 self.assertRaises(RuntimeError,
lambda: torch.mv(m, s))
486 self.assertRaises(RuntimeError,
lambda: torch.addmv(v, m, s))
487 self.assertRaises(RuntimeError,
lambda: torch.ger(v, s))
488 self.assertRaises(RuntimeError,
lambda: torch.ger(s, v))
489 self.assertRaises(RuntimeError,
lambda: torch.addr(m, v, s))
490 self.assertRaises(RuntimeError,
lambda: torch.addr(m, s, v))
492 def _test_math(self, torchfn, mathfn, input=None, test_expand=False):
495 input.append(list(range(-5, 5)))
496 input.append([0
for x
in range(-5, 5)])
497 input.append([x + 1e-6
for x
in range(-5, 5)])
499 input.append([x + 1e10
for x
in range(-5, 5)])
500 input.append([x - 1e10
for x
in range(-5, 5)])
501 input.append(torch.randn(10).tolist())
502 input.append((torch.randn(10) + 1e6).tolist())
503 input.append([math.pi * (x / 2)
for x
in range(-5, 5)])
505 def compare_reference(input, dtype):
507 res1 = torchfn(input.clone())
508 res2 = input.clone().apply_(mathfn)
512 compare_reference(input, torch.double)
513 compare_reference(input, torch.float)
515 def check_non_contiguous(shape, dtype):
516 contig = torch.randn(shape, dtype=dtype)
517 non_contig = torch.empty(shape + (2,), dtype=dtype)[..., 0]
518 non_contig.copy_(contig)
519 self.assertFalse(non_contig.is_contiguous())
520 self.assertEqual(torchfn(contig), torchfn(non_contig),
'non-contiguous')
523 check_non_contiguous((5, 7), torch.double)
524 check_non_contiguous((1024,), torch.double)
525 check_non_contiguous((5, 7), torch.float)
526 check_non_contiguous((1024,), torch.float)
528 def check_non_contiguous_index(dtype):
529 contig = torch.randn((2, 2, 1, 2), dtype=dtype)
530 non_contig = contig[:, 1, ...]
531 contig = non_contig.clone()
532 self.assertFalse(non_contig.is_contiguous())
533 self.assertEqual(torchfn(contig), torchfn(non_contig),
'non-contiguous index')
535 check_non_contiguous_index(torch.float)
536 check_non_contiguous_index(torch.double)
538 def check_non_contiguous_expand(shape, dtype):
539 contig = torch.randn(shape, dtype=dtype)
540 non_contig = contig.clone().expand(3, -1, -1)
541 self.assertFalse(non_contig.is_contiguous())
542 contig = torchfn(contig)
543 non_contig = torchfn(non_contig)
545 self.assertEqual(contig, non_contig[i],
'non-contiguous expand[' + str(i) +
']')
550 check_non_contiguous_expand((1, 3), torch.double)
551 check_non_contiguous_expand((1, 7), torch.double)
552 check_non_contiguous_expand((5, 7), torch.float)
556 def check_contiguous_size1(dtype):
557 contig = torch.randn((5, 100), dtype=dtype)
558 contig = contig[:1, :50]
559 contig2 = torch.empty(contig.size(), dtype=dtype)
560 contig2.copy_(contig)
561 self.assertTrue(contig.is_contiguous())
562 self.assertTrue(contig2.is_contiguous())
563 self.assertEqual(torchfn(contig), torchfn(contig2),
'contiguous size1')
565 check_contiguous_size1(torch.double)
566 check_contiguous_size1(torch.float)
568 def check_contiguous_size1_largedim(dtype):
569 contig = torch.randn((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype)
570 contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
571 contig2 = torch.empty(contig.size(), dtype=dtype)
572 contig2.copy_(contig)
573 self.assertTrue(contig.is_contiguous())
574 self.assertTrue(contig2.is_contiguous())
575 self.assertEqual(torchfn(contig), torchfn(contig2),
'contiguous size1')
577 check_contiguous_size1_largedim(torch.double)
578 check_contiguous_size1_largedim(torch.float)
580 def check_large(dtype):
581 input = torch.randn(1024, 512, dtype=dtype)
582 actual = torchfn(input)
583 expected = torch.stack([torchfn(slice)
for slice
in input])
584 self.assertEqual(actual, expected,
'large')
588 check_large(torch.double)
589 check_large(torch.float)
591 def __test_math_by_name(self, function_name, mathfn, selffn):
592 mathfn = getattr(math, mathfn)
595 return getattr(x, function_name)()
597 torchfn = getattr(torch, function_name)
598 self._test_math(torchfn, mathfn, test_expand=(
not selffn))
600 def _test_math_by_name(self, function_name, test_self=True):
602 self.__test_math_by_name(function_name +
"_", function_name,
True)
603 self.__test_math_by_name(function_name, function_name,
False)
606 self._test_math_by_name(
'sin')
612 except OverflowError:
613 return inf
if x > 0
else -inf
614 self._test_math(torch.sinh, sinh)
616 def test_lgamma(self):
618 if x <= 0
and x == int(x):
620 return math.lgamma(x)
621 self._test_math(torch.lgamma, lgamma)
623 @unittest.skipIf(
not TEST_SCIPY,
"Scipy not found")
624 def test_mvlgamma(self):
625 from scipy.special
import multigammaln
626 for d
in range(1, 5):
627 input = torch.empty(10).uniform_(d, 10)
628 res_torch = torch.mvlgamma(input, d)
629 res_scipy = multigammaln(input.numpy(), d)
630 self.assertEqual(res_torch.numpy(), res_scipy)
632 def test_mvlgamma_argcheck(self):
634 input = torch.linspace((d - 2) / 2, 10, 10)
635 torch.mvlgamma(input, d)
637 with self.assertRaisesRegex(RuntimeError,
"Condition for computing multivariate log-gamma not met"):
640 def _digamma_input(self, test_poles=True):
642 input.append((torch.randn(10).abs() + 1e-4).tolist())
643 input.append((torch.randn(10).abs() + 1e6).tolist())
644 zeros = torch.linspace(-9.5, -0.5, 10)
645 input.append(zeros.tolist())
646 input.append((zeros - 0.49).tolist())
647 input.append((zeros + 0.49).tolist())
648 input.append((zeros + (torch.rand(10) * 0.99) - 0.5).tolist())
651 input.append([-0.999999994, -1.999999994, -2.0000000111,
652 -100.99999994, -1931.99999994, 0.000000111,
653 -0.000000111, 0, -2, -329])
656 @unittest.skipIf(
not TEST_SCIPY,
"Scipy not found")
657 def test_digamma(self):
658 from scipy.special
import digamma
661 def torch_digamma_without_inf(inp):
662 res = torch.digamma(inp)
663 res[(res == -inf) | (res == inf)] = nan
666 def scipy_digamma_without_inf(inp):
669 return res
if np.isfinite(res)
else nan
670 res[np.isinf(res)] = nan
673 self._test_math(torch_digamma_without_inf, scipy_digamma_without_inf, self._digamma_input())
675 @unittest.skipIf(
not TEST_SCIPY,
"Scipy not found")
676 def test_polygamma(self):
677 from scipy.special
import polygamma
679 self._test_math(
lambda x: torch.polygamma(n, x),
680 lambda x: polygamma(n, x).item(),
681 self._digamma_input(test_poles=
False))
684 self._test_math(torch.asin,
lambda x: math.asin(x)
if abs(x) <= 1
else nan)
687 self._test_math_by_name(
'cos')
693 except OverflowError:
697 self._test_math(torch.cosh, cosh)
700 self._test_math(torch.acos,
lambda x: math.acos(x)
if abs(x) <= 1
else nan)
703 self._test_math_by_name(
'tan')
706 self._test_math_by_name(
'tanh')
709 self._test_math_by_name(
'atan')
718 self._test_math(torch.log, log)
720 def test_log10(self):
727 self._test_math(torch.log10, log10)
729 def test_log1p(self):
736 self._test_math(torch.log1p, log1p)
746 except AttributeError:
747 return math.log(x, 2)
748 self._test_math(torch.log2, log2)
751 self._test_math(torch.sqrt,
lambda x: math.sqrt(x)
if x >= 0
else nan)
754 self._test_math_by_name(
'erf')
757 self._test_math_by_name(
'erfc')
759 def test_erfinv(self):
760 def checkType(tensor):
761 inputValues = torch.randn(4, 4, out=tensor()).clamp(-2., 2.)
762 self.assertEqual(tensor(inputValues).erf().erfinv(), tensor(inputValues))
764 self.assertTrue(torch.equal(tensor([-1, 1]).erfinv(), tensor([-inf, inf])))
766 self.assertEqual(tensor([-2, 2]).erfinv(), tensor([nan, nan]))
768 checkType(torch.FloatTensor)
769 checkType(torch.DoubleTensor)
775 except OverflowError:
777 self._test_math(torch.exp, exp)
779 def test_expm1(self):
783 except OverflowError:
785 self._test_math(torch.expm1, expm1)
787 def test_floor(self):
788 self._test_math_by_name(
'floor')
791 self._test_math_by_name(
'ceil')
794 def test_ceil_out_cpu_cuda(self):
796 b = torch.randn(1, device=
"cuda")
797 self.assertRaises(RuntimeError,
lambda: torch.ceil(a, out=b))
799 def test_rsqrt(self):
805 return 1.0 / math.sqrt(x)
807 self._test_math(torch.rsqrt, rsqrt)
809 def test_sigmoid(self):
811 inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000]
812 expectedOutput = [0.0000, 0.2689, 0.5, 0.6225, 0.7311, 0.8808, 1.000]
813 precision_4dps = 0.0002
815 def checkType(tensor):
816 self.assertEqual(tensor(inputValues).sigmoid(), tensor(expectedOutput), precision_4dps)
818 checkType(torch.FloatTensor)
819 checkType(torch.DoubleTensor)
822 self._test_math(torch.frac,
lambda x: math.fmod(x, 1))
824 def test_trunc(self):
825 self._test_math(torch.trunc,
lambda x: x - math.fmod(x, 1))
827 def test_round(self):
828 self._test_math(torch.round, round)
830 def test_has_storage(self):
831 self.assertIsNotNone(torch.Tensor().
storage())
832 self.assertIsNotNone(torch.Tensor(0).
storage())
833 self.assertIsNotNone(torch.Tensor([]).
storage())
834 self.assertIsNotNone(torch.Tensor().clone().
storage())
835 self.assertIsNotNone(torch.Tensor([0, 0, 0]).nonzero().
storage())
836 self.assertIsNotNone(torch.Tensor().new().
storage())
838 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
839 def test_has_storage_numpy(self):
840 for dtype
in [np.float32, np.float64, np.int64,
841 np.int32, np.int16, np.uint8]:
842 arr = np.array([1], dtype=dtype)
843 self.assertIsNotNone(torch.FloatTensor(arr).
storage())
844 self.assertIsNotNone(torch.DoubleTensor(arr).
storage())
845 self.assertIsNotNone(torch.IntTensor(arr).
storage())
846 self.assertIsNotNone(torch.LongTensor(arr).
storage())
847 self.assertIsNotNone(torch.ByteTensor(arr).
storage())
849 self.assertIsNotNone(torch.cuda.FloatTensor(arr).
storage())
850 self.assertIsNotNone(torch.cuda.DoubleTensor(arr).
storage())
851 self.assertIsNotNone(torch.cuda.IntTensor(arr).
storage())
852 self.assertIsNotNone(torch.cuda.LongTensor(arr).
storage())
853 self.assertIsNotNone(torch.cuda.ByteTensor(arr).
storage())
855 def _testSelection(self, torchfn, mathfn):
857 m1 = torch.randn(100, 100)
860 for i, j
in iter_indices(m1):
861 res2 = mathfn(res2, m1[i, j])
862 self.assertEqual(res1, res2)
865 m1 = torch.randn(10, 10, 10)
869 for i, j
in iter_indices(m2):
870 res2 = mathfn(res2, m2[i][j])
871 self.assertEqual(res1, res2)
874 m1 = torch.randn(100, 100)
875 res1val, res1ind = torchfn(m1, 1,
False)
876 res2val = m1[:, 0:1].clone().squeeze()
877 res2ind = res1ind.clone().fill_(0)
878 for i, j
in iter_indices(m1):
879 if mathfn(res2val[i], m1[i, j]) != res2val[i]:
880 res2val[i] = m1[i, j]
884 for i
in range(res1val.size(0)):
885 maxerr = max(maxerr, abs(res1val[i] - res2val[i]))
886 self.assertEqual(res1ind[i], res2ind[i])
887 self.assertLessEqual(abs(maxerr), 1e-5)
890 for index
in (0, 4, 99):
891 m1 = torch.randn(100)
893 res1val, res1ind = torch.max(m1, 0)
894 self.assertTrue(math.isnan(res1val))
895 self.assertEqual(res1ind, index)
896 res1val = torchfn(m1)
897 self.assertTrue(math.isnan(res1val))
900 self._testSelection(torch.max, max)
903 def _test_max_with_inf(self, dtypes=(torch.float, torch.double), device=
'cpu'):
905 a =
torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
906 self.assertTrue(torch.all(torch.max(a, dim=1)[0] == inf).item())
907 self.assertTrue(torch.max(a).item() == inf)
909 def test_max_with_inf(self):
910 self._test_max_with_inf(self)
913 self._testSelection(torch.min, min)
916 def _test_min_with_inf(self, dtypes=(torch.float, torch.double), device=
'cpu'):
918 a =
torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
919 self.assertTrue(torch.all(torch.min(a, dim=1)[0] == (-inf)).item())
920 self.assertTrue(torch.min(a).item() == -inf)
922 def test_min_with_inf(self):
923 self._test_min_with_inf(self)
926 def _test_norm(self, device):
928 x = torch.randn(25, device=device)
930 for p
in [0, 1, 2, 3, 4, inf, -inf]:
931 res = x.norm(p).item()
932 expected = np.linalg.norm(xn, p)
933 self.assertEqual(res, expected,
"full reduction failed for {}-norm".format(p))
936 x = torch.randn(25, 25, device=device)
938 for p
in [0, 1, 2, 3, 4, inf, -inf]:
939 res = x.norm(p, 1).cpu().numpy()
940 expected = np.linalg.norm(xn, p, 1)
941 self.assertEqual(res.shape, expected.shape)
942 self.assertTrue(np.allclose(res, expected),
"dim reduction failed for {}-norm".format(p))
945 for p
in [
'fro',
'nuc']:
946 res = x.norm(p).cpu().numpy()
947 expected = np.linalg.norm(xn, p)
948 self.assertEqual(res.shape, expected.shape)
949 self.assertTrue(np.allclose(res, expected),
"dim reduction failed for {}-norm".format(p))
952 self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000)))
954 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
957 self._test_norm(self, device=
'cpu')
960 def _test_dist(self, device):
962 for p
in [0, 1, 2, 3, 4, inf, -inf]:
963 dist_xy = torch.dist(x, y, p)
964 dist_xy_norm = torch.norm(x - y, p)
965 self.assertEqual(dist_xy, dist_xy_norm)
967 run_test(torch.randn(5, device=device), torch.randn(5, device=device))
969 x = torch.zeros(3, device=device)
970 y = torch.zeros(3, device=device)
975 self._test_dist(self, device=
'cpu')
977 def test_dim_reduction_uint8_overflow(self):
978 example = [[-1, 2, 1], [5, 3, 6]]
980 self.assertEqual(x.sum(dtype=torch.uint8).item(), 16)
981 self.assertEqual(x.sum(0, dtype=torch.uint8), torch.FloatTensor([4, 5, 7]))
982 self.assertEqual(x.sum(1, dtype=torch.uint8), torch.FloatTensor([2, 14]))
984 torch.sum(x, 0, out=y)
985 self.assertEqual(x.sum(0, dtype=torch.uint8), y)
988 def _test_dim_reduction(self, cast):
989 example = [[-1, 2, 1], [5, 3, 6]]
991 types = [torch.double,
1002 self.assertEqual(x.sum().item(), 16)
1003 self.assertEqual(x.sum(0), torch.FloatTensor([4, 5, 7]))
1004 self.assertEqual(x.sum(1), torch.FloatTensor([2, 14]))
1006 torch.sum(x, 0, out=y)
1007 self.assertEqual(x.sum(0), y)
1010 for dtype
in types[:2]:
1012 self.assertEqual(x.mean().item(), 16.0 / 6)
1013 self.assertEqual(x.mean(0), torch.FloatTensor([2.0, 2.5, 7.0 / 2]))
1014 self.assertEqual(x.mean(1), torch.FloatTensor([2.0 / 3, 14.0 / 3]))
1015 self.assertEqual(x.mean(), x.mean((0, 1)))
1019 self.assertEqual(x.prod().item(), -180)
1020 self.assertEqual(x.prod(0), torch.FloatTensor([-5, 6, 6]))
1021 self.assertEqual(x.prod(1), torch.FloatTensor([-2, 90]))
1025 self.assertEqual(x.max().item(), 6)
1026 self.assertEqual(x.max(0), (torch.FloatTensor([5, 3, 6]), torch.FloatTensor([1, 1, 1])))
1027 self.assertEqual(x.max(1), (torch.FloatTensor([2, 6]), torch.FloatTensor([1, 2])))
1031 self.assertEqual(x.min().item(), -1)
1032 self.assertEqual(x.min(0), (torch.FloatTensor([-1, 2, 1]), torch.FloatTensor([0, 0, 0])))
1033 self.assertEqual(x.min(1), (torch.FloatTensor([-1, 3]), torch.FloatTensor([0, 1])))
1037 self.assertEqual(x.argmax().item(), 5)
1038 self.assertEqual(x.argmax(dim=
None).item(), 5)
1039 self.assertEqual(x.argmax(dim=0), torch.FloatTensor([1, 1, 1]))
1040 self.assertEqual(x.argmax(dim=1), torch.FloatTensor([1, 2]))
1041 self.assertEqual(x.argmax(dim=0, keepdim=
True), torch.FloatTensor([[1, 1, 1]]))
1043 self.assertEqual(x[:, :2].argmax().item(), 2)
1047 self.assertEqual(x.argmin().item(), 0)
1048 self.assertEqual(x.argmin(dim=
None).item(), 0)
1049 self.assertEqual(x.argmin(dim=0), torch.FloatTensor([0, 0, 0]))
1050 self.assertEqual(x.argmin(dim=1), torch.FloatTensor([0, 1]))
1051 self.assertEqual(x.argmin(dim=1, keepdim=
True), torch.FloatTensor([[0], [1]]))
1053 self.assertEqual(x[:, :2].argmin().item(), 0)
1056 "mean",
"median",
"mode",
"norm",
"prod",
1057 "std",
"sum",
"var",
"max",
"min"]
1059 def normfn_attr(t, dim, keepdim=False, out=None):
1060 attr = getattr(torch,
"norm")
1061 return attr(t, 2, dim, keepdim, out=out)
1063 for fn_name
in dim_red_fns:
1064 fn_attr = getattr(torch, fn_name)
if fn_name !=
"norm" else normfn_attr
1066 def fn(x, dim, keepdim=False, out=None):
1067 ans = fn_attr(x, dim, keepdim=keepdim, out=out)
1068 return ans
if not istuple(ans)
else ans[0]
1070 def fn_tuple(x, dim, keepdim=False, out=None):
1071 return fn_attr(x, dim, keepdim=keepdim, out=out)
1073 def test_multidim(x, dim):
1074 self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=
True))
1075 self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
1076 self.assertEqual(x.ndimension(), fn(x, dim, keepdim=
True).ndimension())
1079 x = cast(torch.randn(3, 4, 5))
1080 dim = random.randint(0, 2)
1081 test_multidim(x, dim)
1084 x = cast(torch.randn(1))
1086 self.assertEqual(fn(x, dim).shape, ())
1087 self.assertEqual(fn(x, dim, keepdim=
True).shape, (1,))
1091 singleton_dim = random.randint(0, 2)
1092 dims[singleton_dim] = 1
1093 x = cast(torch.randn(dims))
1094 test_multidim(x, singleton_dim)
1097 if fn_name
in [
'median',
'mode',
'max',
'min']:
1098 y = cast(torch.randn(5, 3))
1099 values = cast(torch.randn(5, 3))
1100 indices = cast(torch.zeros(5, 3).long() - 1)
1101 fn_tuple(y, 1, keepdim=
False, out=(values[:, 1], indices[:, 1]))
1102 values_expected, indices_expected = fn_tuple(y, 1, keepdim=
False)
1103 self.assertEqual(values[:, 1], values_expected,
1104 '{} values with out= kwarg'.format(fn_name))
1105 self.assertEqual(indices[:, 1], indices_expected,
1106 '{} indices with out= kwarg'.format(fn_name))
1109 x = cast(torch.randn(5, 3))
1110 y = cast(torch.randn(5, 3))
1111 fn(y, 1, keepdim=
False, out=x[:, 1])
1112 expected = fn(y, 1, keepdim=
False)
1113 self.assertEqual(x[:, 1], expected,
'{} with out= kwarg'.format(fn_name))
1115 def test_dim_reduction(self):
1116 self._test_dim_reduction(self,
lambda t: t)
1118 def test_reduction_empty(self):
1121 (
'max', torch.max,
None),
1122 (
'kthvalue',
lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs),
None),
1123 (
'argmax', torch.argmax,
None),
1124 (
'min', torch.min,
None),
1125 (
'argmin', torch.argmin,
None),
1126 (
'mode', torch.mode,
None),
1127 (
'median', torch.median,
None),
1129 (
'prod', torch.prod, 1),
1130 (
'sum', torch.sum, 0),
1131 (
'norm', torch.norm, 0),
1132 (
'mean', torch.mean, nan),
1133 (
'var', torch.var, nan),
1134 (
'std', torch.std, nan),
1135 (
'logsumexp', torch.logsumexp, -inf),
1140 for device
in devices:
1141 x = torch.randn(shape, device=device)
1143 for item
in fns_to_test:
1144 name, fn, identity = item
1145 if identity
is None:
1146 ident_err =
'does not have an identity' 1147 self.assertRaisesRegex(RuntimeError, ident_err,
lambda: fn(x, dim=2))
1148 self.assertRaisesRegex(RuntimeError, ident_err,
lambda: fn(x, dim=2, keepdim=
True))
1149 self.assertRaisesRegex(RuntimeError, ident_err,
lambda: fn(x, dim=1))
1150 self.assertRaisesRegex(RuntimeError, ident_err,
lambda: fn(x, dim=1, keepdim=
True))
1152 self.assertEqual(torch.empty((2, 0), device=device), fn(x, dim=2))
1153 self.assertEqual(torch.empty((2, 0, 1), device=device), fn(x, dim=2, keepdim=
True))
1155 check = (torch.testing.assert_allclose
if math.isnan(identity)
or math.isinf(identity)
else 1157 check(torch.full((2, 4), identity, device=device), fn(x, dim=1))
1158 check(torch.full((2, 1, 4), identity, device=device), fn(x, dim=1, keepdim=
True))
1160 check(torch.full((), identity, device=device), fn(x))
1161 except TypeError
as err:
1163 self.assertTrue(
'required positional arguments: "dim"' in str(err))
1166 xb = x.to(torch.uint8)
1167 yb = x.to(torch.uint8)
1168 self.assertEqual((2, 0), xb.any(2).shape)
1169 self.assertEqual((2, 0, 1), xb.any(2, keepdim=
True).shape)
1170 self.assertEqual(torch.zeros((2, 4), device=device), xb.any(1))
1171 self.assertEqual(torch.zeros((2, 1, 4), device=device), xb.any(1, keepdim=
True))
1172 self.assertEqual(torch.zeros((), device=device), xb.any())
1175 self.assertEqual((2, 0), xb.all(2).shape)
1176 self.assertEqual((2, 0, 1), xb.all(2, keepdim=
True).shape)
1177 self.assertEqual(torch.ones((2, 4), device=device), xb.all(1))
1178 self.assertEqual(torch.ones((2, 1, 4), device=device), xb.all(1, keepdim=
True))
1179 self.assertEqual(torch.ones((), device=device), xb.all())
1181 def test_pairwise_distance_empty(self):
1183 for device
in devices:
1185 x = torch.randn(shape, device=device)
1186 y = torch.randn(shape, device=device)
1188 self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y))
1189 self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=
True))
1192 x = torch.randn(shape, device=device)
1193 y = torch.randn(shape, device=device)
1194 self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y))
1195 self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=
True))
1197 def test_pdist_empty(self):
1199 for device
in devices:
1201 x = torch.randn(shape, device=device)
1202 self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
1205 x = torch.randn(shape, device=device)
1206 self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
1209 x = torch.randn(shape, device=device)
1210 self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
1212 def test_pdist_norm(self):
1213 def test_pdist_single(shape, device, p, dtype, trans):
1214 x = torch.randn(shape, dtype=dtype, device=device)
1216 x.transpose_(-2, -1)
1217 actual = torch.pdist(x, p=p)
1218 expected = brute_pdist(x, p=p)
1219 self.assertEqual(expected.shape, actual.shape)
1220 self.assertTrue(torch.allclose(expected, actual))
1223 for device
in devices:
1224 for shape
in [(4, 5), (3, 2), (2, 1)]:
1225 for p
in [0, 1, 2, 3, 1.5, 2.5, float(
'inf')]:
1226 for trans
in [
False,
True]:
1227 for dtype
in [torch.float32, torch.float64]:
1228 test_pdist_single(shape, device, p, dtype, trans)
1232 for dtype
in [torch.float32, torch.float64]:
1233 test_pdist_single((1000, 2), device, 2, dtype,
False)
1235 def test_cdist_empty(self):
1237 for device
in devices:
1238 x = torch.randn((0, 5), device=device)
1239 y = torch.randn((4, 5), device=device)
1240 self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
1242 x = torch.randn((2, 5), device=device)
1243 y = torch.randn((0, 5), device=device)
1244 self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
1246 x = torch.randn((2, 0), device=device)
1247 y = torch.randn((3, 0), device=device)
1248 self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y))
1250 x = torch.randn((2, 0), device=device)
1251 y = torch.randn((0, 0), device=device)
1252 self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
1254 def test_cdist_norm(self):
1256 for device
in devices:
1257 for r1
in [3, 4, 5, 6]:
1258 for m
in [2, 3, 4, 10]:
1259 for r2
in [4, 6, 7, 8]:
1260 for p
in [0, 1, 2, 3, 1.5, 2.5, float(
'inf')]:
1261 x = torch.randn(r1, m, device=device)
1262 y = torch.randn(r2, m, device=device)
1263 actual = torch.cdist(x, y, p=p)
1264 expected = brute_cdist(x, y, p=p)
1265 self.assertTrue(torch.allclose(expected, actual))
1267 def test_cdist_large(self):
1269 for device
in devices:
1270 x = torch.randn(1000, 10, device=device)
1271 y = torch.randn(1000, 10, device=device)
1272 actual = torch.cdist(x, y, p=2)
1273 expected = brute_cdist(x, y, p=2)
1274 self.assertTrue(torch.allclose(expected, actual))
1276 def test_cdist_non_contiguous(self):
1278 for device
in devices:
1279 x = torch.randn(5, 7, device=device).t()
1280 y = torch.randn(5, 3, device=device).t()
1281 actual = torch.cdist(x, y, p=2)
1282 expected = brute_cdist(x, y, p=2)
1283 self.assertFalse(x.is_contiguous())
1284 self.assertFalse(y.is_contiguous())
1285 self.assertTrue(torch.allclose(expected, actual))
1287 x = torch.randn(7, 5, device=device)
1288 y = torch.randn(5, 3, device=device).t()
1289 actual = torch.cdist(x, y, p=2)
1290 expected = brute_cdist(x, y, p=2)
1291 self.assertTrue(x.is_contiguous())
1292 self.assertFalse(y.is_contiguous())
1293 self.assertTrue(torch.allclose(expected, actual))
1295 x = torch.randn(5, 7, device=device).t()
1296 y = torch.randn(3, 5, device=device)
1297 actual = torch.cdist(x, y, p=2)
1298 expected = brute_cdist(x, y, p=2)
1299 self.assertFalse(x.is_contiguous())
1300 self.assertTrue(y.is_contiguous())
1301 self.assertTrue(torch.allclose(expected, actual))
1303 @unittest.skipIf(
not TEST_SCIPY,
"Scipy not found")
1304 def test_logsumexp(self):
1305 from scipy.special
import logsumexp
1306 a = torch.randn(5, 4)
1309 actual = a.logsumexp(1)
1310 expected = logsumexp(a.numpy(), 1)
1311 self.assertEqual(expected.shape, actual.shape)
1312 self.assertTrue(np.allclose(expected, actual.numpy()))
1314 b = torch.zeros(5, 2)
1316 torch.logsumexp(a, 1, out=c)
1317 self.assertTrue(np.allclose(expected, b[:, 0].numpy()))
1319 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1320 def test_cpu_parallel(self):
1326 def _run_test(size):
1327 for dim
in range(len(size) + 1):
1328 nv = np.round(np.random.rand(*size))
1329 tv = torch.from_numpy(nv)
1332 self.assertTrue(tv.numel() > 32768)
1333 if dim == len(size):
1339 diff = np.abs(nvs - tvs.numpy()).sum()
1340 self.assertEqual(diff, 0)
1342 _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3])
1343 _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
1344 _run_test([1, 32 * 8 * 32 * 8])
1345 _run_test([1, 32770])
1347 def _testCSelection(self, torchfn, mathfn):
1350 a = torch.rand(*size)
1351 b = torch.rand(*size)
1353 expected_c = torch.zeros(*size)
1354 expected_c.map2_(a, b,
lambda _, a, b: mathfn(a, b))
1355 self.assertEqual(expected_c, c, 0)
1357 def test_max_elementwise(self):
1358 self._testCSelection(torch.max, max)
1360 def test_min_elementwise(self):
1361 self._testCSelection(torch.min, min)
1364 def _test_lerp(self, cast):
1365 start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)]
1366 for shapes
in product(start_end_shapes, start_end_shapes):
1367 start = cast(torch.randn(shapes[0]))
1368 end = cast(torch.randn(shapes[1]))
1371 for weight
in [cast(torch.randn(shapes[0])), random.random()]:
1372 actual = torch.lerp(start, end, weight)
1373 actual_method = start.lerp(end, weight)
1374 self.assertEqual(actual, actual_method)
1375 actual_out = cast(torch.Tensor())
1376 torch.lerp(start, end, weight, out=actual_out)
1377 self.assertEqual(actual, actual_out)
1378 expected = start + weight * (end - start)
1379 self.assertEqual(expected, actual)
1381 def test_lerp(self):
1382 self._test_lerp(self,
lambda t: t)
1384 def test_all_any(self):
1386 x = torch.ones(*size).byte()
1387 self.assertTrue(x.all())
1388 self.assertTrue(x.any())
1391 self.assertFalse(x.all())
1392 self.assertTrue(x.any())
1395 self.assertFalse(x.all())
1396 self.assertFalse(x.any())
1399 self.assertTrue(x.all())
1400 self.assertTrue(x.any())
1405 def test_all_any_empty(self):
1406 x = torch.ByteTensor()
1407 self.assertTrue(x.all())
1408 self.assertFalse(x.any())
1410 def test_all_any_with_dim(self):
1412 r1 = x.prod(dim=0, keepdim=
False).byte()
1413 r2 = x.all(dim=0, keepdim=
False)
1414 self.assertEqual(r1.shape, r2.shape)
1415 self.assertTrue((r1 == r2).all())
1417 r3 = x.sum(dim=1, keepdim=
True).clamp(0, 1).byte()
1418 r4 = x.any(dim=1, keepdim=
True)
1419 self.assertEqual(r3.shape, r4.shape)
1420 self.assertTrue((r3 == r4).all())
1422 test(torch.ByteTensor([[0, 0, 0],
1428 def test_all_any_empty_cuda(self):
1429 x = torch.cuda.ByteTensor()
1430 self.assertTrue(x.all())
1431 self.assertFalse(x.any())
1434 m1 = torch.randn(100, 100)
1435 v1 = torch.randn(100)
1437 res1 = torch.mv(m1, v1)
1438 res2 = res1.clone().zero_()
1439 for i, j
in iter_indices(m1):
1440 res2[i] += m1[i][j] * v1[j]
1442 self.assertEqual(res1, res2)
1446 m1 = torch.randn(100, 100)
1447 v1 = torch.randn(100)
1450 res1 = torch.add(m1[4], v1)
1451 res2 = res1.clone().zero_()
1452 for i
in range(m1.size(1)):
1453 res2[i] = m1[4, i] + v1[i]
1454 self.assertEqual(res1, res2)
1456 m1 = torch.randn(100, 100)
1457 v1 = torch.randn(100)
1460 res1 = torch.add(m1[:, 4], v1)
1461 res2 = res1.clone().zero_()
1462 for i
in range(m1.size(0)):
1463 res2[i] = m1[i, 4] + v1[i]
1464 self.assertEqual(res1, res2)
1467 m1 = torch.randn(10, 10)
1473 for i
in range(m1.size(1)):
1474 res2[3, i] = res2[3, i] + 2
1475 self.assertEqual(res1, res2)
1478 m1 = torch.randn(10, 10)
1482 for i
in range(m1.size(0)):
1483 res2[i, 3] = res2[i, 3] + 2
1484 self.assertEqual(res1, res2)
1487 m1 = torch.randn(10, 10)
1491 self.assertEqual(torch.add(one, 1), 2)
1492 self.assertEqual(torch.add(one, 1).dtype, torch.uint8)
1495 m1 = torch.randn(10, 10)
1496 m2 = torch.randn(10, 10).t()
1498 self.assertTrue(res.is_contiguous())
1499 self.assertEqual(res, m1 + m2.contiguous())
1504 self.assertEqual(m1 + m2, [])
1508 def test_csub(self):
1510 a = torch.randn(100, 90)
1511 b = a.clone().normal_()
1513 res_add = torch.add(a, -1, b)
1514 res_csub = a.clone()
1516 self.assertEqual(res_add, res_csub)
1519 a = torch.randn(100, 100)
1522 res_add = torch.add(a, -scalar)
1523 res_csub = a.clone()
1524 res_csub.sub_(scalar)
1525 self.assertEqual(res_add, res_csub)
1528 def _test_neg(self, cast):
1529 float_types = [torch.DoubleTensor, torch.FloatTensor, torch.LongTensor]
1530 int_types = [torch.IntTensor, torch.ShortTensor, torch.ByteTensor,
1533 for t
in float_types + int_types:
1534 if t
in float_types:
1535 a = cast(torch.randn(100, 90).type(t))
1537 a = cast(torch.randint(-128, 128, (100, 90), dtype=t.dtype))
1538 zeros = cast(torch.Tensor().type(t)).resize_as_(a).zero_()
1540 if t == torch.ByteTensor:
1541 res_add = torch.add(zeros, a, alpha=255)
1543 res_add = torch.add(zeros, a, alpha=-1)
1546 self.assertEqual(res_neg, res_add)
1549 res_neg_out_place = a.clone().neg()
1550 self.assertEqual(res_neg_out_place, res_add)
1553 res_neg_op = -a.clone()
1554 self.assertEqual(res_neg_op, res_add)
1557 self._test_neg(self,
lambda t: t)
1559 def test_threshold(self):
1561 if dtype != torch.uint8
and dtype != torch.float16:
1563 x = torch.randn(100).sign().to(dtype=dtype)
1564 y = torch.threshold(x, 0, 0)
1565 self.assertTrue(y.le(0).any())
1567 def test_reciprocal(self):
1568 a = torch.randn(100, 89)
1570 res_reciprocal = a.clone()
1571 res_reciprocal.reciprocal_()
1572 self.assertEqual(res_reciprocal, res_div)
1575 m1 = torch.randn(10, 10)
1579 for i
in range(res1.size(0)):
1580 res2[i, 3] = res2[i, 3] * 2
1581 self.assertEqual(res1, res2)
1584 m1 = torch.randn(10, 10)
1588 for i
in range(m1.size(0)):
1589 res2[i, 3] = res2[i, 3] / 2
1590 self.assertEqual(res1, res2)
1592 def test_floordiv(self):
1594 if dtype
is torch.float16:
1596 x = torch.randn(100).mul(10).to(dtype)
1598 self.assertEqual(y.dtype, x.dtype)
1599 z =
torch.tensor([math.trunc(v.item() / 3.)
for v
in x], dtype=y.dtype)
1600 self.assertEqual(y, z)
1602 def test_rdiv(self):
1604 if dtype
is torch.float16:
1606 x = torch.rand(100).add(1).mul(4).to(dtype)
1608 if dtype.is_floating_point:
1609 z =
torch.tensor([30 / v.item()
for v
in x], dtype=dtype)
1611 z =
torch.tensor([math.trunc(30. / v.item())
for v
in x], dtype=dtype)
1612 self.assertEqual(y, z)
1614 def test_fmod(self):
1615 m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
1620 for i
in range(m1.size(1)):
1621 res2[i, 3] = math.fmod(res2[i, 3], q)
1622 self.assertEqual(res1, res2)
1624 def test_remainder(self):
1626 for use_item
in [
True,
False]:
1627 m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
1630 qs = torch.arange(-5.1, 4.1)
1632 for col_idx, q
in enumerate(qs):
1634 for i
in range(m1.size(0)):
1635 res2[i, col_idx] = res2[i, col_idx] % q
1637 res1[:, col_idx].remainder_(q
if not use_item
else q.item())
1638 self.assertEqual(res1, res2)
1641 res1.remainder_(qs.unsqueeze(0).expand_as(res1))
1642 self.assertEqual(res1, res2)
1645 for use_item
in [
True,
False]:
1646 long_m1 = torch.LongTensor(10, 10).random_(-10, 10)
1647 long_res1 = long_m1.clone()
1648 long_res2 = long_m1.clone()
1649 long_qs = torch.arange(-5, 5)
1651 for col_idx, long_q
in enumerate(long_qs):
1653 for i
in range(long_m1.size(0)):
1654 long_res2[i, col_idx] = long_res2[i, col_idx] % long_q
1656 long_res1[:, col_idx].remainder_(long_q
if not use_item
else long_q.item())
1657 self.assertEqual(long_res1, long_res2)
1659 long_res1 = long_m1.clone()
1660 long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))
1663 def _test_remainder_overflow(self, dtype, device):
1667 self.assertEqual(x % q, x)
1668 self.assertEqual(-x % q, q - x)
1669 self.assertEqual(x % -q, x - q)
1670 self.assertEqual(-x % -q, -x)
1672 def test_remainder_overflow(self):
1673 self._test_remainder_overflow(self, dtype=torch.int64, device=
'cpu')
1676 def _test_mm(n, m, p, dtype, genf):
1678 def matrixmultiply(mat1, mat2):
1682 res = torch.zeros(n, p, dtype=dtype)
1683 for i, j
in iter_indices(res):
1684 res[i, j] = sum(mat1[i, k] * mat2[k, j]
for k
in range(m))
1690 res = torch.mm(mat1, mat2)
1692 res2 = matrixmultiply(mat1, mat2)
1693 self.assertEqual(res, res2)
1697 mat2 = genf(p, m).t()
1698 res = torch.mm(mat1, mat2)
1700 res2 = matrixmultiply(mat1, mat2)
1701 self.assertEqual(res, res2)
1704 mat1 = genf(m, n).t()
1706 res = torch.mm(mat1, mat2)
1708 res2 = matrixmultiply(mat1, mat2)
1709 self.assertEqual(res, res2)
1712 mat1 = genf(m, n).t()
1713 mat2 = genf(p, m).t()
1714 res = torch.mm(mat1, mat2)
1716 res2 = matrixmultiply(mat1, mat2)
1717 self.assertEqual(res, res2)
1721 mat2 = genf(m, 1).expand(m, p)
1722 res = torch.mm(mat1, mat2)
1724 res2 = matrixmultiply(mat1, mat2)
1725 self.assertEqual(res, res2)
1732 torch.mm(mat1, mat2, out=res)
1734 res2 = matrixmultiply(mat1, mat2)
1735 self.assertEqual(res, res2)
1739 mat1 = genf(m, n).t()
1740 mat2 = genf(p, m).t()
1742 torch.mm(mat1, mat2, out=res)
1744 res2 = matrixmultiply(mat1, mat2)
1745 self.assertEqual(res, res2)
1747 for (n, m, p)
in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]:
1748 _test_mm(n, m, p, torch.float32,
lambda x, y: torch.randn(x, y, dtype=torch.float32))
1749 _test_mm(n, m, p, torch.float64,
lambda x, y: torch.randn(x, y, dtype=torch.float64))
1750 _test_mm(n, m, p, torch.int32,
lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int32))
1751 _test_mm(n, m, p, torch.int64,
lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64))
1754 def _test_btrifact(self, cast):
1755 from common_utils
import random_fullrank_matrix_distinct_singular_value
as fullrank
1757 def run_test(matrix_size, batches, cast):
1758 a = cast(fullrank(matrix_size, *batches))
1759 a_LU_info, pivots_info, info_ = a.btrifact_with_info()
1760 self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size)))
1761 self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,)))
1762 self.assertEqual(info_.size(), torch.Size(batches))
1763 self.assertEqual(info_.abs().sum(), 0)
1764 a_LU, pivots = a.btrifact()
1765 self.assertEqual(a_LU, a_LU_info)
1766 self.assertEqual(pivots_info, pivots)
1768 a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=
False)
1769 self.assertIsNone(nopiv)
1770 self.assertEqual(info_, info_nopiv)
1771 P, L, U = torch.btriunpack(a_LU, pivots)
1772 self.assertEqual(P.matmul(L.matmul(U)), a)
1774 for ms, batch
in product([3, 5, 7], [(2,), (3,), (3, 5)]):
1775 run_test(ms, batch, cast)
1778 a = cast(fullrank(3, 5))
1779 if not (a.is_cuda
and any(x
in torch.version.cuda
for x
in [
'8.0',
'9.2'])):
1780 a[0, 1] = 2 * a[0, 0]
1781 self.assertGreater(a.btrifact_with_info()[2][0], 0)
1784 with self.assertRaisesRegex(RuntimeError,
1785 'btrifact without pivoting is not implemented on the CPU'):
1786 torch.btrifact(torch.empty(1, 2, 2), pivot=
False)
1790 def test_btrifact(self):
1791 self._test_btrifact(self,
lambda t: t)
1794 def _test_btrisolve(self, cast):
1795 a = torch.FloatTensor((((1.3722, -0.9020),
1801 b = torch.FloatTensor(((4.02, 6.19),
1804 a, b = cast(a), cast(b)
1805 LU_data, pivots, info = a.btrifact_with_info()
1806 self.assertEqual(info.abs().sum(), 0)
1807 x = torch.btrisolve(b, LU_data, pivots)
1808 b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
1809 self.assertEqual(b_, b)
1812 def test_btrisolve(self):
1813 self._test_btrisolve(self,
lambda t: t)
1816 def _test_btriunpack(self, cast):
1817 def run_test(shape, cast):
1818 a = cast(torch.randn(*shape))
1819 a_lu, p = torch.btrifact(a.reshape(-1, shape[-1], shape[-1]))
1820 a_lu = a_lu.reshape_as(a)
1821 p = p.reshape(a.shape[:-1])
1822 p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p)
1823 self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
1825 run_test((5, 3, 3), cast)
1826 run_test((7, 3, 5, 5), cast)
1827 run_test((7, 5, 3, 3, 3), cast)
1830 def test_btriunpack(self):
1831 self._test_btriunpack(self,
lambda t: t)
1836 b1 = torch.randn(num_batches, M, N)
1837 b2 = torch.randn(num_batches, N, O)
1838 res = torch.bmm(b1, b2)
1839 for i
in range(num_batches):
1840 r = torch.mm(b1[i], b2[i])
1841 self.assertEqual(r, res[i])
1844 self.assertRaises(RuntimeError,
lambda: torch.bmm(b1, b2.cuda()))
1845 self.assertRaises(RuntimeError,
lambda: torch.bmm(b1.cuda(), b2))
1847 def test_addbmm(self):
1852 b1 = torch.randn(num_batches, M, N)
1853 b2 = torch.randn(num_batches, N, O)
1854 res = torch.bmm(b1, b2)
1855 res2 = torch.Tensor().resize_as_(res[0]).zero_()
1857 res2.addbmm_(b1, b2)
1858 self.assertEqual(res2, res.sum(0,
False))
1860 res2.addbmm_(1, b1, b2)
1861 self.assertEqual(res2, res.sum(0,
False) * 2)
1863 res2.addbmm_(1., .5, b1, b2)
1864 self.assertEqual(res2, res.sum(0,
False) * 2.5)
1866 res3 = torch.addbmm(1, res2, 0, b1, b2)
1867 self.assertEqual(res3, res2)
1869 res4 = torch.addbmm(1, res2, .5, b1, b2)
1870 self.assertEqual(res4, res.sum(0,
False) * 3)
1872 res5 = torch.addbmm(0, res2, 1, b1, b2)
1873 self.assertEqual(res5, res.sum(0,
False))
1875 res6 = torch.addbmm(.1, res2, .5, b1, b2)
1876 self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5))
1878 def test_baddbmm(self):
1881 b1 = torch.randn(num_batches, M, N)
1882 b2 = torch.randn(num_batches, N, O)
1883 res = torch.bmm(b1, b2)
1884 res2 = torch.Tensor().resize_as_(res).zero_()
1886 res2.baddbmm_(b1, b2)
1887 self.assertEqual(res2, res)
1889 res2.baddbmm_(1, b1, b2)
1890 self.assertEqual(res2, res * 2)
1892 res2.baddbmm_(1, .5, b1, b2)
1893 self.assertEqual(res2, res * 2.5)
1895 res3 = torch.baddbmm(1, res2, 0, b1, b2)
1896 self.assertEqual(res3, res2)
1898 res4 = torch.baddbmm(1, res2, .5, b1, b2)
1899 self.assertEqual(res4, res * 3)
1901 res5 = torch.baddbmm(0, res2, 1, b1, b2)
1902 self.assertEqual(res5, res)
1904 res6 = torch.baddbmm(.1, res2, .5, b1, b2)
1905 self.assertEqual(res6, res2 * .1 + res * .5)
1908 def _test_clamp(self, device='cpu'):
1909 m1 = torch.rand(100, device=device).mul(5).add(-2.5)
1917 res1.clamp_(min_val, max_val)
1919 for i
in iter_indices(res2):
1920 res2[i] = max(min_val, min(max_val, res2[i]))
1921 self.assertEqual(res1, res2)
1924 torch.clamp(m1, min=min_val, max=max_val, out=out)
1925 self.assertEqual(out, res1)
1927 res1 = torch.clamp(m1, min=min_val)
1929 for i
in iter_indices(res2):
1930 res2[i] = max(min_val, res2[i])
1931 self.assertEqual(res1, res2)
1933 torch.clamp(m1, min=min_val, out=out)
1934 self.assertEqual(out, res1)
1936 res1 = torch.clamp(m1, max=max_val)
1938 for i
in iter_indices(res2):
1939 res2[i] = min(max_val, res2[i])
1940 self.assertEqual(res1, res2)
1942 torch.clamp(m1, max=max_val, out=out)
1943 self.assertEqual(out, res1)
1948 res1 = test_tens.clone()
1949 res1.clamp_(min_val, max_val)
1950 res2 = test_tens.clone()
1951 for i
in iter_indices(res2):
1952 res2[i] = max(min(res2[i], max_val), min_val)
1953 self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1955 out = test_tens.clone()
1956 torch.clamp(test_tens, min=min_val, max=max_val, out=out)
1957 self.assertEqual(torch.isnan(out), torch.isnan(res1))
1959 res1 = torch.clamp(test_tens, min=min_val)
1960 res2 = test_tens.clone()
1961 for i
in iter_indices(res2):
1962 res2[i] = max(res2[i], min_val)
1963 self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1965 torch.clamp(test_tens, min=min_val, out=out)
1966 self.assertEqual(torch.isnan(out), torch.isnan(res1))
1968 res1 = torch.clamp(test_tens, max=max_val)
1969 res2 = test_tens.clone()
1970 for i
in iter_indices(res2):
1971 res2[i] = min(res2[i], max_val)
1972 self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1974 torch.clamp(test_tens, max=max_val, out=out)
1975 self.assertEqual(torch.isnan(out), torch.isnan(res1))
1977 error_msg =
'At least one of \'min\' or \'max\' must not be None' 1978 with self.assertRaisesRegex(RuntimeError, error_msg):
1980 with self.assertRaisesRegex(RuntimeError, error_msg):
1983 def test_clamp(self):
1984 self._test_clamp(self)
1990 for exponent
in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]:
1993 m1 = torch.rand(100, 100) + 0.5
1994 res1 = torch.pow(m1[4], exponent)
1995 res2 = res1.clone().zero_()
1996 for i
in range(res2.size(0)):
1997 res2[i] = math.pow(m1[4][i], exponent)
1998 self.assertEqual(res1, res2)
2001 m1 = torch.rand(100, 100) + 0.5
2002 res1 = torch.pow(m1[:, 4], exponent)
2003 res2 = res1.clone().zero_()
2004 for i
in range(res2.size(0)):
2005 res2[i] = math.pow(m1[i, 4], exponent)
2006 self.assertEqual(res1, res2)
2010 m1 = torch.randn(100, 100)
2011 res1 = torch.pow(3, m1[4])
2012 res2 = res1.clone().zero_()
2013 for i
in range(res2.size(0)):
2014 res2[i] = math.pow(3, m1[4, i])
2015 self.assertEqual(res1, res2)
2018 m1 = torch.randn(100, 100)
2019 res1 = torch.pow(3, m1[:, 4])
2020 res2 = res1.clone().zero_()
2021 for i
in range(res2.size(0)):
2022 res2[i] = math.pow(3, m1[i][4])
2023 self.assertEqual(res1, res2)
2026 def _test_rpow(self, cast):
2027 m = cast(torch.randn(10, 10))
2028 self.assertEqual(torch.pow(2, m), 2**m)
2031 m = cast(torch.randn(1).squeeze())
2032 assert m.dim() == 0,
"m is intentionally a scalar" 2033 self.assertEqual(torch.pow(2, m), 2**m)
2035 def test_rpow(self):
2036 self._test_rpow(self,
lambda x: x)
2039 def _test_int_pow(self, cast):
2044 def check_against_np(tensor, exp):
2045 tensor_np = tensor.cpu().numpy()
2046 exp_np = exp
if isinstance(exp, int)
else exp.cpu().numpy()
2047 expected = torch.LongTensor(tensor_np ** exp_np).type_as(tensor)
2048 self.assertEqual(torch.pow(tensor, exp), expected)
2049 self.assertEqual(tensor.pow(exp), torch.pow(tensor, exp))
2053 lambda x: x.short(),
2058 typecasts.append(
lambda x: x.int())
2061 tensor = cast(torch.LongTensor(shape).random_(-10, 10))
2062 exps = [0, 1, 2, 5, cast(torch.LongTensor(shape).random_(0, 20))]
2064 for typecast
in typecasts:
2066 t = typecast(tensor)
2067 e = exp
if isinstance(exp, int)
else typecast(exp)
2068 check_against_np(t, e)
2070 def test_int_pow(self):
2071 self._test_int_pow(self,
lambda x: x)
2073 def _test_cop(self, torchfn, mathfn):
2074 def reference_implementation(res2):
2075 for i, j
in iter_indices(sm1):
2076 idx1d = i * sm1.size(0) + j
2077 res2[i, j] = mathfn(sm1[i, j], sm2[idx1d])
2081 m1 = torch.randn(10, 10, 10)
2082 m2 = torch.randn(10, 10 * 10)
2086 res1 = torchfn(sm1, sm2.view(10, 10))
2087 res2 = reference_implementation(res1.clone())
2088 self.assertEqual(res1, res2)
2091 m1 = torch.randn(10, 10, 10)
2092 m2 = torch.randn(10 * 10, 10 * 10)
2096 sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0]))
2097 res1 = torchfn(sm1, sm2)
2099 sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride())
2100 res2 = reference_implementation(res1.clone())
2101 self.assertEqual(res1, res2)
2103 def test_cdiv(self):
2104 self._test_cop(torch.div,
lambda x, y: x / y)
2106 def test_cfmod(self):
2107 self._test_cop(torch.fmod, math.fmod)
2109 def test_cremainder(self):
2110 self._test_cop(torch.remainder,
lambda x, y: x % y)
2112 def test_cmul(self):
2113 self._test_cop(torch.mul,
lambda x, y: x * y)
2115 def test_cpow(self):
2116 self._test_cop(torch.pow,
lambda x, y: nan
if x < 0
else math.pow(x, y))
2118 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2119 def test_einsum(self):
2123 A = torch.randn(3, 5)
2124 B = torch.randn(2, 5)
2125 C = torch.randn(2, 3, 5)
2126 D = torch.randn(2, 5, 7)
2127 E = torch.randn(7, 9)
2128 F = torch.randn(2, 3, 5, 7)
2129 G = torch.randn(7, 11, 13)
2130 H = torch.randn(4, 4)
2131 I = torch.randn(3, 4, 4)
2132 l = torch.randn(5, 10)
2133 r = torch.randn(5, 20)
2134 w = torch.randn(30, 10, 20)
2145 (
"ij,ij->ij", A, A),
2147 (
"ij,kj->ik", A, B),
2148 (
"ij,ab->ijab", A, E),
2150 (
"aij,ajk->aik", C, D),
2151 (
"ijk,jk->i", C, A),
2152 (
"aij,jk->aik", D, E),
2153 (
"abcd,dfg->abcfg", F, G),
2154 (
"ijk,jk->ik", C, A),
2155 (
"ijk,jk->ij", C, A),
2156 (
"ijk,ik->j", C, B),
2157 (
"ijk,ik->jk", C, B),
2163 (
"ki,...k->i...", A.t(), B),
2164 (
"k...,jk", A.t(), B),
2167 (
"bn,anm,bm->ba", l, w, r),
2168 (
"... ii->...i ", I),
2170 for test
in test_list:
2171 actual = torch.einsum(test[0], test[1:])
2172 expected = np.einsum(test[0], *[t.numpy()
for t
in test[1:]])
2173 self.assertEqual(expected.shape, actual.shape, test[0])
2174 self.assertTrue(np.allclose(expected, actual.numpy()), test[0])
2176 actual2 = torch.einsum(test[0], *test[1:])
2177 self.assertEqual(expected.shape, actual2.shape, test[0])
2178 self.assertTrue(np.allclose(expected, actual2.numpy()), test[0])
2180 def do_einsum(*args):
2181 return torch.einsum(test[0], args)
2183 if test[0]
not in {
"i,i->",
"i,i->i",
"ij,ij->ij"}:
2184 gradcheck_inps = tuple(t.detach().requires_grad_()
for t
in test[1:])
2186 self.assertTrue(A._version == 0)
2188 def test_sum_all(self):
2189 def check_sum_all(tensor):
2190 pylist = tensor.reshape(-1).tolist()
2191 self.assertEqual(tensor.sum(), sum(pylist))
2194 check_sum_all(torch.randn(200000))
2195 check_sum_all(torch.randn(2000, 2)[:, 0])
2197 def _assert_matches_numpy(self, t, n):
2198 self.assertEqual(n.shape, t.shape)
2199 if t.dtype == torch.float:
2200 self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05,
2203 self.assertTrue(np.allclose(n, t.numpy(), equal_nan=
True))
2205 def _test_dim_ops(self, pytorch_op, numpy_op,
2206 use_floating=
True, use_integral=
True):
2207 def do_one(tensors_dict, dim):
2208 for category, tensors
in tensors_dict.items():
2209 if category ==
"slice":
2211 for tensor
in tensors:
2213 with warnings.catch_warnings():
2214 warnings.simplefilter(
"ignore")
2215 expected = numpy_op(tensor.numpy(), dim)
2216 actual = pytorch_op(tensor, dim)
2217 self._assert_matches_numpy(actual, expected)
2219 self._assert_matches_numpy(pytorch_op(tensor.cuda(),
2222 do_one(self._make_tensors((5, 400000), use_floating=use_floating,
2223 use_integral=use_integral), 1)
2224 do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2225 use_integral=use_integral), 0)
2226 do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2227 use_integral=use_integral), 1)
2228 do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2229 use_integral=use_integral), 2)
2230 do_one(self._make_tensors((100000, ), use_floating=use_floating,
2231 use_integral=use_integral), -1)
2232 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2233 use_integral=use_integral), 0)
2234 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2235 use_integral=use_integral), 1)
2236 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2237 use_integral=use_integral), 2)
2238 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2239 use_integral=use_integral), (1, 2))
2240 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2241 use_integral=use_integral), (1, -1))
2242 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2243 use_integral=use_integral), (0, 2))
2244 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2245 use_integral=use_integral), (0, 2, 1))
2247 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2248 def test_sum_dim(self):
2250 lambda t, d: t.sum(d),
2251 lambda n, d: n.sum(d))
2253 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2254 def test_mean_dim(self):
2256 lambda t, d: t.mean(d),
2257 lambda n, d: n.mean(d),
2260 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2261 def test_std_dim(self):
2262 for unbiased
in [
False,
True]:
2264 lambda t, d: t.std(d, unbiased=unbiased),
2265 lambda n, d: n.std(d, ddof=1
if unbiased
else 0),
2268 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2269 def test_var_dim(self):
2270 for unbiased
in [
False,
True]:
2272 lambda t, d: t.var(d, unbiased=unbiased),
2273 lambda n, d: n.var(d, ddof=1
if unbiased
else 0),
2276 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2277 @unittest.skipIf(
not TEST_SCIPY,
'Scipy not found')
2278 def test_logsumexp_dim(self):
2279 from scipy.special
import logsumexp
2281 lambda t, d: t.logsumexp(d),
2282 lambda n, d: logsumexp(n, d),
2285 def test_sum_out(self):
2286 x = torch.rand(100, 100)
2287 res1 = torch.sum(x, 1)
2288 res2 = torch.Tensor()
2289 torch.sum(x, 1, out=res2)
2290 self.assertEqual(res1, res2)
2291 x = torch.rand(100, 100, 100)
2292 res1 = x.sum(2).sum(1)
2293 res2 = torch.Tensor()
2294 torch.sum(x, (2, 1), out=res2)
2295 self.assertEqual(res1, res2)
2299 def test_prod(self):
2300 x = torch.rand(100, 100)
2301 res1 = torch.prod(x, 1)
2302 res2 = torch.Tensor()
2303 torch.prod(x, 1, out=res2)
2304 self.assertEqual(res1, res2)
2306 def test_cumsum(self):
2307 x = torch.rand(100, 100)
2308 res1 = torch.cumsum(x, 1)
2309 res2 = torch.Tensor()
2310 torch.cumsum(x, 1, out=res2)
2311 self.assertEqual(res1, res2)
2313 def test_cumprod(self):
2314 x = torch.rand(100, 100)
2315 res1 = torch.cumprod(x, 1)
2316 res2 = torch.Tensor()
2317 torch.cumprod(x, 1, out=res2)
2318 self.assertEqual(res1, res2)
2320 def _test_reduce_integer_upcast(self, fn, has_out=True):
2322 reduced_shape = fn(torch.ones(shape)).shape
2324 def _test_out(dtype, other_dtype):
2325 out = torch.ones(reduced_shape, dtype=dtype)
2326 result = fn(x, out=out)
2327 self.assertIs(out.dtype, result.dtype)
2328 self.assertEqual(fn(x.type(dtype)), result)
2329 result = fn(x, out=out, dtype=dtype)
2330 self.assertIs(out.dtype, result.dtype)
2331 self.assertEqual(fn(x.type(dtype)), result)
2333 self.assertRaises(RuntimeError,
lambda: fn(x, out=out, dtype=other_dtype))
2336 x = torch.ones(shape, dtype=dtype)
2337 expected_dtype = dtype
if dtype.is_floating_point
else torch.int64
2338 self.assertIs(expected_dtype, fn(x).dtype)
2339 self.assertEqual(fn(x.type(expected_dtype)), fn(x))
2341 if dtype.is_floating_point:
2342 other_dtype = torch.float32
if dtype == torch.float64
else torch.float64
2344 other_dtype = torch.int32
if dtype != torch.int32
else torch.int16
2345 self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype)
2346 self.assertEqual(fn(x.type(other_dtype)), fn(x, dtype=other_dtype))
2349 mixed_dtype = torch.int32
if dtype.is_floating_point
else torch.float32
2350 self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype)
2351 self.assertEqual(fn(x.type(mixed_dtype)), fn(x, dtype=mixed_dtype))
2354 _test_out(dtype, other_dtype)
2355 _test_out(dtype, mixed_dtype)
2357 def test_sum_integer_upcast(self):
2358 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.sum(x, **kwargs),
False)
2359 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.sum(x, 0, **kwargs))
2361 def test_prod_integer_upcast(self):
2362 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.prod(x, **kwargs),
False)
2363 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.prod(x, 0, **kwargs))
2365 def test_cumsum_integer_upcast(self):
2366 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.cumsum(x, 0, **kwargs))
2368 def test_cumprod_integer_upcast(self):
2369 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.cumprod(x, 0, **kwargs))
2371 def test_cross(self):
2372 x = torch.rand(100, 3, 100)
2373 y = torch.rand(100, 3, 100)
2374 res1 = torch.cross(x, y)
2375 res2 = torch.Tensor()
2376 torch.cross(x, y, out=res2)
2377 self.assertEqual(res1, res2)
2379 def test_zeros(self):
2380 res1 = torch.zeros(100, 100)
2381 res2 = torch.Tensor()
2382 torch.zeros(100, 100, out=res2)
2383 self.assertEqual(res1, res2)
2385 boolTensor = torch.zeros(2, 2, dtype=torch.bool)
2386 expected =
torch.tensor([[
False,
False], [
False,
False]], dtype=torch.bool)
2387 self.assertEqual(boolTensor, expected)
2389 halfTensor = torch.zeros(1, 1, dtype=torch.half)
2391 self.assertEqual(halfTensor, expected)
2393 def test_zeros_like(self):
2394 expected = torch.zeros(100, 100)
2396 res1 = torch.zeros_like(expected)
2397 self.assertEqual(res1, expected)
2400 def test_zeros_like_cuda(self):
2401 expected = torch.zeros(100, 100).cuda()
2403 res1 = torch.zeros_like(expected)
2404 self.assertEqual(res1, expected)
2407 def test_zeros_like_multiple_device(self):
2408 expected = torch.zeros(100, 100).cuda()
2409 x = torch.cuda.FloatTensor(100, 100, device=1)
2410 output = torch.zeros_like(x)
2411 self.assertEqual(output, expected)
2413 def test_zeros_out(self):
2415 out = torch.zeros(shape)
2416 torch.zeros(shape, out=out)
2419 self.assertRaises(RuntimeError,
lambda: torch.zeros(shape, dtype=torch.int64, out=out))
2420 self.assertRaises(RuntimeError,
lambda: torch.zeros(shape, layout=torch.sparse_coo, out=out))
2422 self.assertRaises(RuntimeError,
lambda: torch.zeros(shape, device=
'cuda', out=out))
2425 self.assertEqual(torch.zeros(shape), torch.zeros(shape, dtype=out.dtype, out=out))
2426 self.assertEqual(torch.zeros(shape), torch.zeros(shape, layout=torch.strided, out=out))
2427 self.assertEqual(torch.zeros(shape), torch.zeros(shape, device=
'cpu', out=out))
2430 def _test_histc(self, device):
2432 with self.assertRaisesRegex(RuntimeError,
'bins must be > 0'):
2433 torch.histc(
torch.tensor([1], dtype=torch.float, device=device), bins=-1)
2436 actual = torch.histc(
2437 torch.tensor([2, 5], dtype=torch.float, device=device))
2438 expected = torch.zeros(100, dtype=torch.float, device=device)
2439 expected.data[0] = 1
2440 expected.data[99] = 1
2441 self.assertEqual(expected, actual)
2443 actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
2445 torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
2448 actual = torch.histc(
2449 torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
2451 torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
2454 actual = torch.histc(
2455 torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
2456 bins=5, min=1, max=5)
2458 torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
2461 actual = torch.histc(
2462 torch.tensor([1, 2, 1], dtype=torch.float, device=device),
2463 bins=4, min=0, max=3)
2465 torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2468 actual = torch.histc(
2469 torch.tensor([1, 2, 1], dtype=torch.double, device=device),
2470 bins=4, min=0, max=3)
2472 torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
2475 actual = torch.histc(
2476 torch.tensor([1., 2, 1], dtype=torch.float, device=device),
2477 bins=4, min=0, max=3)
2479 torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2483 def test_against_np(tensor, bins=100, min=0, max=0):
2484 if min == 0
and max == 0:
2485 min = tensor.min().item()
2486 max = tensor.max().item()
2487 nparr = tensor.cpu().numpy()
2488 actual = torch.histc(tensor, bins=bins, min=min, max=max)
2489 expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0])
2490 self.assertEqual(actual.cpu(), expected)
2493 test_against_np(
torch.tensor([1., 2, 1], device=device))
2494 test_against_np(torch.randn(5000, device=device))
2497 test_against_np(torch.randn(301, device=device), bins=10)
2500 test_against_np(torch.randn(201, device=device), min=0.1, max=1)
2502 noncontig = torch.randn(100, 3, device=device)[:, 2]
2503 test_against_np(noncontig)
2505 multidim = torch.randn(3, 5, 7, 2, device=device)
2506 test_against_np(multidim)
2508 expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
2509 test_against_np(expanded)
2511 def test_histc_cpu(self):
2512 self._test_histc(self,
'cpu')
2514 def test_ones(self):
2515 res1 = torch.ones(100, 100)
2516 res2 = torch.Tensor()
2517 torch.ones(100, 100, out=res2)
2518 self.assertEqual(res1, res2)
2521 res1 = torch.ones(1, 2, dtype=torch.bool)
2522 expected =
torch.tensor([[
True,
True]], dtype=torch.bool)
2523 self.assertEqual(res1, expected)
2525 def test_ones_like(self):
2526 expected = torch.ones(100, 100)
2528 res1 = torch.ones_like(expected)
2529 self.assertEqual(res1, expected)
2532 expected =
torch.tensor([
True,
True], dtype=torch.bool)
2533 res1 = torch.ones_like(expected)
2534 self.assertEqual(res1, expected)
2537 def test_ones_like_cuda(self):
2538 expected = torch.ones(100, 100).cuda()
2540 res1 = torch.ones_like(expected)
2541 self.assertEqual(res1, expected)
2544 def test_ones_like_multiple_device(self):
2545 expected = torch.ones(100, 100).cuda()
2546 x = torch.cuda.FloatTensor(100, 100, device=1)
2547 output = torch.ones_like(x)
2548 self.assertEqual(output, expected)
2550 def test_dtypes(self):
2552 do_test_dtypes(self, all_dtypes, torch.strided, torch.device(
'cpu'))
2554 do_test_dtypes(self, all_dtypes, torch.strided, torch.device(
'cuda:0'))
2556 def test_copy_dtypes(self):
2558 for dtype
in all_dtypes:
2559 copied_dtype = copy.deepcopy(dtype)
2560 self.assertIs(dtype, copied_dtype)
2562 def test_device(self):
2563 cpu = torch.device(
'cpu')
2564 self.assertEqual(
'cpu', str(cpu))
2565 self.assertEqual(
'cpu', cpu.type)
2566 self.assertEqual(
None, cpu.index)
2568 cpu0 = torch.device(
'cpu:0')
2569 self.assertEqual(
'cpu:0', str(cpu0))
2570 self.assertEqual(
'cpu', cpu0.type)
2571 self.assertEqual(0, cpu0.index)
2573 cpu0 = torch.device(
'cpu', 0)
2574 self.assertEqual(
'cpu:0', str(cpu0))
2575 self.assertEqual(
'cpu', cpu0.type)
2576 self.assertEqual(0, cpu0.index)
2578 cuda = torch.device(
'cuda')
2579 self.assertEqual(
'cuda', str(cuda))
2580 self.assertEqual(
'cuda', cuda.type)
2581 self.assertEqual(
None, cuda.index)
2583 cuda1 = torch.device(
'cuda:1')
2584 self.assertEqual(
'cuda:1', str(cuda1))
2585 self.assertEqual(
'cuda', cuda1.type)
2586 self.assertEqual(1, cuda1.index)
2588 cuda1 = torch.device(
'cuda', 1)
2589 self.assertEqual(
'cuda:1', str(cuda1))
2590 self.assertEqual(
'cuda', cuda1.type)
2591 self.assertEqual(1, cuda1.index)
2593 self.assertRaises(RuntimeError,
lambda: torch.device(
'cpu:-1'))
2594 self.assertRaises(RuntimeError,
lambda: torch.device(
'cpu:1'))
2595 self.assertRaises(RuntimeError,
lambda: torch.device(
'cpu', -1))
2596 self.assertRaises(RuntimeError,
lambda: torch.device(
'cpu', 1))
2597 self.assertRaises(RuntimeError,
lambda: torch.device(
'cuda:-1'))
2598 self.assertRaises(RuntimeError,
lambda: torch.device(
'cuda', -1))
2599 self.assertRaises(RuntimeError,
lambda: torch.device(-1))
2601 self.assertRaises(RuntimeError,
lambda: torch.device(
'other'))
2602 self.assertRaises(RuntimeError,
lambda: torch.device(
'other:0'))
2604 device_set = {
'cpu',
'cpu:0',
'cuda',
'cuda:0',
'cuda:1',
'cuda:10',
'cuda:100'}
2605 device_hash_set = set()
2606 for device
in list(device_set):
2607 device_hash_set.add(hash(torch.device(device)))
2608 self.assertEqual(len(device_set), len(device_hash_set))
2610 def test_tensor_device(self):
2611 def assertEqual(device_str, fn):
2612 self.assertEqual(torch.device(device_str), fn().device)
2613 self.assertEqual(device_str, str(fn().device))
2616 assertEqual(
'cpu',
lambda: torch.ones((2, 3), dtype=torch.float32, device=
'cpu'))
2619 assertEqual(
'cpu',
lambda: torch.ones((2, 3), dtype=torch.float32, device=
'cpu:0'))
2620 assertEqual(
'cpu',
lambda:
torch.tensor(torch.ones((2, 3), dtype=torch.float32), device=
'cpu:0'))
2622 assertEqual(
'cpu',
lambda:
torch.tensor(np.random.randn(2, 3), device=
'cpu'))
2626 assertEqual(
'cuda:0',
lambda:
torch.tensor(5).cuda(
'cuda:0'))
2627 self.assertRaises(RuntimeError,
lambda:
torch.tensor(5).cuda(
'cpu'))
2628 self.assertRaises(RuntimeError,
lambda:
torch.tensor(5).cuda(
'cpu:0'))
2629 assertEqual(
'cuda:0',
lambda:
torch.tensor(5, dtype=torch.int64, device=0))
2630 assertEqual(
'cuda:0',
lambda:
torch.tensor(5, dtype=torch.int64, device=
'cuda:0'))
2632 lambda:
torch.tensor(5, dtype=torch.int64, device=
'cuda'))
2633 assertEqual(
'cuda:0',
lambda:
torch.tensor(torch.ones((2, 3), dtype=torch.float32), device=
'cuda:0'))
2635 assertEqual(
'cuda:0',
lambda:
torch.tensor(np.random.randn(2, 3), device=
'cuda:0'))
2639 assertEqual(
'cuda:1',
lambda:
torch.tensor(5).cuda(
'cuda:1'))
2640 assertEqual(
'cuda:1',
lambda:
torch.tensor(5, dtype=torch.int64, device=1))
2641 assertEqual(
'cuda:1',
lambda:
torch.tensor(5, dtype=torch.int64, device=
'cuda:1'))
2642 assertEqual(
'cuda:1',
lambda:
torch.tensor(torch.ones((2, 3), dtype=torch.float32), device=
'cuda:1'))
2644 assertEqual(
'cuda:1',
lambda:
torch.tensor(np.random.randn(2, 3), device=
'cuda:1'))
2647 def test_copy_behavior(t, non_blocking=False):
2648 self.assertIs(t, t.to(t, non_blocking=non_blocking))
2649 self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
2650 self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
2651 self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=
True))
2652 self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=
True))
2653 self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=
True))
2655 devices = [t.device]
2656 if t.device.type ==
'cuda':
2657 if t.device.index == -1:
2660 devices.append(
'cuda')
2661 for device
in devices:
2662 self.assertIs(t, t.to(device, non_blocking=non_blocking))
2663 self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
2664 self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=
True))
2665 self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=
True))
2668 test_copy_behavior(a)
2669 self.assertEqual(a.device, a.to(
'cpu').device)
2670 self.assertEqual(a.device, a.to(
'cpu', dtype=torch.float32).device)
2671 self.assertIs(torch.float32, a.to(
'cpu', dtype=torch.float32).dtype)
2672 self.assertEqual(a.device, a.to(torch.float32).device)
2673 self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
2674 self.assertEqual(a.data_ptr(), a.to(
'cpu').data_ptr())
2675 self.assertEqual(a.data_ptr(), a.to(dtype=a.dtype, device=a.device, copy=
False).data_ptr())
2676 self.assertEqual(a.data_ptr(), a.to(
'cpu', copy=
False).data_ptr())
2677 self.assertNotEqual(a.data_ptr(), a.to(
'cpu', copy=
True).data_ptr())
2680 for non_blocking
in [
True,
False]:
2683 test_copy_behavior(b, non_blocking)
2684 self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device)
2685 self.assertEqual(a.device, b.to(
'cpu', non_blocking=non_blocking).device)
2686 self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device)
2687 self.assertIs(torch.int32, b.to(
'cpu', dtype=torch.int32, non_blocking=non_blocking).dtype)
2688 self.assertEqual(a.device, b.to(
'cpu', dtype=torch.int32, non_blocking=non_blocking).device)
2689 self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
2690 self.assertEqual(b.device, b.to(dtype=torch.int32).device)
2692 def test_to_with_tensor(self):
2694 self.assertEqual(a.device, a.to(a).device)
2697 for non_blocking
in [
True,
False]:
2700 self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device)
2701 self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device)
2702 self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device)
2704 def test_empty_full(self):
2710 def test_dtype_out_match(self):
2711 d = torch.autograd.Variable(torch.DoubleTensor(2, 3))
2712 self.assertRaises(RuntimeError,
lambda: torch.zeros((2, 3), out=d, dtype=torch.float32))
2714 def test_constructor_dtypes(self):
2715 default_type = torch.Tensor().type()
2716 self.assertIs(torch.Tensor().dtype, torch.get_default_dtype())
2718 self.assertIs(torch.uint8, torch.ByteTensor.dtype)
2719 self.assertIs(torch.float32, torch.FloatTensor.dtype)
2720 self.assertIs(torch.float64, torch.DoubleTensor.dtype)
2723 self.assertIs(torch.float32, torch.get_default_dtype())
2727 self.assertIs(torch.float64, torch.get_default_dtype())
2731 self.assertIs(torch.float32, torch.get_default_dtype())
2736 self.assertIs(torch.float32, torch.get_default_dtype())
2737 self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
2741 self.assertIs(torch.float64, torch.get_default_dtype())
2753 def test_constructor_device_legacy(self):
2754 self.assertRaises(RuntimeError,
lambda: torch.FloatTensor(device=
'cuda'))
2755 self.assertRaises(RuntimeError,
lambda: torch.FloatTensor(torch.Size([2, 3, 4]), device=
'cuda'))
2756 self.assertRaises(RuntimeError,
lambda: torch.FloatTensor((2.0, 3.0), device=
'cuda'))
2758 self.assertRaises(RuntimeError,
lambda: torch.Tensor(device=
'cuda'))
2759 self.assertRaises(RuntimeError,
lambda: torch.Tensor(torch.Size([2, 3, 4]), device=
'cuda'))
2760 self.assertRaises(RuntimeError,
lambda: torch.Tensor((2.0, 3.0), device=
'cuda'))
2762 x = torch.randn((3,), device=
'cpu')
2763 self.assertRaises(RuntimeError,
lambda: x.new(device=
'cuda'))
2764 self.assertRaises(RuntimeError,
lambda: x.new(torch.Size([2, 3, 4]), device=
'cuda'))
2765 self.assertRaises(RuntimeError,
lambda: x.new((2.0, 3.0), device=
'cuda'))
2768 self.assertRaises(RuntimeError,
lambda: torch.cuda.FloatTensor(device=
'cpu'))
2769 self.assertRaises(RuntimeError,
lambda: torch.cuda.FloatTensor(torch.Size([2, 3, 4]), device=
'cpu'))
2770 self.assertRaises(RuntimeError,
lambda: torch.cuda.FloatTensor((2.0, 3.0), device=
'cpu'))
2772 default_type = torch.Tensor().type()
2774 self.assertRaises(RuntimeError,
lambda: torch.Tensor(device=
'cpu'))
2775 self.assertRaises(RuntimeError,
lambda: torch.Tensor(torch.Size([2, 3, 4]), device=
'cpu'))
2776 self.assertRaises(RuntimeError,
lambda: torch.Tensor((2.0, 3.0), device=
'cpu'))
2780 x = torch.randn((3,), device=
'cuda')
2781 self.assertRaises(RuntimeError,
lambda: x.new(device=
'cpu'))
2782 self.assertRaises(RuntimeError,
lambda: x.new(torch.Size([2, 3, 4]), device=
'cpu'))
2783 self.assertRaises(RuntimeError,
lambda: x.new((2.0, 3.0), device=
'cpu'))
2785 def test_type(self):
2786 x = torch.randn(3, 3).double()
2787 self.assertEqual(x.type(
'torch.FloatTensor').dtype, torch.float32)
2788 self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32)
2789 self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype())
2790 self.assertEqual(x.type(torch.int32).dtype, torch.int32)
2792 def test_tensor_factory(self):
2793 expected = torch.Tensor([1, 1])
2796 self.assertEqual(res1, expected)
2799 self.assertEqual(res1, expected)
2800 self.assertIs(torch.int, res1.dtype)
2804 self.assertEqual(res2, expected)
2806 self.assertEqual(expected, torch.ones_like(expected))
2809 self.assertEqual(res1, expected)
2810 self.assertIs(torch.int, res1.dtype)
2814 for dtype
in [np.float64, np.int64, np.int8, np.uint8]:
2815 a = np.array([5.]).astype(dtype)
2817 self.assertEqual(5., res1[0].item())
2819 self.assertEqual(5., res1[0].item())
2822 a =
torch.tensor([
True,
True,
False,
True,
True], dtype=torch.bool)
2823 b =
torch.tensor([-1, -1.1, 0, 1, 1.1], dtype=torch.bool)
2824 self.assertEqual(a, b)
2826 def test_tensor_factory_copy_var(self):
2828 def check_copy(copy, is_leaf, requires_grad, data_ptr=None):
2829 if data_ptr
is None:
2830 data_ptr = copy.data_ptr
2831 self.assertEqual(copy.data, source.data)
2832 self.assertTrue(copy.is_leaf == is_leaf)
2833 self.assertTrue(copy.requires_grad == requires_grad)
2834 self.assertTrue(copy.data_ptr == data_ptr)
2836 source = torch.randn(5, 5, dtype=torch.double, requires_grad=
True)
2839 check_copy(
torch.tensor(source, requires_grad=
False),
True,
False)
2840 check_copy(
torch.tensor(source, requires_grad=
True),
True,
True)
2843 copy = torch.randn(1)
2844 check_copy(copy.new_tensor(source),
True,
False)
2845 check_copy(copy.new_tensor(source, requires_grad=
False),
True,
False)
2846 check_copy(copy.new_tensor(source, requires_grad=
True),
True,
True)
2849 check_copy(torch.as_tensor(source), source.is_leaf, source.requires_grad, source.data_ptr)
2850 check_copy(torch.as_tensor(source, dtype=torch.float),
False,
True)
2852 def test_tensor_factory_type_inference(self):
2853 def test_inference(default_dtype):
2854 saved_dtype = torch.get_default_dtype()
2860 self.assertIs(torch.int32,
torch.tensor(5, dtype=torch.int32).dtype)
2861 self.assertIs(default_dtype,
torch.tensor(((7, 5), (9, 5.))).dtype)
2862 self.assertIs(default_dtype,
torch.tensor(((5., 5), (3, 5))).dtype)
2863 self.assertIs(torch.int64,
torch.tensor(((5, 3), (3, 5))).dtype)
2866 self.assertIs(torch.float64,
torch.tensor(np.array(())).dtype)
2867 self.assertIs(torch.float64,
torch.tensor(np.array(5.)).dtype)
2868 if np.array(5).dtype == np.int64:
2869 self.assertIs(torch.int64,
torch.tensor(np.array(5)).dtype)
2871 self.assertIs(torch.int32,
torch.tensor(np.array(5)).dtype)
2872 self.assertIs(torch.uint8,
torch.tensor(np.array(3, dtype=np.uint8)).dtype)
2873 self.assertIs(default_dtype,
torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype)
2874 self.assertIs(torch.float64,
torch.tensor(((7, 5), (9, np.array(5.)))).dtype)
2875 self.assertIs(torch.int64,
torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype)
2878 test_inference(torch.float64)
2879 test_inference(torch.float32)
2882 def test_tensor_factory_cuda_type_inference(self):
2883 saved_type = torch.Tensor().type()
2887 self.assertEqual(torch.device(
'cuda:0'),
torch.tensor(0.).device)
2890 self.assertEqual(torch.device(
'cuda:0'),
torch.tensor(0.).device)
2894 def test_tensor_factory_cuda_type(self):
2895 saved_type = torch.Tensor().type()
2897 x = torch.zeros((5, 5))
2898 self.assertIs(torch.float32, x.dtype)
2899 self.assertTrue(x.is_cuda)
2901 x = torch.zeros((5, 5))
2902 self.assertIs(torch.float64, x.dtype)
2903 self.assertTrue(x.is_cuda)
2908 def test_tensor_factories_empty_bool(self):
2909 expectedShape = (1, 2)
2910 test = torch.empty(expectedShape, dtype=torch.bool)
2911 self.assertEqual(expectedShape, test.shape)
2912 self.assertEqual(expectedShape, torch.empty_like(test).shape)
2914 test = torch.full(expectedShape,
True, dtype=torch.bool)
2915 self.assertEqual(test,
torch.tensor([[
True,
True]], dtype=torch.bool))
2916 self.assertEqual(expectedShape, test.shape)
2917 self.assertEqual(expectedShape, torch.full_like(test,
True).shape)
2919 def test_tensor_factories_empty(self):
2921 shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)]
2924 for device
in devices:
2925 for shape
in shapes:
2926 self.assertEqual(shape, torch.zeros(shape, device=device).shape)
2927 self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device)).shape)
2928 self.assertEqual(shape, torch.empty(shape, device=device).shape)
2929 self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device)).shape)
2930 self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device).shape)
2931 self.assertEqual(shape, torch.full(shape, 3, device=device).shape)
2932 self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device), 3).shape)
2933 self.assertEqual(shape, torch.ones(shape, device=device).shape)
2934 self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device)).shape)
2935 self.assertEqual(shape, torch.rand(shape, device=device).shape)
2936 self.assertEqual(shape, torch.rand_like(torch.zeros(shape, device=device)).shape)
2937 self.assertEqual(shape, torch.randn(shape, device=device).shape)
2938 self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device)).shape)
2939 self.assertEqual(shape, torch.randint(6, shape, device=device).shape)
2940 self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device), 6).shape)
2942 self.assertEqual((0,), torch.arange(0, device=device).shape)
2943 self.assertEqual((0, 0), torch.eye(0, device=device).shape)
2944 self.assertEqual((0, 0), torch.eye(0, 0, device=device).shape)
2945 self.assertEqual((5, 0), torch.eye(5, 0, device=device).shape)
2946 self.assertEqual((0, 5), torch.eye(0, 5, device=device).shape)
2947 self.assertEqual((0,), torch.linspace(1, 1, 0, device=device).shape)
2948 self.assertEqual((0,), torch.logspace(1, 1, 0, device=device).shape)
2949 self.assertEqual((0,), torch.randperm(0, device=device).shape)
2950 self.assertEqual((0,), torch.bartlett_window(0, device=device).shape)
2951 self.assertEqual((0,), torch.bartlett_window(0, periodic=
False, device=device).shape)
2952 self.assertEqual((0,), torch.hamming_window(0, device=device).shape)
2953 self.assertEqual((0,), torch.hann_window(0, device=device).shape)
2954 self.assertEqual((1, 1, 0),
torch.tensor([[[]]], device=device).shape)
2955 self.assertEqual((1, 1, 0), torch.as_tensor([[[]]], device=device).shape)
2957 def test_new_tensor(self):
2958 expected = torch.autograd.Variable(torch.ByteTensor([1, 1]))
2960 res1 = expected.new_tensor([1, 1])
2961 self.assertEqual(res1, expected)
2962 res1 = expected.new_tensor([1, 1], dtype=torch.int)
2963 self.assertEqual(res1, expected)
2964 self.assertIs(torch.int, res1.dtype)
2967 res2 = expected.new_tensor(expected)
2968 self.assertEqual(res2, expected)
2970 self.assertEqual(expected, torch.ones_like(expected))
2971 res2 = expected.new_tensor(expected, dtype=torch.int)
2972 self.assertEqual(res2, expected)
2973 self.assertIs(torch.int, res2.dtype)
2979 res1 = res1.new_tensor(a)
2980 self.assertEqual(5., res1[0].item())
2982 self.assertEqual(5., res1[0].item())
2985 expected = expected.cuda(1)
2986 res1 = expected.new_tensor([1, 1])
2987 self.assertEqual(res1.get_device(), expected.get_device())
2988 res1 = expected.new_tensor([1, 1], dtype=torch.int)
2989 self.assertIs(torch.int, res1.dtype)
2990 self.assertEqual(res1.get_device(), expected.get_device())
2992 res2 = expected.new_tensor(expected)
2993 self.assertEqual(res2.get_device(), expected.get_device())
2994 res2 = expected.new_tensor(expected, dtype=torch.int)
2995 self.assertIs(torch.int, res1.dtype)
2996 self.assertEqual(res2.get_device(), expected.get_device())
2997 res2 = expected.new_tensor(expected, dtype=torch.int, device=0)
2998 self.assertIs(torch.int, res1.dtype)
2999 self.assertEqual(res2.get_device(), 0)
3001 res1 = expected.new_tensor(1)
3002 self.assertEqual(res1.get_device(), expected.get_device())
3003 res1 = expected.new_tensor(1, dtype=torch.int)
3004 self.assertIs(torch.int, res1.dtype)
3005 self.assertEqual(res1.get_device(), expected.get_device())
3007 def test_as_tensor(self):
3009 x = [[0, 1], [2, 3]]
3011 self.assertEqual(
torch.tensor(x, dtype=torch.float32), torch.as_tensor(x, dtype=torch.float32))
3015 with self.assertRaisesRegex(TypeError,
"invalid data type"):
3022 with self.assertRaisesRegex(TypeError,
"self-referential lists are incompatible"):
3027 with self.assertRaisesRegex(TypeError,
"self-referential lists are incompatible"):
3033 self.assertIs(y, torch.as_tensor(y))
3034 self.assertIsNot(y, torch.as_tensor(y, dtype=torch.float32))
3036 self.assertIsNot(y, torch.as_tensor(y, device=
'cuda'))
3037 y_cuda = y.to(
'cuda')
3038 self.assertIs(y_cuda, torch.as_tensor(y_cuda))
3039 self.assertIs(y_cuda, torch.as_tensor(y_cuda, device=
'cuda'))
3043 for dtype
in [np.float64, np.int64, np.int8, np.uint8]:
3044 n = np.random.rand(5, 6).astype(dtype)
3045 n_astensor = torch.as_tensor(n)
3047 n_astensor[0][0] = 25.7
3051 n = np.random.rand(5, 6).astype(np.float32)
3052 n_astensor = torch.as_tensor(n, dtype=torch.float64)
3053 self.assertEqual(
torch.tensor(n, dtype=torch.float64), n_astensor)
3054 n_astensor[0][1] = 250.8
3055 self.assertNotEqual(
torch.tensor(n, dtype=torch.float64), n_astensor)
3059 n = np.random.randn(5, 6)
3060 n_astensor = torch.as_tensor(n, device=
'cuda')
3061 self.assertEqual(
torch.tensor(n, device=
'cuda'), n_astensor)
3062 n_astensor[0][2] = 250.9
3063 self.assertNotEqual(
torch.tensor(n, device=
'cuda'), n_astensor)
3065 def test_diag(self):
3066 x = torch.rand(100, 100)
3067 res1 = torch.diag(x)
3068 res2 = torch.Tensor()
3069 torch.diag(x, out=res2)
3070 self.assertEqual(res1, res2)
3073 def _test_diagonal(self, dtype, device):
3074 x = torch.randn((100, 100), dtype=dtype, device=device)
3075 result = torch.diagonal(x)
3076 expected = torch.diag(x)
3077 self.assertEqual(result, expected)
3079 x = torch.randn((100, 100), dtype=dtype, device=device)
3080 result = torch.diagonal(x, 17)
3081 expected = torch.diag(x, 17)
3082 self.assertEqual(result, expected)
3084 def test_diagonal(self):
3085 self._test_diagonal(self, dtype=torch.float32, device=
'cpu')
3087 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
3088 def test_diagonal_multidim(self):
3089 x = torch.randn(10, 11, 12, 13)
3091 for args
in [(2, 2, 3),
3095 result = torch.diagonal(x, *args)
3096 expected = xn.diagonal(*args)
3097 self.assertEqual(expected.shape, result.shape)
3098 self.assertTrue(np.allclose(expected, result.numpy()))
3100 xp = x.permute(1, 2, 3, 0)
3101 result = torch.diagonal(xp, 0, -2, -1)
3102 expected = xp.numpy().diagonal(0, -2, -1)
3103 self.assertEqual(expected.shape, result.shape)
3104 self.assertTrue(np.allclose(expected, result.numpy()))
3107 def _test_diag_embed(self, dtype, device):
3108 x = torch.arange(3 * 4, dtype=dtype, device=device).view(3, 4)
3109 result = torch.diag_embed(x)
3110 expected = torch.stack([torch.diag(r)
for r
in x], 0)
3111 self.assertEqual(result, expected)
3113 result = torch.diag_embed(x, offset=1, dim1=0, dim2=2)
3114 expected = torch.stack([torch.diag(r, 1)
for r
in x], 1)
3115 self.assertEqual(result, expected)
3117 def test_diag_embed(self):
3118 self._test_diag_embed(self, dtype=torch.float32, device=
'cpu')
3121 def _test_diagflat(self, dtype, device):
3123 x = torch.randn((100,), dtype=dtype, device=device)
3124 result = torch.diagflat(x)
3125 expected = torch.diag(x)
3126 self.assertEqual(result, expected)
3129 x = torch.randn((100,), dtype=dtype, device=device)
3130 result = torch.diagflat(x, 17)
3131 expected = torch.diag(x, 17)
3132 self.assertEqual(result, expected)
3135 x = torch.randn((2, 3, 4), dtype=dtype, device=device)
3136 result = torch.diagflat(x)
3137 expected = torch.diag(x.contiguous().view(-1))
3138 self.assertEqual(result, expected)
3141 x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0)
3142 self.assertFalse(x.is_contiguous())
3143 result = torch.diagflat(x)
3144 expected = torch.diag(x.contiguous().view(-1))
3145 self.assertEqual(result, expected)
3147 def test_diagflat(self):
3148 self._test_diagflat(self, dtype=torch.float32, device=
'cpu')
3151 res1 = torch.eye(100, 100)
3152 res2 = torch.Tensor()
3153 torch.eye(100, 100, out=res2)
3154 self.assertEqual(res1, res2)
3156 def test_renorm(self):
3157 m1 = torch.randn(10, 5)
3158 res1 = torch.Tensor()
3160 def renorm(matrix, value, dim, max_norm):
3161 m1 = matrix.transpose(dim, 0).contiguous()
3163 m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0))))
3164 norms = m2.norm(value, 1,
True)
3166 new_norms = norms.clone()
3167 new_norms[torch.gt(norms, max_norm)] = max_norm
3168 new_norms.div_(norms.add_(1e-7))
3170 m1.mul_(new_norms.expand_as(m1))
3171 return m1.transpose(dim, 0)
3174 maxnorm = m1.norm(2, 1).mean()
3175 m2 = renorm(m1, 2, 1, maxnorm)
3176 m1.renorm_(2, 1, maxnorm)
3177 self.assertEqual(m1, m2, 1e-5)
3178 self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), 1e-5)
3180 m1 = torch.randn(3, 4, 5)
3181 m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
3182 maxnorm = m2.norm(2, 0).mean()
3183 m2 = renorm(m2, 2, 1, maxnorm)
3184 m1.renorm_(2, 1, maxnorm)
3185 m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
3186 self.assertEqual(m3, m2)
3187 self.assertEqual(m3.norm(2, 0), m2.norm(2, 0))
3190 def _test_renorm_ps(self, device):
3192 x = torch.randn(5, 5)
3194 for p
in [1, 2, 3, 4, inf]:
3195 res = x.renorm(p, 1, 1)
3196 expected = x / x.norm(p, 0, keepdim=
True).clamp(min=1)
3197 self.assertEqual(res.numpy(), expected.numpy(),
"renorm failed for {}-norm".format(p))
3199 def test_renorm_ps(self):
3200 self._test_renorm_ps(self, device=
'cpu')
3203 def test_renorm_ps_cuda(self):
3204 self._test_renorm_ps(self, device=
'cuda')
3207 def _test_multinomial(self, type):
3208 def make_prob_dist(shape, is_contiguous):
3210 return type(*shape).uniform_()
3211 elif len(shape) == 1:
3212 return type(*(shape + [5])).uniform_()[:, 2]
3215 new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
3216 prob_dist = type(*new_shape).uniform_()
3217 prob_dist = prob_dist.transpose(1, 4)
3218 prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
3219 assert not prob_dist.is_contiguous()
3222 for is_contiguous
in (
True,
False):
3225 for n_col
in range(4, 5 + 1):
3226 prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
3228 zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist()
3229 for i, j
in enumerate(zero_prob_indices):
3232 n_sample = n_col * 3
3233 sample_indices = torch.multinomial(prob_dist, n_sample,
True)
3234 self.assertEqual(prob_dist.dim(), 2)
3235 self.assertEqual(sample_indices.size(1), n_sample)
3236 for i
in range(n_row):
3237 zero_prob_idx = zero_prob_indices[i]
3238 if zero_prob_idx < 0:
3240 for j
in range(n_sample):
3241 self.assertNotEqual(sample_indices[i, j], zero_prob_idx,
3242 "sampled an index with zero probability")
3246 for n_col
in range(2, 10 + 1, 2):
3247 prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
3249 zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist()
3250 for i, j
in enumerate(zero_prob_indices):
3253 n_sample = max(1, n_col - 2)
3254 sample_indices = torch.multinomial(prob_dist, n_sample,
False)
3255 self.assertEqual(prob_dist.dim(), 2)
3256 self.assertEqual(sample_indices.size(1), n_sample)
3257 for i
in range(n_row):
3259 zero_prob_idx = zero_prob_indices[i]
3260 for j
in range(n_sample):
3261 sample_idx = sample_indices[i, j]
3262 if zero_prob_idx >= 0:
3263 self.assertNotEqual(sample_idx, zero_prob_idx,
3264 "sampled an index with zero probability")
3265 self.assertNotIn(sample_idx, row_samples,
"sampled an index twice")
3266 row_samples[sample_idx] =
True 3270 prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1)
3272 prob_dist[zero_prob_idx] = 0
3274 sample_indices = torch.multinomial(prob_dist, n_sample,
True)
3275 for sample_index
in sample_indices:
3276 self.assertNotEqual(sample_index, zero_prob_idx,
"sampled an index with zero probability")
3277 s_dim = sample_indices.dim()
3278 self.assertEqual(sample_indices.dim(), 1,
"wrong number of dimensions")
3279 self.assertEqual(prob_dist.dim(), 1,
"wrong number of prob_dist dimensions")
3280 self.assertEqual(sample_indices.size(0), n_sample,
"wrong number of samples")
3282 def test_multinomial(self):
3283 self._test_multinomial(self, torch.FloatTensor)
3285 def _spawn_method(self, method, arg):
3287 mp.set_start_method(
'spawn')
3288 except RuntimeError:
3290 with mp.Pool(1)
as pool:
3291 self.assertTrue(pool.map(method, [arg]))
3294 def _test_multinomial_invalid_probs(probs):
3297 torch.multinomial(probs.to(
'cpu'), 2)
3299 except RuntimeError
as e:
3300 return 'invalid multinomial distribution' in str(e)
3302 @unittest.skipIf(NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that \ 3303 don't support multiprocessing with spawn start method")
3304 @unittest.skipIf(IS_WINDOWS,
'FIXME: CUDA OOM error on Windows')
3305 @unittest.skipIf(
not PY3,
3306 "spawn start method is not supported in Python 2, \ 3307 but we need it for for testing failure case for CPU RNG on Windows")
3308 def test_multinomial_invalid_probs(self):
3309 test_method = _TestTorchMixin._test_multinomial_invalid_probs
3310 self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
3311 self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
3312 self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
3313 self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
3314 self._spawn_method(test_method, torch.Tensor([0, 1, 0]))
3317 def test_range(self):
3318 res1 = torch.range(0, 1)
3319 res2 = torch.Tensor()
3320 torch.range(0, 1, out=res2)
3321 self.assertEqual(res1, res2, 0)
3324 x = torch.zeros(2, 3)
3325 torch.range(0, 3, out=x.narrow(1, 1, 2))
3326 res2 = torch.Tensor(((0, 0, 1), (0, 2, 3)))
3327 self.assertEqual(x, res2, 1e-16)
3330 res1 = torch.Tensor((1, 0))
3331 res2 = torch.Tensor()
3332 torch.range(1, 0, -1, out=res2)
3333 self.assertEqual(res1, res2, 0)
3336 res1 = torch.ones(1)
3337 res2 = torch.Tensor()
3338 torch.range(1, 1, -1, out=res2)
3339 self.assertEqual(res1, res2, 0)
3340 torch.range(1, 1, 1, out=res2)
3341 self.assertEqual(res1, res2, 0)
3344 res1 = torch.range(0.6, 0.9, 0.1, out=torch.FloatTensor())
3345 self.assertEqual(res1.size(0), 4)
3346 res1 = torch.range(1, 10, 0.3, out=torch.FloatTensor())
3347 self.assertEqual(res1.size(0), 31)
3350 res1 = torch.range(0.6, 0.9, 0.1, out=torch.DoubleTensor())
3351 self.assertEqual(res1.size(0), 4)
3352 res1 = torch.range(1, 10, 0.3, out=torch.DoubleTensor())
3353 self.assertEqual(res1.size(0), 31)
3355 def test_range_warning(self):
3356 with warnings.catch_warnings(record=
True)
as w:
3358 self.assertEqual(len(w), 1)
3360 def test_arange(self):
3361 res1 = torch.arange(0, 1)
3362 res2 = torch.Tensor()
3363 torch.arange(0, 1, out=res2)
3364 self.assertEqual(res1, res2, 0)
3367 res1 = torch.arange(10)
3368 res2 = torch.arange(0, 10)
3369 self.assertEqual(res1, res2, 0)
3372 x = torch.zeros(2, 3)
3373 torch.arange(0, 4, out=x.narrow(1, 1, 2))
3374 res2 = torch.Tensor(((0, 0, 1), (0, 2, 3)))
3375 self.assertEqual(x, res2, 1e-16)
3378 res1 = torch.Tensor((1, 0))
3379 res2 = torch.Tensor()
3380 torch.arange(1, -1, -1, out=res2)
3381 self.assertEqual(res1, res2, 0)
3384 res1 = torch.ones(1)
3385 res2 = torch.Tensor()
3386 torch.arange(1, 0, -1, out=res2)
3387 self.assertEqual(res1, res2, 0)
3388 torch.arange(1, 2, 1, out=res2)
3389 self.assertEqual(res1, res2, 0)
3392 res1 = torch.arange(0.6, 0.89, 0.1, out=torch.FloatTensor())
3393 self.assertEqual(res1, [0.6, 0.7, 0.8])
3394 res1 = torch.arange(1, 10, 0.3, out=torch.FloatTensor())
3395 self.assertEqual(res1.size(0), 30)
3396 self.assertEqual(res1[0], 1)
3397 self.assertEqual(res1[29], 9.7)
3400 res1 = torch.arange(0.6, 0.89, 0.1, out=torch.DoubleTensor())
3401 self.assertEqual(res1, [0.6, 0.7, 0.8])
3402 res1 = torch.arange(1, 10, 0.3, out=torch.DoubleTensor())
3403 self.assertEqual(res1.size(0), 30)
3404 self.assertEqual(res1[0], 1)
3405 self.assertEqual(res1[29], 9.7)
3408 r = torch.arange(0, 5)
3409 self.assertEqual(r.min(), 0)
3410 self.assertEqual(r.max(), 4)
3411 self.assertEqual(r.numel(), 5)
3413 r = torch.arange(0, 5, 2)
3414 self.assertEqual(r.min(), 0)
3415 self.assertEqual(r.max(), 4)
3416 self.assertEqual(r.numel(), 3)
3418 r1 = torch.arange(0, 5 + 1e-6)
3419 r2 = torch.arange(0, 5)
3420 r3 = torch.arange(0, 5 - 1e-6)
3421 self.assertEqual(r1[:-1], r2, 0)
3422 self.assertEqual(r2, r3, 0)
3424 r1 = torch.arange(10, -1 + 1e-6, -1)
3425 r2 = torch.arange(10, -1, -1)
3426 r3 = torch.arange(10, -1 - 1e-6, -1)
3427 self.assertEqual(r1, r2, 0)
3428 self.assertEqual(r2, r3[:-1], 0)
3430 msg =
"unsupported range" 3431 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(0, float(
'inf')))
3432 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'inf')))
3435 for device
in devices:
3436 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(-5, float(
'nan'), device=device))
3438 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(0, float(
'-inf'), -1, device=device))
3439 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(0, float(
'inf'), device=device))
3440 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'-inf'), 10, device=device))
3441 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'nan'), 10, device=device))
3442 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'inf'), device=device))
3443 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'nan'), device=device))
3445 self.assertRaisesRegex(
3446 RuntimeError,
"overflow",
3447 lambda: torch.arange(1.175494351e-38, 3.402823466e+38, device=device))
3449 def test_arange_inference(self):
3450 saved_dtype = torch.get_default_dtype()
3453 self.assertIs(torch.float32, torch.arange(1.).dtype)
3454 self.assertIs(torch.float32, torch.arange(
torch.tensor(1.)).dtype)
3455 self.assertIs(torch.float32, torch.arange(
torch.tensor(1., dtype=torch.float64)).dtype)
3457 self.assertIs(torch.int64, torch.arange(1).dtype)
3458 self.assertIs(torch.int64, torch.arange(
torch.tensor(1)).dtype)
3459 self.assertIs(torch.int64, torch.arange(
torch.tensor(1, dtype=torch.int16)).dtype)
3462 self.assertIs(torch.float32, torch.arange(1., 3).dtype)
3463 self.assertIs(torch.float32, torch.arange(
torch.tensor(1., dtype=torch.float64), 3).dtype)
3464 self.assertIs(torch.float32, torch.arange(1, 3.).dtype)
3466 self.assertIs(torch.float32, torch.arange(1, 3, 1.).dtype)
3467 self.assertIs(torch.float32,
3472 self.assertIs(torch.int64, torch.arange(1, 3).dtype)
3473 self.assertIs(torch.int64, torch.arange(
torch.tensor(1), 3).dtype)
3475 self.assertIs(torch.int64, torch.arange(1, 3, 1).dtype)
3476 self.assertIs(torch.int64,
3482 def test_randint_inference(self):
3484 for args
in [(3,), (1, 3)]:
3485 self.assertIs(torch.int64, torch.randint(*args, size=size).dtype)
3486 self.assertIs(torch.int64, torch.randint(*args, size=size, layout=torch.strided).dtype)
3487 self.assertIs(torch.int64, torch.randint(*args, size=size, generator=torch.default_generator).dtype)
3488 self.assertIs(torch.float32, torch.randint(*args, size=size, dtype=torch.float32).dtype)
3489 out = torch.empty(size, dtype=torch.float32)
3490 self.assertIs(torch.float32, torch.randint(*args, size=size, out=out).dtype)
3491 self.assertIs(torch.float32, torch.randint(*args, size=size, out=out, dtype=torch.float32).dtype)
3492 out = torch.empty(size, dtype=torch.int64)
3493 self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype)
3494 self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype)
3497 def _select_broadcastable_dims(dims_full=None):
3499 if dims_full
is None:
3501 ndims = random.randint(1, 4)
3502 dims_full = [random.randint(1, 8)
for _
in range(ndims)]
3504 ndims = len(dims_full)
3509 smaller_ndims = random.randint(1, ndims)
3512 for i
in range(ndims - 1, -1, -1):
3513 j = random.randint(1, 3)
3519 dl = 1
if len(dims_small) < smaller_ndims
else dims_full[i]
3523 dims_large = [dl] + dims_large
3524 if len(dims_small) < smaller_ndims:
3525 dims_small = [ds] + dims_small
3526 return (dims_small, dims_large, dims_full)
3529 def _test_broadcast(self, cast):
3533 "dist",
"atan2",
"pow",
"lerp",
"add",
3534 "sub",
"mul",
"div",
"fmod",
"remainder",
3535 "eq",
"ge",
"gt",
"le",
"lt",
"max",
"min",
"ne",
3536 "addcdiv",
"addcmul",
"masked_scatter",
"masked_select",
"masked_fill",
3537 "map",
"map2",
"copy" 3540 fns_3_args = {
"addcdiv",
"addcmul",
"map2"}
3543 (dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
3544 full1d = cast(torch.randn(*dims_full).flatten().float())
3545 small = cast(torch.randn(*dims_small).float())
3546 large = cast(torch.randn(*dims_large).float())
3547 small_expanded = small.expand(*dims_full)
3548 large_expanded = large.expand(*dims_full)
3550 small2_expanded =
None 3551 if fn
in fns_3_args:
3553 (dims_small2, _, _) = self._select_broadcastable_dims(dims_full)
3554 small2 = cast(torch.randn(*dims_small2).float())
3555 small2_expanded = small2.expand(*dims_full)
3557 if small.is_cuda
and fn
in [
'map',
'map2']:
3561 if hasattr(large_expanded, fn):
3564 expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
3566 def tensorfn(myfn, t1, t2):
3568 return myfn(t1, 0.5)
3569 elif fn ==
"masked_select":
3571 elif fn ==
"masked_scatter":
3572 return myfn(t1 < 0.5, full1d)
3573 elif fn ==
"masked_fill":
3574 return myfn(t1 < 0.5, 1.0)
3575 elif fn
in fns_3_args:
3576 return myfn(1, t1, t2)
3581 for first, second, third
in [(large, small, small2), (small, large, small2),
3582 (small2, small, large), (small2, large, small)]:
3585 method_expanded = getattr(expanded[first], fn)
3586 method = getattr(first, fn)
3587 r1 = tensorfn(method_expanded, expanded[second], expanded[third])
3588 r2 = tensorfn(method, second, third)
3589 self.assertEqual(r1, r2)
3592 if hasattr(torch, fn):
3593 fntorch = getattr(torch, fn)
3594 expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
3596 def torchfn(t1, t2, t3):
3598 return fntorch(t1, t2, 0.5)
3599 elif fn ==
"masked_select":
3600 return fntorch(t1, t2 < 0)
3601 elif fn ==
"masked_scatter":
3602 return fntorch(t1, t2 < 0.5, full1d)
3603 elif fn ==
"masked_fill":
3604 return fntorch(t1, t2 < 0.5, 1.0)
3605 elif fn
in fns_3_args:
3606 return fntorch(t1, 1.0, t2, t3)
3608 return fntorch(t1, t2)
3611 for first, second, third
in [(large, small, small2), (small, large, small2),
3612 (small2, small, large), (small2, large, small)]:
3615 r1 = torchfn(expanded[first], expanded[second], expanded[third])
3616 r2 = torchfn(first, second, third)
3617 self.assertEqual(r1, r2)
3622 if not hasattr(large_expanded, fn +
"_"):
3626 large_expanded_clone = large_expanded.clone()
3628 def tensorfn_inplace(t0, t1, t2=None):
3629 t0_fn = getattr(t0, fn +
"_")
3631 return t0_fn(t1, 0.5)
3632 elif fn ==
"masked_scatter":
3633 return t0_fn(t1 < 0.5, full1d)
3634 elif fn ==
"masked_fill":
3635 return t0_fn(t1 < 0.5, 1.0)
3637 return t0_fn(t1,
lambda x, y: x + y)
3639 return t0_fn(t1, t2,
lambda x, y, z: x + y + z)
3640 elif fn
in fns_3_args:
3641 return t0_fn(1.0, t1, t2)
3644 r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
3645 r2 = tensorfn_inplace(large_expanded_clone, small, small2)
3648 if (0
not in large_expanded.stride()
and 0
not in large_expanded_clone.stride()):
3649 self.assertEqual(r1, r2)
3651 def broadcastable(t0, t1, t2=None):
3656 except RuntimeError:
3660 def _test_in_place_broadcastable(t0, t1, t2=None):
3661 if not broadcastable(t0, t1, t2):
3662 same_size = t0.numel() == t1.numel()
and (t0.numel() == t2.numel()
if t2
is not None else True)
3664 self.assertRaises(RuntimeError,
lambda: tensorfn_inplace(t0, t1, t2))
3666 tensorfn_inplace(t0, t1, t2)
3668 if fn
not in fns_3_args:
3669 _test_in_place_broadcastable(small, large_expanded)
3670 _test_in_place_broadcastable(small, large)
3672 _test_in_place_broadcastable(small2, small_expanded, large_expanded)
3673 _test_in_place_broadcastable(small2, small, large)
3675 def test_broadcast(self):
3676 self._test_broadcast(self,
lambda t: t)
3678 def test_broadcast_empty(self):
3680 self.assertRaises(RuntimeError,
lambda: torch.randn(5, 0) + torch.randn(0, 5))
3681 self.assertEqual(torch.randn(5, 0), torch.randn(0) + torch.randn(5, 0))
3682 self.assertEqual(torch.randn(5, 0, 0), torch.randn(0) + torch.randn(5, 0, 1))
3685 self.assertEqual(torch.randn(5, 0, 6), torch.randn(()) + torch.randn(5, 0, 6))
3688 self.assertEqual(torch.randn(0), torch.randn(0) + torch.randn(1))
3689 self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7),
3690 torch.randn(0, 7, 0, 6, 5, 0, 1) + torch.randn(1, 1, 5, 1, 7))
3691 self.assertRaises(RuntimeError,
lambda: torch.randn(7, 0) + torch.randn(2, 1))
3693 def test_broadcast_tensors(self):
3694 x0 = torch.randn(2, 1, 3)
3696 x2 = torch.randn(3, 1)
3697 expected_size = (2, 3, 3)
3699 y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2)
3700 self.assertTrue(y0.size() == expected_size)
3701 self.assertTrue(y1.size() == expected_size)
3702 self.assertTrue(y2.size() == expected_size)
3705 def _test_contiguous(self, cast):
3706 x = cast(torch.randn(1, 16, 5, 5))
3707 self.assertTrue(x.is_contiguous())
3708 stride = list(x.stride())
3711 x.set_(x.storage(), 0, x.size(), stride)
3712 self.assertTrue(x.is_contiguous())
3714 def test_contiguous(self):
3715 return self._test_contiguous(self,
lambda t: t)
3717 def test_empty_tensor_props(self):
3718 sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)]
3721 for device
in devices:
3722 x = torch.empty(tuple(size), device=device)
3723 self.assertEqual(size, x.shape)
3724 self.assertTrue(x.is_contiguous())
3725 size_ones_instead_of_zeros = (x
if x != 0
else 1
for x
in size)
3726 y = torch.empty(tuple(size_ones_instead_of_zeros), device=device)
3727 self.assertEqual(x.stride(), y.stride())
3729 def test_scalars_as_floats(self):
3730 "zero-dim variables that don't require grad should bind to scalar arguments" 3734 self.assertEqual(y.addcmul(y, y, value=x), 21)
3737 self.assertRaises(Exception,
lambda: y.addcmul(y, y, value=x))
3740 def _test_broadcast_fused_matmul(self, cast):
3741 fns = [
"baddbmm",
"addbmm",
"addmm",
"addmv",
"addr"]
3744 batch_dim = random.randint(1, 8)
3745 n_dim = random.randint(1, 8)
3746 m_dim = random.randint(1, 8)
3747 p_dim = random.randint(1, 8)
3749 def dims_full_for_fn():
3751 return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
3752 elif fn ==
"addbmm":
3753 return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
3755 return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
3757 return ([n_dim], [n_dim, m_dim], [m_dim])
3759 return ([n_dim, m_dim], [n_dim], [m_dim])
3761 raise AssertionError(
"unknown function")
3763 (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
3764 (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
3766 t0_small = cast(torch.randn(*t0_dims_small).float())
3767 t1 = cast(torch.randn(*t1_dims).float())
3768 t2 = cast(torch.randn(*t2_dims).float())
3770 t0_full = cast(t0_small.expand(*t0_dims_full))
3772 fntorch = getattr(torch, fn)
3773 r0 = fntorch(t0_small, t1, t2)
3774 r1 = fntorch(t0_full, t1, t2)
3775 self.assertEqual(r0, r1)
3777 def test_broadcast_fused_matmul(self):
3778 self._test_broadcast_fused_matmul(self,
lambda t: t)
3781 def _test_broadcast_batched_matmul(self, cast):
3782 n_dim = random.randint(1, 8)
3783 m_dim = random.randint(1, 8)
3784 p_dim = random.randint(1, 8)
3785 full_batch_dims = [random.randint(1, 3)
for i
in range(random.randint(1, 3))]
3786 (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims)
3788 def verify_batched_matmul(full_lhs, one_dimensional):
3789 if not one_dimensional:
3790 lhs_dims = [n_dim, m_dim]
3791 rhs_dims = [m_dim, p_dim]
3792 result_dims = [n_dim, p_dim]
3794 lhs_dims = [n_dim, m_dim]
if full_lhs
else [m_dim]
3795 rhs_dims = [m_dim, p_dim]
if not full_lhs
else [m_dim]
3796 result_dims = [n_dim]
if full_lhs
else [p_dim]
3798 lhs_mat_dims = lhs_dims
if len(lhs_dims) != 1
else [1, m_dim]
3799 rhs_mat_dims = rhs_dims
if len(rhs_dims) != 1
else [m_dim, 1]
3800 full_mat_dims = lhs_mat_dims
if full_lhs
else rhs_mat_dims
3801 dim0_dims = rhs_dims
if full_lhs
else lhs_dims
3802 small_dims = batch_dims_small + (rhs_mat_dims
if full_lhs
else lhs_mat_dims)
3804 small = cast(torch.randn(*(small_dims)).float())
3805 dim0 = cast(torch.randn(*(dim0_dims)).float())
3806 full = cast(torch.randn(*(full_batch_dims + full_mat_dims)).float())
3807 if not one_dimensional:
3808 (lhsTensors, rhsTensors) = ((full,), (small, dim0))
if full_lhs
else ((small, dim0), (full,))
3810 (lhsTensors, rhsTensors) = ((full,), (dim0,))
if full_lhs
else ((dim0,), (full,))
3812 def maybe_squeeze_result(l, r, result):
3813 if len(lhs_dims) == 1
and l.dim() != 1:
3814 return result.squeeze(-2)
3815 elif len(rhs_dims) == 1
and r.dim() != 1:
3816 return result.squeeze(-1)
3820 for lhs
in lhsTensors:
3821 lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims)))
3822 lhs_expanded_matmul_fn = getattr(lhs_expanded,
"matmul")
3823 for rhs
in rhsTensors:
3824 rhs_expanded = ((rhs
if len(rhs_dims) != 1
else rhs.unsqueeze(-1)).
3825 expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims))))
3826 truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded))
3827 for l
in (lhs, lhs_expanded):
3828 for r
in (rhs, rhs_expanded):
3829 l_matmul_fn = getattr(l,
"matmul")
3830 result = maybe_squeeze_result(l, r, l_matmul_fn(r))
3831 self.assertEqual(truth, result)
3833 torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
3834 self.assertEqual(truth, torch_result)
3836 out = torch.zeros_like(torch_result)
3837 torch.matmul(l, r, out=out)
3838 self.assertEqual(truth, maybe_squeeze_result(l, r, out))
3841 bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
3842 rhs_expanded.contiguous().view(-1, *rhs_mat_dims)))
3843 self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims))
3845 for indices
in product((
True,
False), repeat=2):
3846 verify_batched_matmul(*indices)
3848 def test_broadcast_batched_matmul(self):
3849 self._test_broadcast_batched_matmul(self,
lambda t: t)
3851 def test_copy_broadcast(self):
3852 torch.zeros(5, 6).copy_(torch.zeros(6))
3853 self.assertRaises(RuntimeError,
lambda: torch.zeros(5, 6).copy_(torch.zeros(30)))
3855 def test_randperm(self):
3856 _RNGState = torch.get_rng_state()
3857 res1 = torch.randperm(100)
3858 res2 = torch.LongTensor()
3859 torch.set_rng_state(_RNGState)
3860 torch.randperm(100, out=res2)
3861 self.assertEqual(res1, res2, 0)
3864 res1 = torch.randperm(0)
3865 res2 = torch.LongTensor(5)
3866 torch.randperm(0, out=res2)
3867 self.assertEqual(res1.numel(), 0)
3868 self.assertEqual(res2.numel(), 0)
3870 def test_random(self):
3872 t = torch.FloatTensor(200)
3878 self.assertEqual(t.min(), lb)
3879 self.assertEqual(t.max(), ub - 1)
3883 self.assertEqual(t.min(), 0)
3884 self.assertEqual(t.max(), ub - 1)
3887 def _test_random_neg_values(self, use_cuda=False):
3888 signed_types = [
'torch.DoubleTensor',
'torch.FloatTensor',
'torch.LongTensor',
3889 'torch.IntTensor',
'torch.ShortTensor']
3890 for tname
in signed_types:
3891 res = torch.rand(SIZE, SIZE).type(tname)
3894 res.random_(-10, -1)
3895 self.assertLessEqual(res.max().item(), 9)
3896 self.assertGreaterEqual(res.min().item(), -10)
3898 def test_random_neg_values(self):
3899 self._test_random_neg_values(self)
3901 def assertIsOrdered(self, order, x, mxx, ixx, task):
3903 if order ==
'descending':
3904 def check_order(a, b):
3908 return a != a
or a >= b
3909 elif order ==
'ascending':
3910 def check_order(a, b):
3912 return b != b
or a <= b
3914 error(
'unknown order "{}", must be "ascending" or "descending"'.format(order))
3917 for j, k
in product(range(SIZE), range(1, SIZE)):
3918 self.assertTrue(check_order(mxx[j][k - 1], mxx[j][k]),
3919 'torch.sort ({}) values unordered for {}'.format(order, task))
3922 indicesCorrect =
True 3923 size = x.size(x.dim() - 1)
3924 for k
in range(size):
3926 for j
in range(size):
3927 self.assertEqual(x[k][ixx[k][j]], mxx[k][j],
3928 'torch.sort ({}) indices wrong for {}'.format(order, task))
3930 self.assertEqual(len(seen), size)
3932 def test_sort(self):
3934 x = torch.rand(SIZE, SIZE)
3935 res1val, res1ind = torch.sort(x)
3938 res2val = torch.Tensor()
3939 res2ind = torch.LongTensor()
3940 torch.sort(x, out=(res2val, res2ind))
3941 self.assertEqual(res1val, res2val, 0)
3942 self.assertEqual(res1ind, res2ind, 0)
3943 self.assertEqual(torch.argsort(x), res1ind)
3944 self.assertEqual(x.argsort(), res1ind)
3947 self.assertIsOrdered(
'ascending', x, res2val, res2ind,
'random')
3951 torch.sort(torch.Tensor((50, 40, 30, 20, 10)))[0],
3952 torch.Tensor((10, 20, 30, 40, 50)),
3957 x = torch.floor(torch.rand(SIZE, SIZE) * 10)
3958 torch.sort(x, out=(res2val, res2ind))
3959 self.assertIsOrdered(
'ascending', x, res2val, res2ind,
'random with duplicate keys')
3962 x = torch.rand(SIZE, SIZE)
3963 res1val, res1ind = torch.sort(x, x.dim() - 1,
True)
3966 res2val = torch.Tensor()
3967 res2ind = torch.LongTensor()
3968 torch.sort(x, x.dim() - 1,
True, out=(res2val, res2ind))
3969 self.assertEqual(res1val, res2val, 0)
3970 self.assertEqual(res1ind, res2ind, 0)
3971 self.assertEqual(torch.argsort(x, x.dim() - 1,
True), res1ind)
3972 self.assertEqual(x.argsort(x.dim() - 1,
True), res1ind)
3975 self.assertIsOrdered(
'descending', x, res2val, res2ind,
'random')
3979 torch.sort(torch.Tensor((10, 20, 30, 40, 50)), 0,
True)[0],
3980 torch.Tensor((50, 40, 30, 20, 10)),
3985 self.assertIsOrdered(
'descending', x, res2val, res2ind,
'random with duplicate keys')
3988 x = torch.rand(SIZE, SIZE)
3989 x[1][2] = float(
'NaN')
3990 x[3][0] = float(
'NaN')
3991 torch.sort(x, out=(res2val, res2ind))
3992 self.assertIsOrdered(
'ascending', x, res2val, res2ind,
3994 torch.sort(x, out=(res2val, res2ind), descending=
True)
3995 self.assertIsOrdered(
'descending', x, res2val, res2ind,
3998 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
3999 def test_tensordot(self):
4002 a = torch.arange(60., device=d).reshape(3, 4, 5)
4003 b = torch.arange(24., device=d).reshape(4, 3, 2)
4004 c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
4005 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
4006 axes=([1, 0], [0, 1])))
4007 self.assertEqual(c, cn)
4008 a = torch.randn(2, 3, 4, 5, device=d)
4009 b = torch.randn(4, 5, 6, 7, device=d)
4010 c = torch.tensordot(a, b, dims=2).cpu()
4011 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
4013 self.assertEqual(c, cn)
4014 c = torch.tensordot(a, b).cpu()
4015 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
4016 self.assertEqual(c, cn)
4018 def test_topk(self):
4019 def topKViaSort(t, k, dim, dir):
4020 sorted, indices = t.sort(dim, dir)
4021 return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
4023 def compareTensors(t, res1, ind1, res2, ind2, dim):
4025 self.assertEqual(res1, res2, 0)
4029 if not ind1.eq(ind2).all():
4033 vals = t.gather(dim, ind2)
4034 self.assertEqual(res1, vals, 0)
4036 def compare(t, k, dim, dir):
4037 topKVal, topKInd = t.topk(k, dim, dir,
True)
4038 sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
4039 compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
4041 t = torch.rand(random.randint(1, SIZE),
4042 random.randint(1, SIZE),
4043 random.randint(1, SIZE))
4045 for _kTries
in range(3):
4046 for _dimTries
in range(3):
4047 for transpose
in (
True,
False):
4048 for dir
in (
True,
False):
4051 dim1 = random.randrange(t.ndimension())
4054 dim2 = random.randrange(t.ndimension())
4056 testTensor = t.transpose(dim1, dim2)
4058 dim = random.randrange(testTensor.ndimension())
4059 k = random.randint(1, testTensor.size(dim))
4060 compare(testTensor, k, dim, dir)
4062 def test_topk_arguments(self):
4063 q = torch.randn(10, 2, 10)
4065 self.assertRaises(TypeError,
lambda: q.topk(4,
True))
4068 def test_topk_noncontiguous_gpu(self):
4069 t = torch.randn(20, device=
"cuda")[::2]
4070 top1, idx1 = t.topk(5)
4071 top2, idx2 = t.contiguous().topk(5)
4072 self.assertEqual(top1, top2)
4073 self.assertEqual(idx1, idx2)
4076 def _test_kthvalue(self, device='cpu'):
4078 x = torch.rand(SIZE, SIZE, SIZE, device=device)
4081 k = random.randint(1, SIZE)
4082 res1val, res1ind = torch.kthvalue(x, k, keepdim=
False)
4083 res2val, res2ind = torch.sort(x)
4085 self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4086 self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4088 k = random.randint(1, SIZE)
4090 res1ind =
torch.tensor([], dtype=torch.long, device=device)
4091 torch.kthvalue(x, k, keepdim=
False, out=(res1val, res1ind))
4092 res2val, res2ind = torch.sort(x)
4093 self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4094 self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4097 k = random.randint(1, SIZE)
4098 res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=
False)
4099 res2val, res2ind = torch.sort(x, 0)
4100 self.assertEqual(res1val, res2val[k - 1], 0)
4101 self.assertEqual(res1ind, res2ind[k - 1], 0)
4104 y = x.narrow(1, 0, 1)
4106 k = random.randint(1, SIZE)
4107 res1val, res1ind = torch.kthvalue(y, k)
4108 res2val, res2ind = torch.kthvalue(y0, k)
4109 self.assertEqual(res1val, res2val, 0)
4110 self.assertEqual(res1ind, res2ind, 0)
4113 self.assertEqual(x, x0, 0)
4117 self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0)
4118 self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0)
4122 x = torch.rand(SIZE, SIZE, SIZE, device=device)
4123 x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan
4124 ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1]
4125 res2val, res2ind = torch.sort(x)
4127 res1val, res1ind = torch.kthvalue(x, k, keepdim=
False)
4128 self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4129 self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4131 def test_kthvalue(self):
4132 self._test_kthvalue(self)
4134 def test_median(self):
4135 for size
in (155, 156):
4136 x = torch.rand(size, size)
4139 nelem = x.nelement()
4140 res1val = torch.median(x)
4141 res2val, _ = torch.sort(x.view(nelem))
4142 ind = int(math.floor((nelem + 1) / 2) - 1)
4144 self.assertEqual(res2val[ind], res1val, 0)
4146 res1val, res1ind = torch.median(x, dim=1, keepdim=
False)
4147 res2val, res2ind = torch.sort(x)
4148 ind = int(math.floor((size + 1) / 2) - 1)
4150 self.assertEqual(res2val.select(1, ind), res1val, 0)
4151 self.assertEqual(res2val.select(1, ind), res1val, 0)
4154 res2val = torch.Tensor()
4155 res2ind = torch.LongTensor()
4156 torch.median(x, dim=-1, keepdim=
False, out=(res2val, res2ind))
4157 self.assertEqual(res2val, res1val, 0)
4158 self.assertEqual(res2ind, res1ind, 0)
4161 res1val, res1ind = torch.median(x, 0, keepdim=
False)
4162 res2val, res2ind = torch.sort(x, 0)
4163 self.assertEqual(res1val, res2val[ind], 0)
4164 self.assertEqual(res1ind, res2ind[ind], 0)
4167 self.assertEqual(x, x0, 0)
4169 def test_mode(self):
4170 x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE)
4176 res1val = torch.Tensor(SIZE).fill_(1)
4178 res1ind = torch.LongTensor(SIZE).fill_(1)
4179 res1ind[0] = SIZE - 1
4180 res1ind[1] = SIZE - 1
4182 res2val, res2ind = torch.mode(x, keepdim=
False)
4183 self.assertEqual(res1val, res2val, 0)
4184 self.assertEqual(res1ind, res2ind, 0)
4187 res2val = torch.Tensor()
4188 res2ind = torch.LongTensor()
4189 torch.mode(x, keepdim=
False, out=(res2val, res2ind))
4190 self.assertEqual(res1val, res2val, 0)
4191 self.assertEqual(res1ind, res2ind, 0)
4194 res2val, res2ind = torch.mode(x, 0,
False)
4195 self.assertEqual(res1val, res2val, 0)
4196 self.assertEqual(res1ind, res2ind, 0)
4199 self.assertEqual(x, x0, 0)
4201 def test_trilu_indices(self):
4202 for test_args
in tri_tests_args:
4203 _compare_trilu_indices(self, *test_args)
4204 run_additional_tri_tests(self,
'cpu')
4208 3, 3, dtype=torch.long, device=
'cpu', layout=torch.strided)
4210 x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3))
4212 x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3))
4215 def _test_triu_tril(self, cast):
4216 def gen_mask(shape, diagonal, cast, upper):
4217 mask = torch.zeros(*shape[-2:]).byte()
4218 for i
in range(shape[-2]):
4219 for j
in range(shape[-1]):
4220 cond = j - i < diagonal
if upper
else j - i > diagonal
4223 return cast(mask.expand(*shape))
4225 torch_functions = {
True: torch.triu,
False: torch.tril}
4227 numpy_functions = {
True: np.triu,
False: np.tril}
4229 def run_test(shape, cast, diagonal):
4230 x_cpu = torch.randn(*shape)
4233 for upper
in [
True,
False]:
4235 torch_tri_func = torch_functions[upper]
4236 res1 = torch_tri_func(x, diagonal=diagonal)
4237 res2 = cast(torch.Tensor())
4238 torch_tri_func(x, diagonal=diagonal, out=res2)
4239 exp_mask = gen_mask(shape, diagonal, cast, upper)
4240 expected = torch.where(exp_mask,
torch.tensor(0).type_as(x), x)
4241 self.assertEqual(res1, res2, 0)
4242 self.assertEqual(expected, res1, 0)
4245 if not (0
in shape
or 1
in shape):
4246 for s
in range(-len(shape), -1):
4248 x_nc = x.clone().transpose(s, s + 1)
4249 exp_mask = gen_mask(x_nc.size(), diagonal, cast, upper)
4250 assert not x_nc.is_contiguous(),
"x is intentionally non-contiguous" 4251 exp_nc = torch.where(exp_mask,
torch.tensor(0).type_as(x), x_nc)
4252 self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
4253 x_nc_is_contiguous = x_nc.is_contiguous()
4255 self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
4257 self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
4259 self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous,
4260 "contiguity of x_nc should not be changed")
4263 expanded_size = (x.size(0),) + x.size()
4264 x_expanded = x.clone().expand(*expanded_size)
4265 assert 0
in x_expanded.stride(),
"x intentionally has 0 in its stride" 4266 output = torch_tri_func(x_expanded, diagonal)
4267 self.assertEqual(output, expected.expand(expanded_size), 0)
4268 self.assertTrue(0
in x_expanded.stride(),
4269 "geometry of x_expanded should be the same")
4271 self.assertEqual(output, x_expanded.triu_(diagonal), 0)
4273 self.assertEqual(output, x_expanded.tril_(diagonal), 0)
4279 numpy_tri_func = numpy_functions[upper]
4280 self.assertEqual(numpy_tri_func(x_cpu.numpy(), diagonal), res1.cpu().numpy())
4282 diagonals = [-2, -1, 0, 1, 2]
4283 shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3),
4284 (7, 3), (5, 7, 3), (7, 5, 7, 3),
4285 (3, 7), (5, 3, 7), (7, 5, 3, 7),
4286 (3, 0), (0, 3, 3), (3, 3, 0, 0),
4287 (3, 1), (5, 3, 1), (7, 5, 3, 1),
4288 (1, 3), (5, 1, 3), (7, 5, 1, 3)]
4289 for s, d
in product(shapes, diagonals):
4290 run_test(s, cast, d)
4292 def test_triu_tril(self):
4293 self._test_triu_tril(self,
lambda t: t)
4297 for dtype
in (torch.half, torch.double, torch.int):
4298 for dim
in range(-3, 3):
4299 pos_dim = dim
if dim >= 0
else 3 + dim
4300 x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4301 y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4302 z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4304 res1 = torch.cat((x, y, z), dim)
4305 self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
4306 self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
4307 self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)
4309 x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE)).to(dtype)
4310 self.assertEqual(torch.cat(torch.split(x, 7)), x)
4311 self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
4313 y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE)).to(dtype)
4314 z = torch.cat([x, y])
4315 self.assertEqual(z.size(), (21, SIZE, SIZE))
4317 self.assertRaises(RuntimeError,
lambda: torch.cat([]))
4318 self.assertRaisesRegex(TypeError,
'got None',
lambda: torch.cat([x,
None]))
4320 def test_cat_bad_input_sizes(self):
4321 x = torch.randn(2, 1)
4322 y = torch.randn(2, 1, 1)
4323 z = torch.randn(2, 1, 1)
4324 self.assertRaises(RuntimeError,
lambda: torch.cat([x, y, z]))
4326 x = torch.randn(2, 1, 2)
4327 y = torch.randn(2, 1, 1)
4328 z = torch.randn(2, 2, 1)
4329 self.assertRaises(RuntimeError,
lambda: torch.cat([x, y, z], dim=1))
4331 def test_cat_scalars(self):
4334 with self.assertRaisesRegex(RuntimeError,
'zero-dimensional.*cannot be concatenated'):
4338 def _test_cat_empty_legacy(self, use_cuda=False):
4341 dtype = torch.float32
4342 device =
'cuda' if use_cuda
else 'cpu' 4344 x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
4345 empty = torch.randn((0,), dtype=dtype, device=device)
4347 res1 = torch.cat([x, empty], dim=1)
4348 res2 = torch.cat([empty, x], dim=1)
4349 self.assertEqual(res1, res2)
4351 conv = torch.nn.Conv2d(3, 3, kernel_size=1).float()
4354 res1 = torch.cat([
conv(x), empty], dim=1)
4355 res2 = torch.cat([empty,
conv(x)], dim=1)
4356 self.assertEqual(res1, res2)
4358 res1 = torch.cat([empty, empty], dim=1)
4359 self.assertEqual(res1, empty)
4361 with self.assertRaisesRegex(RuntimeError,
4362 'expected a non-empty list of Tensors'):
4363 torch.cat([], dim=1)
4365 def test_cat_empty_legacy(self):
4366 self._test_cat_empty_legacy(self)
4369 def _test_cat_empty(self, use_cuda=False):
4370 dtype = torch.float32
4371 device =
'cuda' if use_cuda
else 'cpu' 4373 x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
4374 empty = torch.randn((4, 0, 32, 32), dtype=dtype, device=device)
4376 res1 = torch.cat([x, empty], dim=1)
4377 res2 = torch.cat([empty, x], dim=1)
4378 self.assertEqual(res1, res2)
4380 conv = torch.nn.Conv2d(3, 3, kernel_size=1).float()
4383 res1 = torch.cat([
conv(x), empty], dim=1)
4384 res2 = torch.cat([empty,
conv(x)], dim=1)
4385 self.assertEqual(res1, res2)
4387 res1 = torch.cat([empty, empty], dim=1)
4388 self.assertEqual(res1, empty)
4391 empty = torch.randn((4, 0, 31, 32), dtype=dtype, device=device)
4392 self.assertRaises(RuntimeError,
lambda: torch.cat([x, empty], dim=1))
4393 self.assertRaises(RuntimeError,
lambda: torch.cat([empty, x], dim=1))
4396 empty = torch.randn((4, 0), dtype=dtype, device=device)
4397 self.assertRaises(RuntimeError,
lambda: torch.cat([x, empty], dim=1))
4398 self.assertRaises(RuntimeError,
lambda: torch.cat([empty, x], dim=1))
4400 def test_cat_empty(self):
4401 self._test_cat_empty(self)
4403 def test_narrow(self):
4404 x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
4405 self.assertEqual(x.narrow(0, 0, 1), torch.Tensor([[0, 1, 2]]))
4406 self.assertEqual(x.narrow(0, 0, 2), torch.Tensor([[0, 1, 2], [3, 4, 5]]))
4407 self.assertEqual(x.narrow(0, 1, 1), torch.Tensor([[3, 4, 5]]))
4408 self.assertEqual(x.narrow(0, -1, 1), torch.Tensor([[6, 7, 8]]))
4409 self.assertEqual(x.narrow(0, -2, 2), torch.Tensor([[3, 4, 5], [6, 7, 8]]))
4410 self.assertEqual(x.narrow(0, -3, 3), torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
4411 self.assertEqual(x.narrow(-1, -1, 1), torch.Tensor([[2], [5], [8]]))
4412 self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]]))
4414 def test_narrow_empty(self):
4416 for device
in devices:
4417 x = torch.randn(2, 3, 4, device=device)
4418 for d
in range(x.dim()):
4419 y = x.narrow(d, x.size(d), 0)
4422 self.assertEqual(sz, y.size())
4424 def test_stack(self):
4425 for dtype
in (torch.half, torch.double, torch.int):
4426 x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4427 y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4428 z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4429 for dim
in range(4):
4430 res = torch.stack((x, y, z), dim)
4431 res_neg = torch.stack((x, y, z), dim - 4)
4432 expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
4433 self.assertEqual(res, res_neg)
4434 self.assertEqual(res.size(), expected_size)
4435 self.assertEqual(res.select(dim, 0), x, 0)
4436 self.assertEqual(res.select(dim, 1), y, 0)
4437 self.assertEqual(res.select(dim, 2), z, 0)
4439 def test_stack_out(self):
4440 for dtype
in (torch.half, torch.double, torch.int):
4441 x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4442 y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4443 z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4444 for dim
in range(4):
4445 expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
4446 res_out = x.new(expected_size)
4447 res_neg_out = x.new(expected_size)
4448 res_out_dp = res_out.data_ptr()
4449 res_out_neg_dp = res_neg_out.data_ptr()
4450 torch.stack((x, y, z), dim, out=res_out)
4451 torch.stack((x, y, z), dim - 4, out=res_neg_out)
4452 self.assertEqual(res_out, res_neg_out)
4453 self.assertEqual(res_out.size(), expected_size)
4454 self.assertEqual(res_out_dp, res_out.data_ptr())
4455 self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr())
4456 self.assertEqual(res_out.select(dim, 0), x, 0)
4457 self.assertEqual(res_out.select(dim, 1), y, 0)
4458 self.assertEqual(res_out.select(dim, 2), z, 0)
4460 def test_unbind(self):
4461 x = torch.rand(2, 3, 4, 5)
4462 for dim
in range(4):
4463 res = torch.unbind(x, dim)
4464 res2 = x.unbind(dim)
4465 self.assertEqual(x.size(dim), len(res))
4466 self.assertEqual(x.size(dim), len(res2))
4467 for i
in range(dim):
4468 self.assertEqual(x.select(dim, i), res[i])
4469 self.assertEqual(x.select(dim, i), res2[i])
4472 def test_linspace(self):
4474 for device
in devices:
4475 _from = random.random()
4476 to = _from + random.random()
4477 res1 = torch.linspace(_from, to, 137, device=device)
4479 torch.linspace(_from, to, 137, out=res2)
4480 self.assertEqual(res1, res2, 0)
4481 self.assertRaises(RuntimeError,
lambda: torch.linspace(0, 1, -1, device=device))
4482 self.assertEqual(torch.linspace(0, 1, 1, device=device), torch.zeros(1, device=device), 0)
4485 self.assertEqual(torch.linspace(2, 0, 3, device=device),
torch.tensor((2, 1, 0), device=device), 0)
4488 x = torch.zeros(2, 3, device=device)
4489 y = torch.linspace(0, 3, 4, out=x.narrow(1, 1, 2))
4490 self.assertEqual(x,
torch.tensor(((0, 0, 1), (0, 2, 3)), device=device), 0)
4492 def test_logspace(self):
4493 _from = random.random()
4494 to = _from + random.random()
4495 res1 = torch.logspace(_from, to, 137)
4496 res2 = torch.Tensor()
4497 torch.logspace(_from, to, 137, out=res2)
4498 self.assertEqual(res1, res2, 0)
4499 self.assertRaises(RuntimeError,
lambda: torch.logspace(0, 1, -1))
4500 self.assertEqual(torch.logspace(0, 1, 1), torch.ones(1), 0)
4503 self.assertEqual(torch.logspace(1, 0, 2), torch.Tensor((10, 1)), 0)
4506 x = torch.zeros(2, 3)
4507 y = torch.logspace(0, 3, 4, out=x.narrow(1, 1, 2))
4508 self.assertEqual(x, torch.Tensor(((0, 1, 10), (0, 100, 1000))), 0)
4510 def test_rand(self):
4511 torch.manual_seed(123456)
4512 res1 = torch.rand(SIZE, SIZE)
4513 res2 = torch.Tensor()
4514 torch.manual_seed(123456)
4515 torch.rand(SIZE, SIZE, out=res2)
4516 self.assertEqual(res1, res2)
4518 def test_randint(self):
4519 torch.manual_seed(123456)
4520 res1 = torch.randint(0, 6, (SIZE, SIZE))
4521 res2 = torch.Tensor()
4522 torch.manual_seed(123456)
4523 torch.randint(0, 6, (SIZE, SIZE), out=res2)
4524 torch.manual_seed(123456)
4525 res3 = torch.randint(6, (SIZE, SIZE))
4526 res4 = torch.Tensor()
4527 torch.manual_seed(123456)
4528 torch.randint(6, (SIZE, SIZE), out=res4)
4529 self.assertEqual(res1, res2)
4530 self.assertEqual(res1, res3)
4531 self.assertEqual(res1, res4)
4532 self.assertEqual(res2, res3)
4533 self.assertEqual(res2, res4)
4534 self.assertEqual(res3, res4)
4535 res1 = res1.view(-1)
4536 high = (res1 < 6).type(torch.LongTensor)
4537 low = (res1 >= 0).type(torch.LongTensor)
4538 tensorSize = res1.size()[0]
4539 assert(tensorSize == high.sum())
4540 assert(tensorSize == low.sum())
4542 def test_randn(self):
4543 torch.manual_seed(123456)
4544 res1 = torch.randn(SIZE, SIZE)
4545 res2 = torch.Tensor()
4546 torch.manual_seed(123456)
4547 torch.randn(SIZE, SIZE, out=res2)
4548 self.assertEqual(res1, res2)
4550 def test_slice(self):
4551 empty = torch.empty(0, 4)
4552 x = torch.arange(0., 16).view(4, 4)
4553 self.assertEqual(x[:], x)
4554 self.assertEqual(x[:4], x)
4556 self.assertEqual(x[:5], x)
4558 self.assertEqual(x[2:1], empty)
4559 self.assertEqual(x[2:2], empty)
4561 self.assertEqual(x[10:12], empty)
4563 self.assertEqual(x[:1].data.tolist(), [[0, 1, 2, 3]])
4564 self.assertEqual(x[:-3].data.tolist(), [[0, 1, 2, 3]])
4565 self.assertEqual(x[:, -2:3].data.tolist(), [[2], [6], [10], [14]])
4566 self.assertEqual(x[0:-1:2].data.tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]])
4568 def test_is_signed(self):
4569 self.assertEqual(torch.IntTensor(5).is_signed(),
True)
4570 self.assertEqual(torch.ByteTensor(5).is_signed(),
False)
4571 self.assertEqual(torch.CharTensor(5).is_signed(),
True)
4572 self.assertEqual(torch.FloatTensor(5).is_signed(),
True)
4573 self.assertEqual(torch.HalfTensor(10).is_signed(),
True)
4576 def test_is_signed_cuda(self):
4577 self.assertEqual(torch.cuda.IntTensor(5).is_signed(),
True)
4578 self.assertEqual(torch.cuda.ByteTensor(5).is_signed(),
False)
4579 self.assertEqual(torch.cuda.CharTensor(5).is_signed(),
True)
4580 self.assertEqual(torch.cuda.FloatTensor(5).is_signed(),
True)
4581 self.assertEqual(torch.cuda.HalfTensor(10).is_signed(),
True)
4584 def _test_solve(self, cast):
4585 a = cast(torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
4586 (-6.05, -3.30, 5.36, -4.44, 1.08),
4587 (-0.45, 2.58, -2.70, 0.27, 9.04),
4588 (8.32, 2.71, 4.35, -7.17, 2.14),
4589 (-9.67, -5.14, -7.26, 6.08, -6.87)))).t()
4590 b = cast(torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
4591 (-1.56, 4.00, -8.67, 1.75, 2.86),
4592 (9.81, -4.09, -4.57, -8.61, 8.99)))).t()
4594 res1 = torch.solve(b, a)[0]
4595 self.assertLessEqual(b.dist(torch.mm(a, res1)), 1e-12)
4597 ta = cast(torch.Tensor())
4598 tb = cast(torch.Tensor())
4599 res2 = torch.solve(b, a, out=(tb, ta))[0]
4600 res3 = torch.solve(b, a, out=(b, a))[0]
4601 self.assertEqual(res1, tb)
4602 self.assertEqual(res1, b)
4603 self.assertEqual(res1, res2)
4604 self.assertEqual(res1, res3)
4607 res1 = torch.solve(b, a)[0]
4608 ta = cast(torch.Tensor())
4609 tb = cast(torch.Tensor())
4610 torch.solve(b, a, out=(tb, ta))[0]
4611 self.assertEqual(res1, tb)
4612 torch.solve(b, a, out=(tb, ta))[0]
4613 self.assertEqual(res1, tb)
4616 def test_solve(self):
4617 self._test_solve(self,
lambda t: t)
4620 def _test_solve_batched(self, cast):
4621 from common_utils
import random_fullrank_matrix_distinct_singular_value
4623 A = cast(random_fullrank_matrix_distinct_singular_value(5, 1))
4624 b = cast(torch.randn(1, 5, 10))
4625 x_exp, LU_exp = torch.solve(b.squeeze(0), A.squeeze(0))
4626 x, LU = torch.solve(b, A)
4627 self.assertEqual(x, x_exp.unsqueeze(0))
4628 self.assertEqual(LU, LU_exp.unsqueeze(0))
4631 A = cast(random_fullrank_matrix_distinct_singular_value(5, 4))
4632 b = cast(torch.randn(4, 5, 10))
4637 x_exp, LU_exp = torch.solve(b[i], A[i])
4638 x_exp_list.append(x_exp)
4639 LU_exp_list.append(LU_exp)
4640 x_exp = torch.stack(x_exp_list)
4641 LU_exp = torch.stack(LU_exp_list)
4643 x, LU = torch.solve(b, A)
4644 self.assertEqual(x, x_exp)
4645 self.assertEqual(LU, LU_exp)
4648 A = cast(random_fullrank_matrix_distinct_singular_value(5, 3))
4649 b = cast(torch.randn(3, 5, 10))
4650 x, LU = torch.solve(b, A)
4651 self.assertEqual(torch.matmul(A, x), b)
4657 from numpy.linalg
import solve
4658 A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2)
4659 b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0)
4660 x, _ = torch.solve(b, A)
4661 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4662 self.assertEqual(x.data, cast(x_exp))
4665 def test_solve_batched(self):
4666 self._test_solve_batched(self,
lambda t: t)
4669 def _test_solve_batched_dims(self, cast):
4673 from numpy.linalg
import solve
4674 from common_utils
import random_fullrank_matrix_distinct_singular_value
4676 A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3))
4677 b = cast(torch.randn(2, 1, 3, 4, 6))
4678 x, _ = torch.solve(b, A)
4679 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4680 self.assertEqual(x.data, cast(x_exp))
4683 A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3)).transpose(-2, -1)
4684 b = cast(torch.randn(2, 1, 3, 6, 4)).transpose(-2, -1)
4685 assert not A.is_contiguous()
4686 assert not b.is_contiguous()
4687 x, _ = torch.solve(b, A)
4688 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4689 self.assertEqual(x.data, cast(x_exp))
4692 A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3))
4693 b = cast(torch.randn(4, 6))
4694 x, _ = torch.solve(b, A)
4695 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4696 self.assertEqual(x.data, cast(x_exp))
4699 A = cast(random_fullrank_matrix_distinct_singular_value(4))
4700 b = cast(torch.randn(2, 1, 3, 4, 2))
4701 x, _ = torch.solve(b, A)
4702 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4703 self.assertEqual(x.data, cast(x_exp))
4706 A = cast(random_fullrank_matrix_distinct_singular_value(4, 1, 3, 1))
4707 b = cast(torch.randn(2, 1, 3, 4, 5))
4708 x, _ = torch.solve(b, A)
4709 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4710 self.assertEqual(x.data, cast(x_exp))
4713 def test_solve_batched_dims(self):
4714 self._test_solve_batched_dims(self,
lambda t: t)
4716 def test_solve_methods_arg_device(self):
4720 for b_device, A_device
in product([
'cpu',
'cuda'], repeat=2):
4721 if b_device == A_device:
4724 b = torch.randn(3, 1, device=b_device)
4725 A = torch.randn(3, 3, device=A_device)
4726 err_str =
"Expected b and A to be on the same device" 4727 with self.assertRaisesRegex(RuntimeError, err_str):
4730 with self.assertRaisesRegex(RuntimeError, err_str):
4731 torch.cholesky_solve(b, A)
4738 def canonicalize(q, r):
4739 d = r.diag().sign().diag()
4740 return torch.mm(q, d), torch.mm(d, r)
4742 def canon_and_check(q, r, expected_q, expected_r):
4743 q_canon, r_canon = canonicalize(q, r)
4744 expected_q_canon, expected_r_canon = canonicalize(expected_q, expected_r)
4745 self.assertEqual(q_canon, expected_q_canon)
4746 self.assertEqual(r_canon, expected_r_canon)
4748 def check_qr(a, expected_q, expected_r):
4751 canon_and_check(q, r, expected_q, expected_r)
4754 q, r = torch.Tensor(), torch.Tensor()
4755 torch.qr(a, out=(q, r))
4756 canon_and_check(q, r, expected_q, expected_r)
4762 result, tau = torch.geqrf(a)
4763 self.assertEqual(result.size(0), m)
4764 self.assertEqual(result.size(1), n)
4765 self.assertEqual(tau.size(0), k)
4766 r = torch.triu(result.narrow(0, 0, k))
4767 q = torch.orgqr(result, tau)
4768 q, r = q.narrow(1, 0, k), r
4769 canon_and_check(q, r, expected_q, expected_r)
4772 a = torch.Tensor(((1, 2, 3), (4, 5, 6), (7, 8, 10)))
4774 expected_q = torch.Tensor((
4775 (-1.230914909793328e-01, 9.045340337332914e-01, 4.082482904638621e-01),
4776 (-4.923659639173310e-01, 3.015113445777629e-01, -8.164965809277264e-01),
4777 (-8.616404368553292e-01, -3.015113445777631e-01, 4.082482904638634e-01)))
4778 expected_r = torch.Tensor((
4779 (-8.124038404635959e+00, -9.601136296387955e+00, -1.193987e+01),
4780 (0.000000000000000e+00, 9.045340337332926e-01, 1.507557e+00),
4781 (0.000000000000000e+00, 0.000000000000000e+00, 4.082483e-01)))
4783 check_qr(a, expected_q, expected_r)
4792 expected_q = torch.Tensor((
4793 (-0.0776150525706334, -0.833052161400748, 0.3651483716701106),
4794 (-0.3104602102825332, -0.4512365874254053, -0.1825741858350556),
4795 (-0.5433053679944331, -0.0694210134500621, -0.7302967433402217),
4796 (-0.7761505257063329, 0.3123945605252804, 0.5477225575051663)
4798 expected_r = torch.Tensor((
4799 (-12.8840987267251261, -14.5916298832790581, -17.0753115655393231),
4800 (0, -1.0413152017509357, -1.770235842976589),
4801 (0, 0, 0.5477225575051664)
4804 check_qr(a, expected_q, expected_r)
4812 expected_q = torch.Tensor((
4813 (-0.0966736489045663, 0.907737593658436, 0.4082482904638653),
4814 (-0.4833682445228317, 0.3157348151855452, -0.8164965809277254),
4815 (-0.870062840141097, -0.2762679632873518, 0.4082482904638621)
4817 expected_r = torch.Tensor((
4818 (-1.0344080432788603e+01, -1.1794185166357092e+01,
4819 -1.3244289899925587e+01, -1.5564457473635180e+01),
4820 (0.0000000000000000e+00, 9.4720444555662542e-01,
4821 1.8944088911132546e+00, 2.5653453733825331e+00),
4822 (0.0000000000000000e+00, 0.0000000000000000e+00,
4823 1.5543122344752192e-15, 4.0824829046386757e-01)
4825 check_qr(a, expected_q, expected_r)
4828 a = torch.randn(1000, 1000)
4830 a_qr = torch.mm(q, r)
4831 self.assertEqual(a, a_qr, prec=1e-3)
4834 def test_ormqr(self):
4835 mat1 = torch.randn(7, 7)
4836 mat2 = torch.randn(7, 7)
4837 q, r = torch.qr(mat1)
4838 m, tau = torch.geqrf(mat1)
4839 out_holder = torch.empty_like(mat1)
4841 res1 = torch.mm(q, mat2)
4842 res2 = torch.ormqr(m, tau, mat2, left=
True, transpose=
False)
4843 torch.ormqr(m, tau, mat2, out=out_holder)
4844 self.assertEqual(res1, res2)
4845 self.assertEqual(res2, out_holder)
4847 res1 = torch.mm(mat2, q)
4848 res2 = torch.ormqr(m, tau, mat2, left=
False, transpose=
False)
4849 torch.ormqr(m, tau, mat2, left=
False, transpose=
False, out=out_holder)
4850 self.assertEqual(res1, res2)
4851 self.assertEqual(res2, out_holder)
4853 res1 = torch.mm(q.t(), mat2)
4854 res2 = torch.ormqr(m, tau, mat2, left=
True, transpose=
True)
4855 torch.ormqr(m, tau, mat2, left=
True, transpose=
True, out=out_holder)
4856 self.assertEqual(res1, res2)
4857 self.assertEqual(res2, out_holder)
4859 res1 = torch.mm(mat2, q.t())
4860 res2 = torch.ormqr(m, tau, mat2, left=
False, transpose=
True)
4861 torch.ormqr(m, tau, mat2, left=
False, transpose=
True, out=out_holder)
4862 self.assertEqual(res1, res2)
4863 self.assertEqual(res2, out_holder)
4866 def _test_geqrf(self, cast):
4867 a = cast(torch.randn(5, 5))
4868 b, c = torch.geqrf(a)
4869 b_placeholder, c_placeholder = torch.empty_like(b), torch.empty_like(c)
4870 torch.geqrf(a, out=(b_placeholder, c_placeholder))
4871 self.assertEqual(b, b_placeholder)
4872 self.assertEqual(c, c_placeholder)
4875 def test_geqrf(self):
4876 self._test_geqrf(self,
lambda t: t)
4879 def _test_trtrs(self, cast):
4880 a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
4881 (-6.05, -3.30, 5.36, -4.44, 1.08),
4882 (-0.45, 2.58, -2.70, 0.27, 9.04),
4883 (8.32, 2.71, 4.35, -7.17, 2.14),
4884 (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
4885 b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
4886 (-1.56, 4.00, -8.67, 1.75, 2.86),
4887 (9.81, -4.09, -4.57, -8.61, 8.99))).t()
4896 x = torch.trtrs(b, U)[0]
4897 self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
4898 x = torch.trtrs(b, U,
True,
False,
False)[0]
4899 self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
4902 x = torch.trtrs(b, L,
False)[0]
4903 self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
4904 x = torch.trtrs(b, L,
False,
False,
False)[0]
4905 self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
4908 x = torch.trtrs(b, U,
True,
True)[0]
4909 self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
4910 x = torch.trtrs(b, U,
True,
True,
False)[0]
4911 self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
4914 y = torch.trtrs(b, U.t(),
False,
False)[0]
4915 self.assertLessEqual(x.dist(y), 1e-12)
4918 x = torch.trtrs(b, L,
False,
True)[0]
4919 self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
4920 x = torch.trtrs(b, L,
False,
True,
False)[0]
4921 self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
4924 y = torch.trtrs(b, L.t(),
True,
False)[0]
4925 self.assertLessEqual(x.dist(y), 1e-12)
4928 res1 = torch.trtrs(b, a)[0]
4929 ta = cast(torch.Tensor())
4930 tb = cast(torch.Tensor())
4931 torch.trtrs(b, a, out=(tb, ta))
4932 self.assertEqual(res1, tb, 0)
4934 torch.trtrs(b, a, out=(tb, ta))
4935 self.assertEqual(res1, tb, 0)
4938 def test_trtrs(self):
4939 self._test_trtrs(self,
lambda t: t)
4942 def _test_trtrs_batched(self, cast):
4943 def trtrs_test_helper(A_dims, b_dims, cast, upper, unitriangular):
4944 A = cast(torch.randn(*A_dims))
4945 A = A.triu()
if upper
else A.tril()
4947 A.diagonal(dim1=-2, dim2=-1).fill_(1.)
4948 b = cast(torch.randn(*b_dims))
4951 for upper, transpose, unitriangular
in product([
True,
False], repeat=3):
4953 A, b = trtrs_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular)
4954 x_exp = torch.trtrs(b.squeeze(0), A.squeeze(0),
4955 upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4956 x = torch.trtrs(b, A,
4957 upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4958 self.assertEqual(x, x_exp.unsqueeze(0))
4961 A, b = trtrs_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular)
4964 x_exp = torch.trtrs(b[i], A[i],
4965 upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4966 x_exp_list.append(x_exp)
4967 x_exp = torch.stack(x_exp_list)
4969 x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4970 self.assertEqual(x, x_exp)
4973 A, b = trtrs_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular)
4974 x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4976 self.assertLessEqual(b.dist(torch.matmul(A.transpose(-1, -2), x)), 2e-12)
4978 self.assertLessEqual(b.dist(torch.matmul(A, x)), 2e-12)
4981 def test_trtrs_batched(self):
4982 _TestTorchMixin._test_trtrs_batched(self,
lambda t: t)
4985 def _test_trtrs_batched_dims(self, cast):
4989 from scipy.linalg
import solve_triangular
as tri_solve
4991 def scipy_tri_solve_batched(A, B, upper, trans, diag):
4992 batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2]
4993 single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:]
4994 expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A),
4995 torch.Size(batch_dims_B)))
4996 expand_A = np.broadcast_to(A, expand_dims + single_dim_A)
4997 expand_B = np.broadcast_to(B, expand_dims + single_dim_B)
4998 flat_A = expand_A.reshape((-1,) + single_dim_A)
4999 flat_B = expand_B.reshape((-1,) + single_dim_B)
5000 flat_X = np.vstack([tri_solve(a, b, lower=(
not upper), trans=int(trans), unit_diagonal=diag)
5001 for a, b
in zip(flat_A, flat_B)])
5002 return flat_X.reshape(expand_B.shape)
5004 def run_test(A_dims, b_dims, cast, upper, transpose, unitriangular):
5005 A = torch.randn(*A_dims)
5006 A = A.triu()
if upper
else A.tril()
5008 A.diagonal(dim1=-2, dim2=-1).fill_(1.)
5009 b = torch.randn(*b_dims)
5010 x_exp = torch.Tensor(scipy_tri_solve_batched(A.numpy(), b.numpy(),
5011 upper, transpose, unitriangular))
5012 A, b = cast(A), cast(b)
5013 x = torch.trtrs(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
5015 self.assertEqual(x, cast(x_exp))
5017 for upper, transpose, unitriangular
in product([
True,
False], repeat=3):
5019 run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast, upper, transpose, unitriangular)
5020 run_test((2, 1, 3, 4, 4), (4, 6), cast, upper, transpose, unitriangular)
5021 run_test((4, 4), (2, 1, 3, 4, 2), cast, upper, transpose, unitriangular)
5022 run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper, transpose, unitriangular)
5025 def test_trtrs_batched_dims(self):
5026 self._test_trtrs_batched_dims(self,
lambda t: t)
5029 def test_gels(self):
5030 def _test_underdetermined(a, b, expectedNorm):
5037 res1 = torch.gels(b, a)[0]
5038 self.assertEqual(a, a_copy, 0)
5039 self.assertEqual(b, b_copy, 0)
5040 self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8)
5044 res2 = torch.gels(b, a, out=(tb, ta))[0]
5045 self.assertEqual(a, a_copy, 0)
5046 self.assertEqual(b, b_copy, 0)
5047 self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8)
5049 res3 = torch.gels(b, a, out=(b, a))[0]
5050 self.assertEqual((torch.mm(a_copy, b) - b_copy).norm(), expectedNorm, 1e-8)
5051 self.assertEqual(res1, tb, 0)
5052 self.assertEqual(res1, b, 0)
5053 self.assertEqual(res1, res2, 0)
5054 self.assertEqual(res1, res3, 0)
5056 def _test_overdetermined(a, b, expectedNorm):
5061 def check_norm(a, b, expected_norm, gels_result):
5068 resid_info = gels_result[n:]
5070 resid_norm = (torch.mm(a, x) - b).norm()
5071 self.assertEqual(resid_norm, expectedNorm, 1e-8)
5072 self.assertEqual(resid_info.norm(), resid_norm, 1e-8)
5076 res1 = torch.gels(b, a)[0]
5077 self.assertEqual(a, a_copy, 0)
5078 self.assertEqual(b, b_copy, 0)
5079 check_norm(a, b, expectedNorm, res1)
5083 res2 = torch.gels(b, a, out=(tb, ta))[0]
5084 self.assertEqual(a, a_copy, 0)
5085 self.assertEqual(b, b_copy, 0)
5086 check_norm(a, b, expectedNorm, res2)
5088 res3 = torch.gels(b, a, out=(b, a))[0]
5089 check_norm(a_copy, b_copy, expectedNorm, res3)
5091 self.assertEqual(res1, tb, 0)
5092 self.assertEqual(res1, b, 0)
5093 self.assertEqual(res1, res2, 0)
5094 self.assertEqual(res1, res3, 0)
5098 a = torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5099 (-7.84, -0.28, 3.24, 8.09),
5100 (-4.39, -3.24, 6.27, 5.28),
5101 (4.53, 3.83, -6.64, 2.06))).t()
5102 b = torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5103 (9.35, -4.43, -0.70, -0.26))).t()
5104 _test_underdetermined(a, b, expectedNorm)
5107 expectedNorm = 17.390200628863
5108 a = torch.Tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45),
5109 (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70),
5110 (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19),
5111 (4.53, 3.83, -6.64, 2.06, -2.47, 4.70))).t()
5112 b = torch.Tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93),
5113 (9.35, -4.43, -0.70, -0.26, -7.36, -2.52))).t()
5114 _test_overdetermined(a, b, expectedNorm)
5118 a = torch.Tensor(((1.44, -9.96, -7.55),
5119 (-7.84, -0.28, 3.24),
5120 (-4.39, -3.24, 6.27),
5121 (4.53, 3.83, -6.64))).t()
5122 b = torch.Tensor(((8.58, 8.26, 8.48),
5123 (9.35, -4.43, -0.70))).t()
5124 _test_underdetermined(a, b, expectedNorm)
5128 a = torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5129 (-7.84, -0.28, 3.24, 8.09),
5130 (-4.39, -3.24, 6.27, 5.28),
5131 (4.53, 3.83, -6.64, 2.06))).t()
5132 b = torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5133 (9.35, -4.43, -0.70, -0.26))).t()
5136 torch.gels(b, a, out=(tb, ta))
5137 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5138 torch.gels(b, a, out=(tb, ta))
5139 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5140 torch.gels(b, a, out=(tb, ta))
5141 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5145 a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
5146 (-6.49, 3.80, 0.00, 0.00, 0.00),
5147 (-0.47, -6.39, 4.17, 0.00, 0.00),
5148 (-7.20, 1.50, -1.51, 5.70, 0.00),
5149 (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous()
5151 ee, vv = torch.eig(a,
True)
5154 eee, vvv = torch.eig(a,
True, out=(te, tv))
5155 self.assertEqual(e, ee, 1e-12)
5156 self.assertEqual(ee, eee, 1e-12)
5157 self.assertEqual(ee, te, 1e-12)
5158 self.assertEqual(vv, vvv, 1e-12)
5159 self.assertEqual(vv, tv, 1e-12)
5162 X = torch.randn(4, 4)
5163 X = torch.mm(X.t(), X)
5164 e, v = torch.zeros(4, 2), torch.zeros(4, 4)
5165 torch.eig(X,
True, out=(e, v))
5166 Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
5167 self.assertEqual(X, Xhat, 1e-8,
'VeV\' wrong')
5168 self.assertFalse(v.is_contiguous(),
'V is contiguous')
5170 torch.eig(X,
True, out=(e, v))
5171 Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t()))
5172 self.assertEqual(X, Xhat, 1e-8,
'VeV\' wrong')
5173 self.assertFalse(v.is_contiguous(),
'V is contiguous')
5176 X = torch.randn(4, 4)
5177 X = torch.mm(X.t(), X)
5178 e = torch.zeros(4, 2, 2)[:, 1]
5179 v = torch.zeros(4, 2, 4)[:, 1]
5180 self.assertFalse(v.is_contiguous(),
'V is contiguous')
5181 self.assertFalse(e.is_contiguous(),
'E is contiguous')
5182 torch.eig(X,
True, out=(e, v))
5183 Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
5184 self.assertEqual(X, Xhat, 1e-8,
'VeV\' wrong')
5187 def _test_symeig(self, conv_fn):
5188 xval = conv_fn(torch.rand(100, 3))
5189 cov = torch.mm(xval.t(), xval)
5190 rese = conv_fn(torch.zeros(3))
5191 resv = conv_fn(torch.zeros(3, 3))
5194 self.assertTrue(resv.is_contiguous(),
'resv is not contiguous')
5195 torch.symeig(cov.clone(),
True, out=(rese, resv))
5196 ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
5197 self.assertEqual(cov, ahat, 1e-8,
'VeV\' wrong')
5200 self.assertFalse(resv.is_contiguous(),
'resv is contiguous')
5201 torch.symeig(cov.clone(),
True, out=(rese, resv))
5202 ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
5203 self.assertEqual(cov, ahat, 1e-8,
'VeV\' wrong')
5206 rese2 = conv_fn(torch.zeros(3))
5207 resv2 = conv_fn(torch.randn(3, 3))
5208 expected_resv2 = conv_fn(torch.zeros(3, 3))
5209 torch.symeig(cov.clone(),
False, out=(rese2, resv2))
5210 self.assertEqual(rese, rese2)
5211 self.assertEqual(resv2, expected_resv2)
5214 X = conv_fn(torch.rand(5, 5))
5216 e = conv_fn(torch.zeros(4, 2)).select(1, 1)
5217 v = conv_fn(torch.zeros(4, 2, 4))[:, 1]
5218 self.assertFalse(v.is_contiguous(),
'V is contiguous')
5219 self.assertFalse(e.is_contiguous(),
'E is contiguous')
5220 torch.symeig(X,
True, out=(e, v))
5221 Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t())
5222 self.assertEqual(X, Xhat, 1e-8,
'VeV\' wrong')
5225 def test_symeig(self):
5226 self._test_symeig(self,
lambda x: x)
5230 a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
5231 (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
5232 (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
5233 (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
5234 (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
5235 u, s, v = torch.svd(a)
5239 uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
5240 self.assertEqual(u, uu, 0,
'torch.svd')
5241 self.assertEqual(u, uuu, 0,
'torch.svd')
5242 self.assertEqual(s, ss, 0,
'torch.svd')
5243 self.assertEqual(s, sss, 0,
'torch.svd')
5244 self.assertEqual(v, vv, 0,
'torch.svd')
5245 self.assertEqual(v, vvv, 0,
'torch.svd')
5248 X = torch.randn(4, 4)
5249 U, S, V = torch.svd(X)
5250 Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5251 self.assertEqual(X, Xhat, 1e-8,
'USV\' wrong')
5253 self.assertFalse(U.is_contiguous(),
'U is contiguous')
5254 torch.svd(X, out=(U, S, V))
5255 Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5256 self.assertEqual(X, Xhat, 1e-8,
'USV\' wrong')
5259 X = torch.randn(5, 5)
5260 U = torch.zeros(5, 2, 5)[:, 1]
5261 S = torch.zeros(5, 2)[:, 1]
5262 V = torch.zeros(5, 2, 5)[:, 1]
5264 self.assertFalse(U.is_contiguous(),
'U is contiguous')
5265 self.assertFalse(S.is_contiguous(),
'S is contiguous')
5266 self.assertFalse(V.is_contiguous(),
'V is contiguous')
5267 torch.svd(X, out=(U, S, V))
5268 Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5269 self.assertEqual(X, Xhat, 1e-8,
'USV\' wrong')
5272 def _test_svd_no_singularvectors(self, cast):
5273 for size
in [(5, 5), (5, 20), (20, 5)]:
5274 a = cast(torch.randn(*size))
5275 u, s_expect, v = torch.svd(a)
5276 u, s_actual, v = torch.svd(a, compute_uv=
False)
5277 self.assertEqual(s_expect, s_actual,
"Singular values don't match")
5280 def test_svd_no_singularvectors(self):
5281 self._test_svd_no_singularvectors(self,
lambda t: t)
5284 def _test_matrix_rank(self, conv_fn):
5285 a = conv_fn(torch.eye(10))
5286 self.assertEqual(torch.matrix_rank(a).item(), 10)
5287 self.assertEqual(torch.matrix_rank(a,
True).item(), 10)
5290 self.assertEqual(torch.matrix_rank(a).item(), 9)
5291 self.assertEqual(torch.matrix_rank(a,
True).item(), 9)
5293 a = conv_fn(torch.randn(24, 42))
5294 self.assertEqual(torch.matrix_rank(a), torch.matrix_rank(a.t()))
5295 aaT = torch.mm(a, a.t())
5296 self.assertEqual(torch.matrix_rank(aaT), torch.matrix_rank(aaT,
True))
5297 aTa = torch.mm(a.t(), a)
5298 self.assertEqual(torch.matrix_rank(aTa), torch.matrix_rank(aTa,
True))
5301 from numpy.linalg
import matrix_rank
5302 a = conv_fn(torch.randn(35, 75))
5303 self.assertEqual(torch.matrix_rank(a).item(), matrix_rank(a.cpu().numpy()))
5304 self.assertEqual(torch.matrix_rank(a, 0.01).item(), matrix_rank(a.cpu().numpy(), 0.01))
5306 aaT = torch.mm(a, a.t())
5307 self.assertEqual(torch.matrix_rank(aaT).item(), matrix_rank(aaT.cpu().numpy()))
5308 self.assertEqual(torch.matrix_rank(aaT, 0.01).item(), matrix_rank(aaT.cpu().numpy(), 0.01))
5310 if np.lib.NumpyVersion(np.__version__) >=
'1.14.0':
5311 self.assertEqual(torch.matrix_rank(aaT,
True).item(), matrix_rank(aaT.cpu().numpy(),
True))
5312 self.assertEqual(torch.matrix_rank(aaT, 0.01,
True).item(),
5313 matrix_rank(aaT.cpu().numpy(), 0.01,
True))
5316 def test_matrix_rank(self):
5317 self._test_matrix_rank(self,
lambda x: x)
5320 def _test_signal_window_functions(self, device='cpu'):
5322 raise unittest.SkipTest(
'Scipy not found')
5325 torch_method = getattr(torch, name +
'_window')
5326 for size
in [1, 2, 5, 10, 50, 100, 1024, 2048]:
5327 for periodic
in [
True,
False]:
5328 res = torch_method(size, periodic=periodic, device=device)
5329 ref = torch.from_numpy(signal.get_window(name, size, fftbins=periodic))
5330 self.assertEqual(res, ref)
5331 with self.assertRaisesRegex(RuntimeError,
r'not implemented for sparse types'):
5332 torch_method(3, layout=torch.sparse_coo)
5333 with self.assertRaisesRegex(RuntimeError,
r'floating point'):
5334 torch_method(3, dtype=torch.long)
5335 self.assertTrue(torch_method(3, requires_grad=
True).requires_grad)
5336 self.assertFalse(torch_method(3).requires_grad)
5338 for window
in [
'hann',
'hamming',
'bartlett',
'blackman']:
5341 def test_signal_window_functions(self):
5342 self._test_signal_window_functions(self)
5345 def _test_inverse(self, conv_fn):
5346 from common_utils
import random_fullrank_matrix_distinct_singular_value
5349 matrix = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
5350 matrix_inverse = torch.inverse(matrix)
5351 identity = conv_fn(torch.eye(5))
5352 self.assertEqual(identity, torch.mm(matrix, matrix_inverse), 1e-8,
'inverse value')
5353 self.assertEqual(identity, torch.mm(matrix_inverse, matrix), 1e-8,
'inverse value')
5355 matrix_inverse_out = conv_fn(torch.empty(5, 5))
5356 torch.inverse(matrix, out=matrix_inverse_out)
5357 self.assertEqual(matrix_inverse_out, matrix_inverse, 0,
'inverse value in-place')
5359 torch.inverse(matrix, out=matrix_inverse_out)
5360 self.assertEqual(matrix_inverse_out, matrix_inverse, 0,
'inverse value in-place')
5363 matrix = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 1))
5364 matrix_inverse = torch.inverse(matrix)
5365 expected_inv = matrix.squeeze(0).inverse()
5366 self.assertEqual(matrix_inverse, expected_inv.unsqueeze(0))
5369 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 4))
5370 expected_inv_list = []
5371 for i
in range(0, 4):
5372 expected_inv_list.append(torch.inverse(matrices[i]))
5373 expected_inv = torch.stack(expected_inv_list)
5374 matrices_inverse = torch.inverse(matrices)
5375 self.assertEqual(matrices_inverse, expected_inv)
5378 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 2, 3))
5379 expected_inv_list = []
5380 for mat
in matrices.view(-1, 5, 5):
5381 expected_inv_list.append(torch.inverse(mat))
5382 expected_inv = torch.stack(expected_inv_list).view(2, 3, 5, 5)
5383 matrices_inverse = torch.inverse(matrices)
5384 self.assertEqual(matrices_inverse, expected_inv)
5387 with self.assertRaisesRegex(RuntimeError,
"must be batches of square matrices"):
5388 torch.inverse(torch.randn(2, 3, 4, 3))
5391 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 3))
5392 matrices_inverse = torch.inverse(matrices)
5393 self.assertEqual(torch.matmul(matrices, matrices_inverse), identity.expand_as(matrices))
5394 self.assertEqual(torch.matmul(matrices_inverse, matrices), identity.expand_as(matrices))
5397 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 3))
5398 matrices_inverse = conv_fn(torch.empty(3, 5, 5))
5399 torch.inverse(matrices, out=matrices_inverse)
5400 self.assertEqual(torch.inverse(matrices), matrices_inverse)
5406 from numpy.linalg
import inv
5407 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 2)).permute(0, 2, 1)
5408 assert not matrices.is_contiguous()
5409 matrices_inverse = torch.inverse(matrices)
5410 expected_inv = torch.as_tensor(inv(matrices.cpu().numpy()))
5411 self.assertEqual(matrices_inverse, conv_fn(expected_inv))
5414 def test_inverse(self):
5415 self._test_inverse(self,
lambda t: t)
5418 def _test_pinverse(self, conv_fn):
5421 MPI = torch.pinverse(M)
5422 self.assertEqual(M, M.mm(MPI).mm(M), 1e-8,
'pseudo-inverse condition 1')
5423 self.assertEqual(MPI, MPI.mm(M).mm(MPI), 1e-8,
'pseudo-inverse condition 2')
5424 self.assertEqual(M.mm(MPI), (M.mm(MPI)).t(), 1e-8,
'pseudo-inverse condition 3')
5425 self.assertEqual(MPI.mm(M), (MPI.mm(M)).t(), 1e-8,
'pseudo-inverse condition 4')
5428 M = conv_fn(torch.randn(5, 5))
5432 M = conv_fn(torch.randn(3, 4))
5436 M = torch.randn(5, 5)
5437 M = conv_fn(M.mm(M.t()))
5438 self.assertEqual(conv_fn(torch.eye(5)), M.pinverse().mm(M), 1e-7,
'pseudo-inverse for invertible matrix')
5441 def test_pinverse(self):
5442 self._test_pinverse(self, conv_fn=
lambda x: x)
5445 def _test_matrix_power(self, conv_fn):
5446 def run_test(M, sign=1):
5449 MP2 = torch.matrix_power(M, 2)
5450 self.assertEqual(MP2, torch.matmul(M, M))
5452 MP3 = torch.matrix_power(M, 3)
5453 self.assertEqual(MP3, torch.matmul(MP2, M))
5455 MP4 = torch.matrix_power(M, 4)
5456 self.assertEqual(MP4, torch.matmul(MP2, MP2))
5458 MP6 = torch.matrix_power(M, 6)
5459 self.assertEqual(MP6, torch.matmul(MP3, MP3))
5461 MP0 = torch.matrix_power(M, 0)
5462 self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M))
5465 M = conv_fn(torch.randn(5, 5))
5469 M = conv_fn(torch.randn(3, 3, 3))
5473 M = conv_fn(torch.randn(2, 3, 3, 3))
5477 from common_utils
import random_fullrank_matrix_distinct_singular_value
5478 M = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
5479 run_test(M, sign=-1)
5481 M = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 3))
5482 run_test(M, sign=-1)
5484 M = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 2, 3))
5485 run_test(M, sign=-1)
5488 def test_matrix_power(self):
5489 self._test_matrix_power(self, conv_fn=
lambda x: x)
5492 def _test_chain_matmul(self, cast):
5493 def product(matrices):
5494 for mat
in matrices[1:]:
5495 matrices[0] = matrices[0].mm(mat)
5498 def run_test(p, cast):
5500 for (pi, pi_1)
in zip(p[:-1], p[1:]):
5501 matrices.append(cast(torch.randn(pi, pi_1)))
5502 self.assertEqual(torch.chain_matmul(*matrices), product(matrices))
5504 run_test([10, 20, 30, 5], cast)
5505 run_test([15, 5, 10, 20, 25], cast)
5507 def test_chain_matmul(self):
5508 self._test_chain_matmul(self, cast=
lambda x: x)
5511 def _test_det_logdet_slogdet(self, conv_fn):
5512 def reference_det(M):
5520 M[0], M[i] = M[i], M[0]
5525 for i
in range(1, l):
5528 row -= row[j] / M[j, j] * M[j]
5530 return M.diag().prod() * multiplier
5532 def test_single_det(M, target, desc):
5535 sdet, logabsdet = M.slogdet()
5536 self.assertEqual(det, target, 1e-7,
'{} (det)'.format(desc))
5538 self.assertTrue(logdet.item() != logdet.item(),
'{} (logdet negative case)'.format(desc))
5539 self.assertTrue(sdet.item() == -1,
'{} (slogdet sign negative case)'.format(desc))
5540 self.assertEqual(logabsdet.exp(), det.abs(), 1e-7,
'{} (slogdet logabsdet negative case)'.format(desc))
5541 elif det.item() == 0:
5542 self.assertEqual(logdet.exp().item(), 0, 1e-7,
'{} (logdet zero case)'.format(desc))
5543 self.assertTrue(sdet.item() == 0,
'{} (slogdet sign zero case)'.format(desc))
5544 self.assertEqual(logabsdet.exp().item(), 0, 1e-7,
'{} (slogdet logabsdet zero case)'.format(desc))
5546 self.assertEqual(logdet.exp(), det, 1e-7,
'{} (logdet positive case)'.format(desc))
5547 self.assertTrue(sdet.item() == 1,
'{} (slogdet sign positive case)'.format(desc))
5548 self.assertEqual(logabsdet.exp(), det, 1e-7,
'{} (slogdet logabsdet positive case)'.format(desc))
5550 eye = conv_fn(torch.eye(5))
5551 test_single_det(eye,
torch.tensor(1, dtype=eye.dtype),
'identity')
5554 is_cuda_8_92 =
False 5556 is_cuda_8_92 = any(x
in torch.version.cuda
for x
in [
'8.0',
'9.2'])
5559 assert M.size(0) >= 5,
'this helper fn assumes M to be at least 5x5' 5562 if M.is_cuda
and is_cuda_8_92:
5566 ref_M_det = reference_det(M)
5568 test_single_det(M, ref_M_det,
'basic')
5569 if abs(ref_M_det.item()) >= 1e-10:
5570 test_single_det(M, M.inverse().det().pow_(-1),
'inverse')
5571 test_single_det(M, M.t().det(),
'transpose')
5574 for scale
in [-2, -0.1, 0, 10]:
5575 target = M_det * scale
5578 M_clone[:, x] *= scale
5579 test_single_det(M_clone, target,
'scale a row')
5582 M_clone[x, :] *= scale
5583 test_single_det(M_clone, target,
'scale a column')
5585 for x1, x2
in [(0, 3), (4, 1), (3, 2)]:
5586 assert x1 != x2,
'x1 and x2 needs to be different for this test' 5587 target = M_det.clone().zero_()
5590 M_clone[:, x2] = M_clone[:, x1]
5591 test_single_det(M_clone, target,
'two rows are same')
5594 M_clone[x2, :] = M_clone[x1, :]
5595 test_single_det(M_clone, target,
'two columns are same')
5597 for scale1, scale2
in [(0.3, -1), (0, 2), (10, 0.1)]:
5598 target = -M_det * scale1 * scale2
5601 t = M_clone[:, x1] * scale1
5602 M_clone[:, x1] += M_clone[:, x2] * scale2
5604 test_single_det(M_clone, target,
'exchanging rows')
5607 t = M_clone[x1, :] * scale1
5608 M_clone[x1, :] += M_clone[x2, :] * scale2
5610 test_single_det(M_clone, target,
'exchanging columns')
5612 def get_random_mat_scale(n):
5628 return math.factorial(n - 1) ** (-1.0 / (2 * n))
5630 for n
in [5, 10, 25]:
5631 scale = get_random_mat_scale(n)
5632 test(torch.randn(n, n) * scale)
5633 r = torch.randn(n, n) * scale
5637 r = torch.randn(n, n) * scale
5638 test(r.mm(r.t()) + torch.eye(n) * 1e-6)
5640 r = torch.randn(n, n) * scale
5646 test((torch.randn(n, n, n + 1) * scale)[:, 2, 1:])
5648 r = torch.randn(n, n) * scale
5650 if reference_det(u) < 0:
5652 if reference_det(v) < 0:
5656 test(u.mm(s.diag()).mm(v))
5659 def test_det_logdet_slogdet(self):
5660 self._test_det_logdet_slogdet(self,
lambda x: x)
5663 def _test_fft_ifft_rfft_irfft(self, device='cpu'):
5664 def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
5665 x = prepro_fn(torch.randn(*sizes, device=device))
5666 for normalized
in (
True,
False):
5667 res = x.fft(signal_ndim, normalized=normalized)
5668 rec = res.ifft(signal_ndim, normalized=normalized)
5669 self.assertEqual(x, rec, 1e-8,
'fft and ifft')
5670 res = x.ifft(signal_ndim, normalized=normalized)
5671 rec = res.fft(signal_ndim, normalized=normalized)
5672 self.assertEqual(x, rec, 1e-8,
'ifft and fft')
5674 def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
5675 x = prepro_fn(torch.randn(*sizes, device=device))
5677 signal_sizes = x.size()[-signal_ndim:]
5678 for normalized, onesided
in product((
True,
False), repeat=2):
5679 res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided)
5681 def test_one_sample(res, test_num=10):
5682 idxs_per_dim = [torch.LongTensor(test_num).random_(s).tolist()
for s
in signal_sizes]
5683 for idx
in zip(*idxs_per_dim):
5684 reflected_idx = tuple((s - i) % s
for i, s
in zip(idx, res.size()))
5685 idx_val = res.__getitem__(idx)
5686 reflected_val = res.__getitem__(reflected_idx)
5687 self.assertEqual(idx_val[0], reflected_val[0],
'rfft hermitian symmetry on real part')
5688 self.assertEqual(idx_val[1], -reflected_val[1],
'rfft hermitian symmetry on imaginary part')
5689 if len(sizes) == signal_ndim:
5690 test_one_sample(res)
5692 output_non_batch_shape = res.size()[-(signal_ndim + 1):]
5693 flatten_batch_res = res.view(-1, *output_non_batch_shape)
5694 nb = flatten_batch_res.size(0)
5695 test_idxs = torch.LongTensor(min(nb, 4)).random_(nb)
5696 for test_idx
in test_idxs.tolist():
5697 test_one_sample(flatten_batch_res[test_idx])
5699 xc = torch.stack([x, torch.zeros_like(x)], -1)
5700 xc_res = xc.fft(signal_ndim, normalized=normalized)
5701 self.assertEqual(res, xc_res)
5702 test_input_signal_sizes = [signal_sizes]
5703 rec = res.irfft(signal_ndim, normalized=normalized,
5704 onesided=onesided, signal_sizes=signal_sizes)
5705 self.assertEqual(x, rec, 1e-8,
'rfft and irfft')
5707 rec = res.ifft(signal_ndim, normalized=normalized)
5708 self.assertEqual(x, rec.select(-1, 0), 1e-8,
'twosided rfft and ifft real')
5709 self.assertEqual(rec.select(-1, 1).data.abs().mean(), 0, 1e-8,
'twosided rfft and ifft imaginary')
5712 _test_real((100,), 1)
5713 _test_real((10, 1, 10, 100), 1)
5714 _test_real((100, 100), 2)
5715 _test_real((2, 2, 5, 80, 60), 2)
5716 _test_real((50, 40, 70), 3)
5717 _test_real((30, 1, 50, 25, 20), 3)
5719 _test_complex((100, 2), 1)
5720 _test_complex((100, 100, 2), 1)
5721 _test_complex((100, 100, 2), 2)
5722 _test_complex((1, 20, 80, 60, 2), 2)
5723 _test_complex((50, 40, 70, 2), 3)
5724 _test_complex((6, 5, 50, 25, 20, 2), 3)
5727 _test_real((165,), 1,
lambda x: x.narrow(0, 25, 100))
5728 _test_real((100, 100, 3), 1,
lambda x: x[:, :, 0])
5729 _test_real((100, 100), 2,
lambda x: x.t())
5730 _test_real((20, 100, 10, 10), 2,
lambda x: x.view(20, 100, 100)[:, :60])
5731 _test_real((65, 80, 115), 3,
lambda x: x[10:60, 13:53, 10:80])
5732 _test_real((30, 20, 50, 25), 3,
lambda x: x.transpose(1, 2).transpose(2, 3))
5734 _test_complex((2, 100), 1,
lambda x: x.t())
5735 _test_complex((100, 2), 1,
lambda x: x.expand(100, 100, 2))
5736 _test_complex((300, 200, 3), 2,
lambda x: x[:100, :100, 1:])
5737 _test_complex((20, 90, 110, 2), 2,
lambda x: x[:, 5:85].narrow(2, 5, 100))
5738 _test_complex((40, 60, 3, 80, 2), 3,
lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
5739 _test_complex((30, 55, 50, 22, 2), 3,
lambda x: x[:, 3:53, 15:40, 1:21])
5742 _test_complex((50,), 1,
lambda x: x.as_strided([5, 5, 2], [3, 2, 1]))
5743 _test_complex((50,), 1,
lambda x: x.as_strided([5, 5, 2], [4, 2, 2]))
5744 _test_complex((50,), 1,
lambda x: x.as_strided([5, 5, 2], [4, 3, 1]))
5745 _test_complex((50,), 2,
lambda x: x.as_strided([5, 5, 2], [3, 3, 1]))
5746 _test_complex((50,), 2,
lambda x: x.as_strided([5, 5, 2], [4, 2, 2]))
5747 _test_complex((50,), 2,
lambda x: x.as_strided([5, 5, 2], [4, 3, 1]))
5749 @unittest.skipIf(
not TEST_MKL,
"PyTorch is built without MKL support")
5750 def test_fft_ifft_rfft_irfft(self):
5751 self._test_fft_ifft_rfft_irfft(self)
5754 def _test_stft(self, device='cpu'):
5755 if not TEST_LIBROSA:
5756 raise unittest.SkipTest(
'librosa not found')
5758 def librosa_stft(x, n_fft, hop_length, win_length, window, center):
5760 window = np.ones(n_fft
if win_length
is None else win_length)
5762 window = window.cpu().numpy()
5763 input_1d = x.dim() == 1
5768 ri = librosa.stft(xi.cpu().numpy(), n_fft, hop_length, win_length, window, center=center)
5769 result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1)))
5770 result = torch.stack(result, 0)
5775 def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
5776 center=
True, expected_error=
None):
5777 x = torch.randn(*sizes, device=device)
5778 if win_sizes
is not None:
5779 window = torch.randn(*win_sizes, device=device)
5782 if expected_error
is None:
5783 result = x.stft(n_fft, hop_length, win_length, window, center=center)
5784 ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
5785 self.assertEqual(result, ref_result, 7e-6,
'stft comparison against librosa')
5787 self.assertRaises(expected_error,
5788 lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
5790 for center
in [
True,
False]:
5791 _test((10,), 7, center=center)
5792 _test((10, 4000), 1024, center=center)
5794 _test((10,), 7, 2, center=center)
5795 _test((10, 4000), 1024, 512, center=center)
5797 _test((10,), 7, 2, win_sizes=(7,), center=center)
5798 _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center)
5801 _test((10,), 7, 2, win_length=5, center=center)
5802 _test((10, 4000), 1024, 512, win_length=100, center=center)
5804 _test((10, 4, 2), 1, 1, expected_error=RuntimeError)
5805 _test((10,), 11, 1, center=
False, expected_error=RuntimeError)
5806 _test((10,), -1, 1, expected_error=RuntimeError)
5807 _test((10,), 3, win_length=5, expected_error=RuntimeError)
5808 _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
5809 _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
5811 def test_stft(self):
5812 self._test_stft(self)
5814 @unittest.skip(
"Not implemented yet")
5815 def test_conv2(self):
5816 x = torch.rand(math.floor(torch.uniform(50, 100)), math.floor(torch.uniform(50, 100)))
5817 k = torch.rand(math.floor(torch.uniform(10, 20)), math.floor(torch.uniform(10, 20)))
5818 imvc = torch.conv2(x, k)
5819 imvc2 = torch.conv2(x, k,
'V')
5820 imfc = torch.conv2(x, k,
'F')
5825 for i
in range(ks.size() - 1, 0, -1):
5826 kis[ks.size() - i + 1] = ks[i]
5828 imvx = torch.xcorr2(x, ki)
5829 imvx2 = torch.xcorr2(x, ki,
'V')
5830 imfx = torch.xcorr2(x, ki,
'F')
5832 self.assertEqual(imvc, imvc2, 0,
'torch.conv2')
5833 self.assertEqual(imvc, imvx, 0,
'torch.conv2')
5834 self.assertEqual(imvc, imvx2, 0,
'torch.conv2')
5835 self.assertEqual(imfc, imfx, 0,
'torch.conv2')
5836 self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr2(x, x)[0][0]), 1e-10,
'torch.conv2')
5838 xx = torch.Tensor(2, x.size(1), x.size(2))
5841 kk = torch.Tensor(2, k.size(1), k.size(2))
5845 immvc = torch.conv2(xx, kk)
5846 immvc2 = torch.conv2(xx, kk,
'V')
5847 immfc = torch.conv2(xx, kk,
'F')
5849 self.assertEqual(immvc[0], immvc[1], 0,
'torch.conv2')
5850 self.assertEqual(immvc[0], imvc, 0,
'torch.conv2')
5851 self.assertEqual(immvc2[0], imvc2, 0,
'torch.conv2')
5852 self.assertEqual(immfc[0], immfc[1], 0,
'torch.conv2')
5853 self.assertEqual(immfc[0], imfc, 0,
'torch.conv2')
5855 @unittest.skip(
"Not implemented yet")
5856 def test_conv3(self):
5857 x = torch.rand(math.floor(torch.uniform(20, 40)),
5858 math.floor(torch.uniform(20, 40)),
5859 math.floor(torch.uniform(20, 40)))
5860 k = torch.rand(math.floor(torch.uniform(5, 10)),
5861 math.floor(torch.uniform(5, 10)),
5862 math.floor(torch.uniform(5, 10)))
5863 imvc = torch.conv3(x, k)
5864 imvc2 = torch.conv3(x, k,
'V')
5865 imfc = torch.conv3(x, k,
'F')
5870 for i
in range(ks.size() - 1, 0, -1):
5871 kis[ks.size() - i + 1] = ks[i]
5872 imvx = torch.xcorr3(x, ki)
5873 imvx2 = torch.xcorr3(x, ki,
'V')
5874 imfx = torch.xcorr3(x, ki,
'F')
5876 self.assertEqual(imvc, imvc2, 0,
'torch.conv3')
5877 self.assertEqual(imvc, imvx, 0,
'torch.conv3')
5878 self.assertEqual(imvc, imvx2, 0,
'torch.conv3')
5879 self.assertEqual(imfc, imfx, 0,
'torch.conv3')
5880 self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr3(x, x)[0][0][0]), 4e-10,
'torch.conv3')
5882 xx = torch.Tensor(2, x.size(1), x.size(2), x.size(3))
5885 kk = torch.Tensor(2, k.size(1), k.size(2), k.size(3))
5889 immvc = torch.conv3(xx, kk)
5890 immvc2 = torch.conv3(xx, kk,
'V')
5891 immfc = torch.conv3(xx, kk,
'F')
5893 self.assertEqual(immvc[0], immvc[1], 0,
'torch.conv3')
5894 self.assertEqual(immvc[0], imvc, 0,
'torch.conv3')
5895 self.assertEqual(immvc2[0], imvc2, 0,
'torch.conv3')
5896 self.assertEqual(immfc[0], immfc[1], 0,
'torch.conv3')
5897 self.assertEqual(immfc[0], imfc, 0,
'torch.conv3')
5899 @unittest.skip(
"Not implemented yet")
5900 def _test_conv_corr_eq(self, fn, fn_2_to_3):
5901 ix = math.floor(random.randint(20, 40))
5902 iy = math.floor(random.randint(20, 40))
5903 iz = math.floor(random.randint(20, 40))
5904 kx = math.floor(random.randint(5, 10))
5905 ky = math.floor(random.randint(5, 10))
5906 kz = math.floor(random.randint(5, 10))
5908 x = torch.rand(ix, iy, iz)
5909 k = torch.rand(kx, ky, kz)
5912 o32 = torch.zeros(o3.size())
5913 fn_2_to_3(x, k, o3, o32)
5914 self.assertEqual(o3, o32)
5916 @unittest.skip(
"Not implemented yet")
5917 def test_xcorr3_xcorr2_eq(self):
5918 def reference(x, k, o3, o32):
5919 for i
in range(o3.size(1)):
5920 for j
in range(k.size(1)):
5921 o32[i].add(torch.xcorr2(x[i + j - 1], k[j]))
5922 self._test_conv_corr_eq(torch.xcorr3, reference)
5924 @unittest.skip(
"Not implemented yet")
5925 def test_xcorr3_xcorr2_eq_full(self):
5926 def reference(x, k, o3, o32):
5927 for i
in range(x.size(1)):
5928 for j
in range(k.size(1)):
5929 o32[i].add(torch.xcorr2(x[i], k[k.size(1) - j + 1],
'F'))
5930 self._test_conv_corr_eq(
lambda x, k: torch.xcorr3(x, k,
'F'), reference)
5932 @unittest.skip(
"Not implemented yet")
5933 def test_conv3_conv2_eq_valid(self):
5934 def reference(x, k, o3, o32):
5935 for i
in range(o3.size(1)):
5936 for j
in range(k.size(1)):
5937 o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1]))
5938 self._test_conv_corr_eq(torch.conv3, reference)
5940 @unittest.skip(
"Not implemented yet")
5941 def test_fconv3_fconv2_eq(self):
5942 def reference(x, k, o3, o32):
5943 for i
in range(o3.size(1)):
5944 for j
in range(k.size(1)):
5945 o32[i + j - 1].add(torch.conv2(x[i], k[j],
'F'))
5946 self._test_conv_corr_eq(
lambda x, k: torch.conv3(x, k,
'F'), reference)
5948 def test_logical(self):
5949 x = torch.rand(100, 100) * 2 - 1
5951 xgt = torch.gt(x, 1)
5952 xlt = torch.lt(x, 1)
5954 xeq = torch.eq(x, 1)
5955 xne = torch.ne(x, 1)
5959 self.assertEqual(neqs.long().sum(), xne.long().sum(), 0)
5960 self.assertEqual(x.nelement(), all.long().sum())
5962 def test_isfinite(self):
5963 x = torch.Tensor([1, inf, 2, -inf, nan, -10])
5964 self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 0, 1, 0, 0, 1]))
5966 def test_isfinite_int(self):
5968 self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 1, 1]))
5971 def _test_isinf(self, cast):
5972 t1 = cast(torch.Tensor([1, inf, 2, -inf, nan]))
5973 t2 = cast(torch.ByteTensor([1, 2, 3]))
5974 t3 = cast(torch.CharTensor([1, 2, 3]))
5975 t4 = cast(torch.ShortTensor([1, 2, 3]))
5976 t5 = cast(torch.IntTensor([1, 2, 3]))
5977 t6 = cast(torch.LongTensor([1, 2, 3]))
5978 self.assertEqual(torch.isinf(t1), cast(torch.ByteTensor([0, 1, 0, 1, 0])))
5979 self.assertEqual(torch.isinf(t2), cast(torch.ByteTensor([0, 0, 0])))
5980 self.assertEqual(torch.isinf(t3), cast(torch.ByteTensor([0, 0, 0])))
5981 self.assertEqual(torch.isinf(t4), cast(torch.ByteTensor([0, 0, 0])))
5982 self.assertEqual(torch.isinf(t5), cast(torch.ByteTensor([0, 0, 0])))
5983 self.assertEqual(torch.isinf(t6), cast(torch.ByteTensor([0, 0, 0])))
5985 def test_isinf(self):
5986 self._test_isinf(self,
lambda t: t)
5988 def test_isnan(self):
5989 x = torch.Tensor([1, nan, 2])
5990 self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0]))
5992 def test_RNGState(self):
5993 state = torch.get_rng_state()
5994 stateCloned = state.clone()
5995 before = torch.rand(1000)
5997 self.assertEqual(state.ne(stateCloned).long().sum(), 0, 0)
5999 torch.set_rng_state(state)
6000 after = torch.rand(1000)
6001 self.assertEqual(before, after, 0)
6003 def test_RNGStateAliasing(self):
6005 gen = torch.Generator()
6006 gen.set_state(torch.get_rng_state())
6007 self.assertEqual(gen.get_state(), torch.get_rng_state())
6009 target_value = torch.rand(1000)
6011 _ = torch.rand(100000)
6012 forked_value = torch.rand(1000, generator=gen)
6013 self.assertEqual(target_value, forked_value, 0,
"RNG has not forked correctly.")
6015 def test_RNG_after_pickle(self):
6017 before = torch.rand(10)
6021 tensor = torch.Tensor([1, 2, 3])
6022 ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor)
6023 after = torch.rand(10)
6025 self.assertEqual(before, after, 0)
6027 def test_boxMullerState(self):
6028 torch.manual_seed(123)
6030 seeded = torch.randn(odd_number)
6031 state = torch.get_rng_state()
6032 midstream = torch.randn(odd_number)
6033 torch.set_rng_state(state)
6034 repeat_midstream = torch.randn(odd_number)
6035 torch.manual_seed(123)
6036 reseeded = torch.randn(odd_number)
6037 self.assertEqual(midstream, repeat_midstream, 0,
6038 'get_rng_state/set_rng_state not generating same sequence of normally distributed numbers')
6039 self.assertEqual(seeded, reseeded, 0,
6040 'repeated calls to manual_seed not generating same sequence of normally distributed numbers')
6042 def test_manual_seed(self):
6043 rng_state = torch.get_rng_state()
6044 torch.manual_seed(2)
6045 x = torch.randn(100)
6046 self.assertEqual(torch.initial_seed(), 2)
6047 torch.manual_seed(2)
6048 y = torch.randn(100)
6049 self.assertEqual(x, y)
6050 torch.set_rng_state(rng_state)
6053 def _test_cholesky(self, cast):
6054 x = cast(torch.rand(10, 10) + 1e-1)
6055 A = torch.mm(x, x.t())
6058 C = torch.cholesky(A)
6059 B = torch.mm(C, C.t())
6060 self.assertEqual(A, B, 1e-14)
6063 U = torch.cholesky(A,
True)
6064 B = torch.mm(U.t(), U)
6065 self.assertEqual(A, B, 1e-14,
'cholesky (upper) did not allow rebuilding the original matrix')
6068 L = torch.cholesky(A,
False)
6069 B = torch.mm(L, L.t())
6070 self.assertEqual(A, B, 1e-14,
'cholesky (lower) did not allow rebuilding the original matrix')
6073 def test_cholesky(self):
6074 self._test_cholesky(self,
lambda t: t)
6077 def _test_cholesky_batched(self, cast):
6078 from common_utils
import random_symmetric_pd_matrix
6080 def cholesky_test_helper(n, batch_dims, cast, upper):
6081 A = cast(random_symmetric_pd_matrix(n, *batch_dims))
6082 cholesky_exp = torch.stack([m.cholesky(upper=upper)
for m
in A.reshape(-1, n, n)])
6083 cholesky_exp = cholesky_exp.reshape_as(A)
6084 self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
6086 for upper, batchsize
in product([
True,
False], [(3,), (3, 4), (2, 3, 4)]):
6087 cholesky_test_helper(3, batchsize, cast, upper)
6090 def test_cholesky_batched(self):
6091 self._test_cholesky_batched(self,
lambda t: t)
6094 def _test_cholesky_solve(self, cast):
6095 a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
6096 (-6.05, -3.30, 5.36, -4.44, 1.08),
6097 (-0.45, 2.58, -2.70, 0.27, 9.04),
6098 (8.32, 2.71, 4.35, -7.17, 2.14),
6099 (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
6100 b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
6101 (-1.56, 4.00, -8.67, 1.75, 2.86),
6102 (9.81, -4.09, -4.57, -8.61, 8.99))).t()
6105 a = torch.mm(a, a.t())
6106 a, b = cast(a), cast(b)
6109 U = torch.cholesky(a,
True)
6110 x = torch.cholesky_solve(b, U,
True)
6111 self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
6114 L = torch.cholesky(a,
False)
6115 x = torch.cholesky_solve(b, L,
False)
6116 self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
6119 L_def = torch.cholesky(a)
6120 x_def = torch.cholesky_solve(b, L_def)
6121 self.assertLessEqual(b.dist(torch.mm(a, x_def)), 1e-12)
6124 def test_cholesky_solve(self):
6125 self._test_cholesky_solve(self,
lambda t: t)
6128 def _test_cholesky_solve_batched(self, cast):
6129 from common_utils
import random_symmetric_pd_matrix
6131 def cholesky_solve_test_helper(A_dims, b_dims, cast, upper):
6132 A = cast(random_symmetric_pd_matrix(*A_dims))
6133 L = torch.cholesky(A, upper)
6134 b = cast(torch.randn(*b_dims))
6137 for upper
in [
True,
False]:
6139 A, L, b = cholesky_solve_test_helper((5, 1), (1, 5, 10), cast, upper)
6140 x_exp = torch.cholesky_solve(b.squeeze(0), L.squeeze(0), upper=upper)
6141 x = torch.cholesky_solve(b, L, upper=upper)
6142 self.assertEqual(x, x_exp.unsqueeze(0))
6145 A, L, b = cholesky_solve_test_helper((5, 4), (4, 5, 10), cast, upper)
6148 x_exp = torch.cholesky_solve(b[i], L[i], upper=upper)
6149 x_exp_list.append(x_exp)
6150 x_exp = torch.stack(x_exp_list)
6152 x = torch.cholesky_solve(b, L, upper=upper)
6153 self.assertEqual(x, x_exp)
6156 A, L, b = cholesky_solve_test_helper((5, 3), (3, 5, 10), cast, upper)
6157 x = torch.cholesky_solve(b, L, upper)
6158 self.assertLessEqual(b.dist(torch.matmul(A, x)), 1e-12)
6164 from numpy.linalg
import solve
6165 A = random_symmetric_pd_matrix(2, 2)
6166 b = torch.randn(2, 2, 2)
6167 x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy()))
6168 A = cast(A).permute(0, 2, 1)
6169 b = cast(b).permute(2, 1, 0)
6170 assert not A.is_contiguous()
and not b.is_contiguous(),
"contiguous inputs" 6171 L = torch.cholesky(A, upper)
6172 x = torch.cholesky_solve(b, L, upper=upper)
6173 self.assertEqual(x, cast(x_exp))
6176 def test_cholesky_solve_batched(self):
6177 self._test_cholesky_solve_batched(self,
lambda t: t)
6180 def _test_cholesky_solve_batched_dims(self, cast):
6184 from numpy.linalg
import solve
6185 from common_utils
import random_symmetric_pd_matrix
6187 def run_test(A_dims, b_dims, cast, upper):
6188 A = random_symmetric_pd_matrix(*A_dims)
6189 b = torch.randn(*b_dims)
6190 x_exp = torch.Tensor(solve(A.numpy(), b.numpy()))
6191 A, b = cast(A), cast(b)
6192 L = torch.cholesky(A, upper)
6193 x = torch.cholesky_solve(b, L, upper=upper)
6194 self.assertEqual(x, cast(x_exp))
6196 for upper
in [
True,
False]:
6198 run_test((4, 2, 1, 3), (2, 1, 3, 4, 6), cast, upper)
6199 run_test((4, 2, 1, 3), (4, 6), cast, upper)
6200 run_test((4,), (2, 1, 3, 4, 2), cast, upper)
6201 run_test((4, 1, 3, 1), (2, 1, 3, 4, 5), cast, upper)
6204 def test_cholesky_solve_batched_dims(self):
6205 self._test_cholesky_solve_batched_dims(self,
lambda t: t)
6208 def test_potri(self):
6209 a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
6210 (-6.05, -3.30, 5.36, -4.44, 1.08),
6211 (-0.45, 2.58, -2.70, 0.27, 9.04),
6212 (8.32, 2.71, 4.35, -7.17, 2.14),
6213 (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
6216 a = torch.mm(a, a.t())
6219 inv0 = torch.inverse(a)
6222 chol = torch.cholesky(a)
6223 inv1 = torch.potri(chol,
False)
6224 self.assertLessEqual(inv0.dist(inv1), 1e-12)
6227 chol = torch.cholesky(a,
True)
6228 inv1 = torch.potri(chol,
True)
6229 self.assertLessEqual(inv0.dist(inv1), 1e-12)
6232 chol = torch.cholesky(a,
False)
6233 inv1 = torch.potri(chol,
False)
6234 self.assertLessEqual(inv0.dist(inv1), 1e-12)
6237 def test_pstrf(self):
6238 def checkPsdCholesky(a, uplo, inplace):
6240 u = torch.empty_like(a)
6241 piv = a.new(a.size(0)).int()
6242 kwargs = {
'out': (u, piv)}
6247 if uplo
is not None:
6250 u, piv = torch.pstrf(*args, **kwargs)
6253 a_reconstructed = torch.mm(u, u.t())
6255 a_reconstructed = torch.mm(u.t(), u)
6258 a_permuted = a.index_select(0, piv).index_select(1, piv)
6259 self.assertEqual(a_permuted, a_reconstructed, 1e-14)
6261 dimensions = ((5, 1), (5, 3), (5, 5), (10, 10))
6262 for dim
in dimensions:
6263 m = torch.Tensor(*dim).uniform_()
6264 a = torch.mm(m, m.t())
6266 for i
in range(m.size(0)):
6267 a[i][i] = a[i][i] + 1e-7
6268 for inplace
in (
True,
False):
6269 for uplo
in (
None,
True,
False):
6270 checkPsdCholesky(a, uplo, inplace)
6272 def test_numel(self):
6273 b = torch.ByteTensor(3, 100, 100)
6274 self.assertEqual(b.nelement(), 3 * 100 * 100)
6275 self.assertEqual(b.numel(), 3 * 100 * 100)
6277 def _consecutive(self, size, start=1):
6278 sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0)
6279 sequence.add_(start - 1)
6280 return sequence.resize_(*size)
6283 def _test_index(self, conv_fn):
6285 def consec(size, start=1):
6286 sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0)
6287 sequence.add_(start - 1)
6288 return sequence.view(*size)
6290 reference = conv_fn(consec((3, 3, 3)))
6293 self.assertEqual(reference[conv_fn(torch.LongTensor())], reference.new(0, 3, 3))
6295 self.assertEqual(reference[0], consec((3, 3)), 0)
6296 self.assertEqual(reference[1], consec((3, 3), 10), 0)
6297 self.assertEqual(reference[2], consec((3, 3), 19), 0)
6298 self.assertEqual(reference[0, 1], consec((3,), 4), 0)
6299 self.assertEqual(reference[0:2], consec((2, 3, 3)), 0)
6300 self.assertEqual(reference[2, 2, 2], 27, 0)
6301 self.assertEqual(reference[:], consec((3, 3, 3)), 0)
6304 self.assertEqual(reference[..., 2], torch.Tensor([[3, 6, 9],
6307 self.assertEqual(reference[0, ..., 2], torch.Tensor([3, 6, 9]), 0)
6308 self.assertEqual(reference[..., 2], reference[:, :, 2], 0)
6309 self.assertEqual(reference[0, ..., 2], reference[0, :, 2], 0)
6310 self.assertEqual(reference[0, 2, ...], reference[0, 2], 0)
6311 self.assertEqual(reference[..., 2, 2, 2], 27, 0)
6312 self.assertEqual(reference[2, ..., 2, 2], 27, 0)
6313 self.assertEqual(reference[2, 2, ..., 2], 27, 0)
6314 self.assertEqual(reference[2, 2, 2, ...], 27, 0)
6315 self.assertEqual(reference[...], reference, 0)
6317 reference_5d = conv_fn(consec((3, 3, 3, 3, 3)))
6318 self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], 0)
6319 self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], 0)
6320 self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], 0)
6321 self.assertEqual(reference_5d[...], reference_5d, 0)
6324 reference = conv_fn(consec((5, 5, 5)))
6325 idx = conv_fn(torch.LongTensor([2, 4]))
6326 self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
6332 self.assertEqual(reference[2,
None], reference[2].unsqueeze(0))
6333 self.assertEqual(reference[2,
None,
None], reference[2].unsqueeze(0).unsqueeze(0))
6334 self.assertEqual(reference[2:4,
None], reference[2:4].unsqueeze(1))
6335 self.assertEqual(reference[
None, 2,
None,
None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0))
6336 self.assertEqual(reference[
None, 2:5,
None,
None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2))
6339 self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)])
6340 self.assertEqual(torch.empty(0, 5), reference[slice(0), 2])
6341 self.assertEqual(torch.empty(0, 5), reference[2, slice(0)])
6342 self.assertEqual(
torch.tensor([]), reference[2, 1:1, 2])
6345 reference = consec((10, 10, 10))
6346 self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0))
6347 self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0))
6348 self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0))
6349 self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1))
6350 self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0))
6351 self.assertEqual(reference[
None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0))
6352 self.assertEqual(reference[:, 2, 1:6:2],
6353 torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1))
6355 lst = [list(range(i, i + 10))
for i
in range(0, 100, 10)]
6356 tensor = conv_fn(torch.DoubleTensor(lst))
6357 for _i
in range(100):
6358 idx1_start = random.randrange(10)
6359 idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
6360 idx1_step = random.randrange(1, 8)
6361 idx1 = slice(idx1_start, idx1_end, idx1_step)
6362 if random.randrange(2) == 0:
6363 idx2_start = random.randrange(10)
6364 idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
6365 idx2_step = random.randrange(1, 8)
6366 idx2 = slice(idx2_start, idx2_end, idx2_step)
6367 lst_indexed = list(map(
lambda l: l[idx2], lst[idx1]))
6368 tensor_indexed = tensor[idx1, idx2]
6370 lst_indexed = lst[idx1]
6371 tensor_indexed = tensor[idx1]
6372 self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
6374 self.assertRaises(ValueError,
lambda: reference[1:9:0])
6375 self.assertRaises(ValueError,
lambda: reference[1:9:-1])
6377 self.assertRaises(IndexError,
lambda: reference[1, 1, 1, 1])
6378 self.assertRaises(IndexError,
lambda: reference[1, 1, 1, 1:1])
6379 self.assertRaises(IndexError,
lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
6381 self.assertRaises(IndexError,
lambda: reference[0.0])
6382 self.assertRaises(TypeError,
lambda: reference[0.0:2.0])
6383 self.assertRaises(IndexError,
lambda: reference[0.0, 0.0:2.0])
6384 self.assertRaises(IndexError,
lambda: reference[0.0, :, 0.0:2.0])
6385 self.assertRaises(IndexError,
lambda: reference[0.0, ..., 0.0:2.0])
6386 self.assertRaises(IndexError,
lambda: reference[0.0, :, 0.0])
6391 self.assertRaises(TypeError, delitem)
6393 def test_index(self):
6394 self._test_index(self,
lambda x: x)
6397 def _test_advancedindex(self, conv_fn):
6401 def consec(size, start=1):
6402 numel = reduce(
lambda x, y: x * y, size, 1)
6403 sequence = torch.ones(numel).cumsum(0)
6404 sequence.add_(start - 1)
6405 return sequence.view(*size)
6409 choice = random.randint(0, 2)
6411 return conv_fn(torch.LongTensor(indices))
6413 return list(indices)
6415 return tuple(indices)
6417 def validate_indexing(x):
6418 self.assertEqual(x[[0]], consec((1,)))
6419 self.assertEqual(x[ri([0]), ], consec((1,)))
6420 self.assertEqual(x[ri([3]), ], consec((1,), 4))
6421 self.assertEqual(x[[2, 3, 4]], consec((3,), 3))
6422 self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3))
6423 self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([1, 3, 5]))
6425 def validate_setting(x):
6428 self.assertEqual(x[[0]], torch.Tensor([-2]).type(dtype))
6430 self.assertEqual(x[ri([0]), ], torch.Tensor([-1]).type(dtype))
6432 self.assertEqual(x[[2, 3, 4]], torch.Tensor([4, 4, 4]).type(dtype))
6433 x[ri([2, 3, 4]), ] = 3
6434 self.assertEqual(x[ri([2, 3, 4]), ], torch.Tensor([3, 3, 3]).type(dtype))
6435 x[ri([0, 2, 4]), ] = conv_fn(torch.Tensor([5, 4, 3])).type(dtype)
6436 self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([5, 4, 3]).type(dtype))
6441 reference = conv_fn(consec((10,)))
6442 validate_indexing(reference)
6443 validate_indexing(reference.type(torch.half))
6446 validate_setting(reference)
6447 validate_setting(reference.type(torch.half))
6452 reference = conv_fn(consec((10,)))
6453 strided = conv_fn(torch.Tensor())
6454 strided.set_(reference.storage(), storage_offset=0,
6455 size=torch.Size([4]), stride=[2])
6457 self.assertEqual(strided[[0]], torch.Tensor([1]))
6458 self.assertEqual(strided[ri([0]), ], torch.Tensor([1]))
6459 self.assertEqual(strided[ri([3]), ], torch.Tensor([7]))
6460 self.assertEqual(strided[[1, 2]], torch.Tensor([3, 5]))
6461 self.assertEqual(strided[ri([1, 2]), ], torch.Tensor([3, 5]))
6462 self.assertEqual(strided[ri([[2, 1], [0, 3]]), ],
6463 torch.Tensor([[5, 3], [1, 7]]))
6466 strided = conv_fn(torch.Tensor())
6467 strided.set_(reference.storage(), storage_offset=4,
6468 size=torch.Size([2]), stride=[4])
6469 self.assertEqual(strided[[0]], torch.Tensor([5]))
6470 self.assertEqual(strided[ri([0]), ], torch.Tensor([5]))
6471 self.assertEqual(strided[ri([1]), ], torch.Tensor([9]))
6472 self.assertEqual(strided[[0, 1]], torch.Tensor([5, 9]))
6473 self.assertEqual(strided[ri([0, 1]), ], torch.Tensor([5, 9]))
6474 self.assertEqual(strided[ri([[0, 1], [1, 0]]), ],
6475 torch.Tensor([[5, 9], [9, 5]]))
6480 reference = conv_fn(consec((3, 2)))
6481 self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([1, 3, 5]))
6482 self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([2, 4, 6]))
6483 self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
6484 self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
6485 self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([1, 2]))
6486 self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
6487 torch.Tensor([2, 4, 4, 2, 6]))
6488 self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6489 torch.Tensor([1, 2, 3, 3]))
6494 self.assertEqual(reference[rows, columns], torch.Tensor([[1, 1],
6499 columns = ri([1, 0])
6500 self.assertEqual(reference[rows, columns], torch.Tensor([[2, 1],
6504 columns = ri([[0, 1],
6506 self.assertEqual(reference[rows, columns], torch.Tensor([[1, 2],
6510 reference[ri([0]), ri([1])] = -1
6511 self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1]))
6512 reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4]))
6513 self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1,
6515 reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6516 self.assertEqual(reference[rows, columns],
6517 torch.Tensor([[4, 6], [2, 3]]))
6521 reference = conv_fn(torch.Tensor([[0, 1, 2, 3],
6523 [8, 9, 10, 11]])).t_()
6530 self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([0, 1,
6532 self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([4, 5,
6534 self.assertEqual(reference[ri([0]), ri([0])], torch.Tensor([0]))
6535 self.assertEqual(reference[ri([2]), ri([1])], torch.Tensor([6]))
6536 self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([0, 4]))
6537 self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
6538 torch.Tensor([4, 5, 5, 4, 7]))
6539 self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6540 torch.Tensor([0, 4, 1, 1]))
6545 self.assertEqual(reference[rows, columns], torch.Tensor([[0, 0],
6550 columns = ri([1, 0])
6551 self.assertEqual(reference[rows, columns], torch.Tensor([[4, 0],
6555 columns = ri([[0, 1],
6557 self.assertEqual(reference[rows, columns], torch.Tensor([[0, 4],
6561 reference[ri([0]), ri([1])] = -1
6562 self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1]))
6563 reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4]))
6564 self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1,
6566 reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6567 self.assertEqual(reference[rows, columns],
6568 torch.Tensor([[4, 6], [2, 3]]))
6575 reference = conv_fn(torch.arange(0., 24).view(3, 8))
6576 strided = conv_fn(torch.Tensor())
6577 strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
6580 self.assertEqual(strided[ri([0, 1]), ri([0])], torch.Tensor([1, 9]))
6581 self.assertEqual(strided[ri([0, 1]), ri([1])], torch.Tensor([3, 11]))
6582 self.assertEqual(strided[ri([0]), ri([0])], torch.Tensor([1]))
6583 self.assertEqual(strided[ri([1]), ri([3])], torch.Tensor([15]))
6584 self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]], torch.Tensor([1, 7]))
6585 self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
6586 torch.Tensor([9, 11, 11, 9, 15]))
6587 self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6588 torch.Tensor([1, 3, 9, 9]))
6593 self.assertEqual(strided[rows, columns], torch.Tensor([[1, 1],
6598 columns = ri([1, 2])
6599 self.assertEqual(strided[rows, columns], torch.Tensor([[3, 13],
6603 columns = ri([[0, 1],
6605 self.assertEqual(strided[rows, columns], torch.Tensor([[1, 3],
6613 reference = conv_fn(torch.arange(0., 24).view(3, 8))
6614 strided = conv_fn(torch.Tensor())
6615 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6617 self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([11]))
6618 strided[ri([0]), ri([1])] = -1
6619 self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([-1]))
6621 reference = conv_fn(torch.arange(0., 24).view(3, 8))
6622 strided = conv_fn(torch.Tensor())
6623 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6625 self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([11,
6627 strided[ri([0, 1]), ri([1, 0])] = conv_fn(torch.Tensor([-1, 2]))
6628 self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([-1,
6631 reference = conv_fn(torch.arange(0., 24).view(3, 8))
6632 strided = conv_fn(torch.Tensor())
6633 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6638 columns = ri([[0, 1],
6640 self.assertEqual(strided[rows, columns],
6641 torch.Tensor([[10, 11], [17, 18]]))
6642 strided[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6643 self.assertEqual(strided[rows, columns],
6644 torch.Tensor([[4, 6], [2, 3]]))
6651 reference = conv_fn(consec((3, 2)))
6652 self.assertEqual(reference[ri([0, 2]), ], torch.Tensor([[1, 2], [5, 6]]))
6653 self.assertEqual(reference[ri([1]), ...], torch.Tensor([[3, 4]]))
6654 self.assertEqual(reference[..., ri([1])], torch.Tensor([[2], [4], [6]]))
6657 with self.assertRaises(IndexError):
6658 reference[ri([1]), ri([0, 2]), ri([3])]
6661 reference = conv_fn(torch.empty(10))
6663 if not reference.is_cuda:
6664 for err_idx
in (10, -11):
6665 with self.assertRaisesRegex(IndexError,
r'out of'):
6667 with self.assertRaisesRegex(IndexError,
r'out of'):
6668 reference[conv_fn(torch.LongTensor([err_idx]))]
6669 with self.assertRaisesRegex(IndexError,
r'out of'):
6670 reference[[err_idx]]
6677 def tensor_indices_to_np(tensor, indices):
6679 if (tensor.is_cuda):
6680 tensor = tensor.cpu()
6681 npt = tensor.numpy()
6684 idxs = tuple(i.tolist()
if isinstance(i, torch.LongTensor)
else 6689 def get_numpy(tensor, indices):
6690 npt, idxs = tensor_indices_to_np(tensor, indices)
6693 return torch.Tensor(npt[idxs])
6695 def set_numpy(tensor, indices, value):
6696 if not isinstance(value, int):
6699 value = value.numpy()
6701 npt, idxs = tensor_indices_to_np(tensor, indices)
6705 def assert_get_eq(tensor, indexer):
6706 self.assertEqual(tensor[indexer],
6707 conv_fn(get_numpy(tensor, indexer)))
6709 def assert_set_eq(tensor, indexer, val):
6710 pyt = tensor.clone()
6711 numt = tensor.clone()
6713 numt = conv_fn(torch.Tensor(set_numpy(numt, indexer, val)))
6714 self.assertEqual(pyt, numt)
6716 def get_set_tensor(indexed, indexer):
6717 set_size = indexed[indexer].size()
6718 set_count = indexed[indexer].numel()
6719 set_tensor = conv_fn(torch.randperm(set_count).view(set_size).double())
6726 reference = conv_fn(torch.arange(0., 20).view(4, 5))
6730 [slice(
None), [1, 3]],
6733 [[0, 2], slice(
None)],
6736 [slice(
None), [[0, 1],
6741 [slice(
None), [-1]],
6745 get_indices_to_test = indices_to_test + [[slice(
None), [0, 1, 1, 2, 2]]]
6747 for indexer
in get_indices_to_test:
6748 assert_get_eq(reference, indexer)
6750 for indexer
in indices_to_test:
6751 assert_set_eq(reference, indexer, 44)
6752 assert_set_eq(reference,
6754 get_set_tensor(reference, indexer))
6756 reference = conv_fn(torch.arange(0., 160).view(4, 8, 5))
6759 [slice(
None), slice(
None), [0, 3, 4]],
6760 [slice(
None), [2, 4, 5, 7], slice(
None)],
6761 [[2, 3], slice(
None), slice(
None)],
6762 [slice(
None), [0, 2, 3], [1, 3, 4]],
6763 [slice(
None), [0], [1, 2, 4]],
6764 [slice(
None), [0, 1, 3], [4]],
6765 [slice(
None), [[0, 1], [1, 0]], [[2, 3]]],
6766 [slice(
None), [[0, 1], [2, 3]], [[0]]],
6767 [slice(
None), [[5, 6]], [[0, 3], [4, 4]]],
6768 [[0, 2, 3], [1, 3, 4], slice(
None)],
6769 [[0], [1, 2, 4], slice(
None)],
6770 [[0, 1, 3], [4], slice(
None)],
6771 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(
None)],
6772 [[[0, 1], [1, 0]], [[2, 3]], slice(
None)],
6773 [[[0, 1], [2, 3]], [[0]], slice(
None)],
6774 [[[2, 1]], [[0, 3], [4, 4]], slice(
None)],
6775 [[[2]], [[0, 3], [4, 1]], slice(
None)],
6779 [[0, 2], slice(
None)],
6781 [[0, 2], slice(
None), Ellipsis],
6782 [[0, 2], Ellipsis, slice(
None)],
6784 [[0, 2], [1, 3], Ellipsis],
6785 [Ellipsis, [1, 3], [2, 3]],
6786 [Ellipsis, [2, 3, 4]],
6787 [Ellipsis, slice(
None), [2, 3, 4]],
6788 [slice(
None), Ellipsis, [2, 3, 4]],
6791 [Ellipsis, slice(
None), slice(
None), [0, 3, 4]],
6792 [slice(
None), Ellipsis, slice(
None), [0, 3, 4]],
6793 [slice(
None), slice(
None), Ellipsis, [0, 3, 4]],
6794 [slice(
None), slice(
None), [0, 3, 4], Ellipsis],
6795 [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(
None)],
6796 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(
None)],
6797 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(
None), Ellipsis],
6800 for indexer
in indices_to_test:
6801 assert_get_eq(reference, indexer)
6802 assert_set_eq(reference, indexer, 212)
6803 assert_set_eq(reference,
6805 get_set_tensor(reference, indexer))
6807 reference = conv_fn(torch.arange(0., 1296).view(3, 9, 8, 6))
6810 [slice(
None), slice(
None), slice(
None), [0, 3, 4]],
6811 [slice(
None), slice(
None), [2, 4, 5, 7], slice(
None)],
6812 [slice(
None), [2, 3], slice(
None), slice(
None)],
6813 [[1, 2], slice(
None), slice(
None), slice(
None)],
6814 [slice(
None), slice(
None), [0, 2, 3], [1, 3, 4]],
6815 [slice(
None), slice(
None), [0], [1, 2, 4]],
6816 [slice(
None), slice(
None), [0, 1, 3], [4]],
6817 [slice(
None), slice(
None), [[0, 1], [1, 0]], [[2, 3]]],
6818 [slice(
None), slice(
None), [[0, 1], [2, 3]], [[0]]],
6819 [slice(
None), slice(
None), [[5, 6]], [[0, 3], [4, 4]]],
6820 [slice(
None), [0, 2, 3], [1, 3, 4], slice(
None)],
6821 [slice(
None), [0], [1, 2, 4], slice(
None)],
6822 [slice(
None), [0, 1, 3], [4], slice(
None)],
6823 [slice(
None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(
None)],
6824 [slice(
None), [[0, 1], [3, 4]], [[2, 3]], slice(
None)],
6825 [slice(
None), [[0, 1], [3, 2]], [[0]], slice(
None)],
6826 [slice(
None), [[2, 1]], [[0, 3], [6, 4]], slice(
None)],
6827 [slice(
None), [[2]], [[0, 3], [4, 2]], slice(
None)],
6828 [[0, 1, 2], [1, 3, 4], slice(
None), slice(
None)],
6829 [[0], [1, 2, 4], slice(
None), slice(
None)],
6830 [[0, 1, 2], [4], slice(
None), slice(
None)],
6831 [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(
None), slice(
None)],
6832 [[[0, 1], [1, 2]], [[2, 0]], slice(
None), slice(
None)],
6833 [[[2, 2]], [[0, 3], [4, 5]], slice(
None), slice(
None)],
6834 [[[2]], [[0, 3], [4, 5]], slice(
None), slice(
None)],
6835 [slice(
None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
6836 [slice(
None), [2, 3, 4], [1, 3, 4], [4]],
6837 [slice(
None), [0, 1, 3], [4], [1, 3, 4]],
6838 [slice(
None), [6], [0, 2, 3], [1, 3, 4]],
6839 [slice(
None), [2, 3, 5], [3], [4]],
6840 [slice(
None), [0], [4], [1, 3, 4]],
6841 [slice(
None), [6], [0, 2, 3], [1]],
6842 [slice(
None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
6843 [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(
None)],
6844 [[2, 0, 1], [1, 2, 3], [4], slice(
None)],
6845 [[0, 1, 2], [4], [1, 3, 4], slice(
None)],
6846 [[0], [0, 2, 3], [1, 3, 4], slice(
None)],
6847 [[0, 2, 1], [3], [4], slice(
None)],
6848 [[0], [4], [1, 3, 4], slice(
None)],
6849 [[1], [0, 2, 3], [1], slice(
None)],
6850 [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(
None)],
6853 [Ellipsis, [0, 3, 4]],
6854 [Ellipsis, slice(
None), [0, 3, 4]],
6855 [Ellipsis, slice(
None), slice(
None), [0, 3, 4]],
6856 [slice(
None), Ellipsis, [0, 3, 4]],
6857 [slice(
None), slice(
None), Ellipsis, [0, 3, 4]],
6858 [slice(
None), [0, 2, 3], [1, 3, 4]],
6859 [slice(
None), [0, 2, 3], [1, 3, 4], Ellipsis],
6860 [Ellipsis, [0, 2, 3], [1, 3, 4], slice(
None)],
6862 [[0], [1, 2, 4], slice(
None)],
6863 [[0], [1, 2, 4], Ellipsis],
6864 [[0], [1, 2, 4], Ellipsis, slice(
None)],
6866 [[0, 2, 1], [3], [4]],
6867 [[0, 2, 1], [3], [4], slice(
None)],
6868 [[0, 2, 1], [3], [4], Ellipsis],
6869 [Ellipsis, [0, 2, 1], [3], [4]],
6872 for indexer
in indices_to_test:
6873 assert_get_eq(reference, indexer)
6874 assert_set_eq(reference, indexer, 1333)
6875 assert_set_eq(reference,
6877 get_set_tensor(reference, indexer))
6878 indices_to_test += [
6879 [slice(
None), slice(
None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
6880 [slice(
None), slice(
None), [[2]], [[0, 3], [4, 4]]],
6882 for indexer
in indices_to_test:
6883 assert_get_eq(reference, indexer)
6884 assert_set_eq(reference, indexer, 1333)
6886 def test_advancedindex(self):
6887 self._test_advancedindex(self,
lambda x: x)
6890 def _test_advancedindex_big(self, conv_fn):
6891 reference = conv_fn(torch.arange(0, 123344).int())
6893 self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
6894 torch.LongTensor([0, 123, 44488, 68807, 123343]))
6896 def test_advancedindex_big(self):
6897 self._test_advancedindex_big(self,
lambda x: x)
6899 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
6900 def test_newaxis_numpy_comparison(self):
6901 def run_test(tensor, *idx):
6902 npt = tensor.numpy()
6903 self.assertEqual(tensor[idx], npt[idx])
6906 x = torch.arange(0, 10)
6914 [Ellipsis,
None, 2],
6915 [Ellipsis, 2,
None],
6916 [2, Ellipsis,
None],
6917 [2,
None, Ellipsis],
6918 [
None, 2, Ellipsis],
6919 [
None, Ellipsis, 2],
6926 x = torch.arange(0, 12).view(3, 4)
6932 [Ellipsis,
None,
None],
6934 [
None, Ellipsis,
None],
6935 [
None,
None, Ellipsis],
6937 [2,
None, Ellipsis],
6938 [2, Ellipsis,
None],
6939 [
None, 2, Ellipsis],
6940 [Ellipsis, 2,
None],
6941 [Ellipsis,
None, 2],
6942 [
None, Ellipsis, 2],
6944 [1, 2, Ellipsis,
None],
6945 [1, Ellipsis, 2,
None],
6946 [Ellipsis, 1,
None, 2],
6947 [Ellipsis, 1, 2,
None],
6948 [1,
None, 2, Ellipsis],
6949 [
None, 1, Ellipsis, 2],
6950 [
None, 1, 2, Ellipsis],
6956 def test_newindex(self):
6957 reference = self._consecutive((3, 3, 3))
6960 def checkPartialAssign(index):
6961 reference = torch.zeros(3, 3, 3)
6962 reference[index] = self._consecutive((3, 3, 3))[index]
6963 self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], 0)
6964 reference[index] = 0
6965 self.assertEqual(reference, torch.zeros(3, 3, 3), 0)
6967 checkPartialAssign(0)
6968 checkPartialAssign(1)
6969 checkPartialAssign(2)
6970 checkPartialAssign((0, 1))
6971 checkPartialAssign((1, 2))
6972 checkPartialAssign((0, 2))
6973 checkPartialAssign(torch.LongTensor((0, 2)))
6975 with self.assertRaises(IndexError):
6976 reference[1, 1, 1, 1] = 1
6977 with self.assertRaises(IndexError):
6978 reference[1, 1, 1, (1, 1)] = 1
6979 with self.assertRaises(IndexError):
6980 reference[3, 3, 3, 3, 3, 3, 3, 3] = 1
6981 with self.assertRaises(IndexError):
6983 with self.assertRaises(TypeError):
6984 reference[0.0:2.0] = 1
6985 with self.assertRaises(IndexError):
6986 reference[0.0, 0.0:2.0] = 1
6987 with self.assertRaises(IndexError):
6988 reference[0.0, :, 0.0:2.0] = 1
6989 with self.assertRaises(IndexError):
6990 reference[0.0, ..., 0.0:2.0] = 1
6991 with self.assertRaises(IndexError):
6992 reference[0.0, :, 0.0] = 1
6994 def test_index_copy(self):
6995 num_copy, num_dest = 3, 20
6996 dest = torch.randn(num_dest, 4, 5)
6997 src = torch.randn(num_copy, 4, 5)
6998 idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
6999 dest2 = dest.clone()
7000 dest.index_copy_(0, idx, src)
7001 for i
in range(idx.size(0)):
7002 dest2[idx[i]] = src[i]
7003 self.assertEqual(dest, dest2, 0)
7005 dest = torch.randn(num_dest)
7006 src = torch.randn(num_copy)
7007 idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7008 dest2 = dest.clone()
7009 dest.index_copy_(0, idx, src)
7010 for i
in range(idx.size(0)):
7011 dest2[idx[i]] = src[i]
7012 self.assertEqual(dest, dest2, 0)
7014 def test_index_add(self):
7015 num_copy, num_dest = 3, 3
7016 dest = torch.randn(num_dest, 4, 5)
7017 src = torch.randn(num_copy, 4, 5)
7018 idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7019 dest2 = dest.clone()
7020 dest.index_add_(0, idx, src)
7021 for i
in range(idx.size(0)):
7022 dest2[idx[i]] += src[i]
7023 self.assertEqual(dest, dest2)
7025 dest = torch.randn(num_dest)
7026 src = torch.randn(num_copy)
7027 idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7028 dest2 = dest.clone()
7029 dest.index_add_(0, idx, src)
7030 for i
in range(idx.size(0)):
7031 dest2[idx[i]] = dest2[idx[i]] + src[i]
7032 self.assertEqual(dest, dest2)
7034 def test_index_select(self):
7035 src = torch.randn(3, 4, 5)
7037 idx = torch.LongTensor([2, 1, 0, 1, 2])
7038 dest = torch.index_select(src, 0, idx)
7039 self.assertEqual(dest.shape, (5, 4, 5))
7040 for i
in range(idx.size(0)):
7041 self.assertEqual(dest[i], src[idx[i]])
7044 out = torch.randn(5 * 4 * 5)
7045 dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5))
7046 self.assertEqual(dest.shape, (5, 4, 5))
7047 for i
in range(idx.size(0)):
7048 self.assertEqual(dest[i], src[idx[i]])
7050 self.assertEqual(out, dest.view(-1))
7055 self.assertEqual(x, x.t())
7057 self.assertEqual(x, x.t())
7061 self.assertEqual(x, x.t())
7063 self.assertEqual(x, x.t())
7066 x = torch.rand((2, 2))
7067 self.assertEqual(x.t(), x.transpose(0, 1))
7069 self.assertEqual(x.t(), x.transpose(0, 1))
7072 x = torch.rand((2, 2, 2))
7073 with self.assertRaisesRegex(RuntimeError,
'expects a tensor with <= 2 dimensions, but self is 3D'):
7076 with self.assertRaisesRegex(RuntimeError,
'expects a tensor with <= 2 sparse and 0 dense dimensions'):
7079 def test_take(self):
7080 def check(src, idx):
7081 expected = src.contiguous().view(-1).index_select(
7082 0, idx.contiguous().view(-1)).view_as(idx)
7083 actual = src.take(idx)
7084 self.assertEqual(actual.size(), idx.size())
7085 self.assertEqual(expected, actual)
7087 src = torch.randn(2, 3, 5)
7088 idx = torch.LongTensor([[0, 2], [3, 4]])
7090 check(src.transpose(1, 2), idx)
7092 def test_take_empty(self):
7094 for device
in devices:
7095 for input_shape
in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
7096 for indices_shape
in [(0,), (0, 1, 2, 0)]:
7097 input = torch.empty(input_shape, device=device)
7098 indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
7099 self.assertEqual(indices, torch.take(input, indices))
7101 def test_put_(self):
7102 def check(dst, idx, value):
7103 expected = dst.clone().view(-1).index_copy_(
7104 0, idx.contiguous().view(-1), value.contiguous().view(-1))
7105 expected = expected.view_as(dst)
7106 dst.put_(idx, value)
7107 self.assertEqual(expected, dst)
7109 dst = torch.randn(2, 3, 5)
7110 idx = torch.LongTensor([[0, 2], [3, 4]])
7111 values = torch.randn(2, 2)
7112 check(dst, idx, values)
7113 check(dst.transpose(1, 2), idx, values)
7115 def test_put_accumulate(self):
7116 dst = torch.ones(2, 2)
7117 idx = torch.LongTensor([[0, 1], [0, 1]])
7118 src = torch.Tensor([1, 2, 3, 4])
7119 dst.put_(idx, src, accumulate=
True)
7120 self.assertEqual(dst.tolist(), [[5, 7], [1, 1]])
7122 def test_put_empty(self):
7124 for device
in devices:
7125 for dst_shape
in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
7126 for indices_shape
in [(0,), (0, 1, 2, 0)]:
7127 for accumulate
in [
False,
True]:
7128 dst = torch.randn(dst_shape, device=device)
7129 indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
7130 src = torch.randn(indices_shape, device=device)
7131 self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate))
7135 def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
7136 for i
in range(1
if dim == 0
else m):
7137 for j
in range(1
if dim == 1
else n):
7138 for k
in range(1
if dim == 2
else o):
7140 ii[dim] = slice(0, idx.size(dim) + 1)
7141 idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
7143 def test_flatten(self):
7144 src = torch.randn(5, 5, 5, 5)
7145 flat = src.flatten(0, -1)
7146 self.assertEqual(flat.shape, torch.Size([625]))
7147 self.assertEqual(src.view(-1), flat.view(-1))
7149 flat = src.flatten(0, 2)
7150 self.assertEqual(flat.shape, torch.Size([125, 5]))
7151 self.assertEqual(src.view(-1), flat.view(-1))
7153 flat = src.flatten(0, 1)
7154 self.assertEqual(flat.shape, torch.Size([25, 5, 5]))
7155 self.assertEqual(src.view(-1), flat.view(-1))
7157 flat = src.flatten(1, 2)
7158 self.assertEqual(flat.shape, torch.Size([5, 25, 5]))
7159 self.assertEqual(src.view(-1), flat.view(-1))
7161 flat = src.flatten(2, 3)
7162 self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
7163 self.assertEqual(src.view(-1), flat.view(-1))
7165 flat = src.flatten(-2, -1)
7166 self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
7167 self.assertEqual(src.view(-1), flat.view(-1))
7169 flat = src.flatten(2, 2)
7170 self.assertEqual(flat, src)
7173 with self.assertRaisesRegex(IndexError,
'Dimension out of range'):
7177 with self.assertRaisesRegex(RuntimeError,
'start_dim cannot come after end_dim'):
7181 def _test_gather(self, cast, test_bounds=True):
7182 m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
7183 elems_per_row = random.randint(1, 10)
7184 dim = random.randrange(3)
7186 src = torch.randn(m, n, o)
7187 idx_size = [m, n, o]
7188 idx_size[dim] = elems_per_row
7189 idx = torch.LongTensor().resize_(*idx_size)
7190 _TestTorchMixin._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
7195 actual = torch.gather(src, dim, idx)
7196 expected = cast(torch.Tensor().resize_(*idx_size))
7197 for i
in range(idx_size[0]):
7198 for j
in range(idx_size[1]):
7199 for k
in range(idx_size[2]):
7201 ii[dim] = idx[i, j, k]
7202 expected[i, j, k] = src[tuple(ii)]
7203 self.assertEqual(actual, expected, 0)
7207 self.assertRaises(RuntimeError,
lambda: torch.gather(src, dim, idx))
7209 src = cast(torch.randn(3, 4, 5))
7210 expected, idx = src.max(2,
True)
7211 expected = cast(expected)
7213 actual = torch.gather(src, 2, idx)
7214 self.assertEqual(actual, expected, 0)
7216 def test_gather(self):
7217 self._test_gather(self,
lambda t: t)
7220 def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True):
7221 m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
7222 elems_per_row = random.randint(1, 10)