22 from torch._six import inf, nan, string_classes, istuple
23 from itertools
import product, combinations, combinations_with_replacement
24 from functools
import reduce
25 from torch
import multiprocessing
as mp
26 from common_methods_invocations
import tri_tests_args, run_additional_tri_tests, \
27 _compare_trilu_indices
28 from common_utils
import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
29 TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
30 IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
31 IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist
32 from multiprocessing.reduction
import ForkingPickler
36 load_tests = load_tests
42 from scipy
import signal
49 can_retrieve_source =
True 50 with warnings.catch_warnings(record=
True)
as warns:
51 with tempfile.NamedTemporaryFile()
as checkpoint:
52 x = torch.save(torch.nn.Module(), checkpoint)
54 if "Couldn't retrieve source code" in warn.message.args[0]:
55 can_retrieve_source =
False 60 def __init__(self, data, has_fileno=True, has_readinto=False):
72 def result(*args, **kwargs):
74 return fn(*args, **kwargs)
77 for attr
in [
'read',
'readline',
'seek',
'tell',
'write',
'flush']:
78 traced_fn = trace(getattr(self.
bytesio, attr), attr)
79 setattr(self, attr, traced_fn)
82 raise io.UnsupportedOperation(
'Not a real file')
84 def readinto_opt(self, view):
85 self.calls.add(
'readinto')
86 return self.bytesio.readinto(view)
88 def was_called(self, name):
89 return name
in self.
calls 96 def __exit__(self, *args):
102 class _TestTorchMixin(object):
103 def _check_sum_dim(tensors, dim):
104 for tensor
in tensors:
105 expected = tensor.numpy().sum(dim)
106 actual = tensor.sum(dim)
107 self.assertEqual(expected.shape, actual.shape)
108 if actual.dtype == torch.float:
109 self.assertTrue(np.allclose(expected, actual.numpy(), rtol=1e-03, atol=1e-05))
111 self.assertTrue(np.allclose(expected, actual.numpy()))
113 def _make_tensors(self, shape, val_range=(-100, 100), use_floating=
True, use_integral=
True):
114 float_types = [torch.double,
116 int_types = [torch.int64,
120 def make_contiguous(shape, dtype):
121 if dtype
in float_types:
122 val = torch.randn(shape, dtype=dtype)
123 val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
124 val = val + ((val_range[1] - val_range[0]) / 2.0)
125 val = torch.clamp(val, min=val_range[0], max=val_range[1])
127 result = torch.zeros(shape, dtype=dtype)
128 result.apply_(
lambda x: random.randint(val_range[0], val_range[1]))
131 def make_non_contiguous(shape, dtype):
132 contig = make_contiguous(shape, dtype)
133 non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
134 non_contig = non_contig.select(-1, -1)
135 non_contig.copy_(contig)
136 self.assertFalse(non_contig.is_contiguous())
139 def make_contiguous_slice(size, dtype):
140 contig = make_contiguous((1, size), dtype)
141 non_contig = contig[:1, 1:size - 1]
142 self.assertTrue(non_contig.is_contiguous())
150 tensors = {
"cont": [],
"noncont": [],
"slice": []}
152 tensors[
"cont"].append(make_contiguous(shape, dtype))
153 tensors[
"noncont"].append(make_non_contiguous(shape, dtype))
154 tensors[
"slice"].append(make_contiguous_slice(sum(list(shape)), dtype))
162 checked_types = (types.MethodType, types.FunctionType,
163 types.BuiltinFunctionType, types.BuiltinMethodType)
165 def test_namespace(ns, *skips):
166 if isinstance(ns, object):
167 ns_name = ns.__class__.__name__
169 ns_name = ns.__name__
172 if isinstance(r, string_classes):
173 skip_regexes.append(re.compile(
'^{}$'.format(re.escape(r))))
175 skip_regexes.append(r)
177 if name.startswith(
'_'):
179 var = getattr(ns, name)
180 if not isinstance(var, checked_types):
183 has_doc = doc
is not None and len(doc.strip()) > 0
184 full_name = ns_name +
'.' + name
185 if any(r.match(name)
for r
in skip_regexes):
186 self.assertFalse(has_doc,
187 'New docs have been added for {}, please remove ' 188 'it from the skipped list in TestTorch.test_doc'.format(full_name))
190 self.assertTrue(has_doc,
'{} is missing documentation'.format(full_name))
193 test_namespace(torch.randn(1),
196 re.compile(
'^clamp_(min|max)_?$'),
229 'sparse_resize_and_clear_',
232 test_namespace(
torch.nn.functional,
'assert_int_or_pair',
'bilinear',
'feature_alpha_dropout')
238 'torch.DoubleTensor': 1e-8,
239 'torch.FloatTensor': 1e-4,
241 for tname, _prec
in types.items():
242 v1 = torch.randn(100).type(tname)
243 v2 = torch.randn(100).type(tname)
244 res1 = torch.dot(v1, v2)
246 for i, j
in zip(v1, v2):
248 self.assertEqual(res1, res2)
249 out = torch.randn(()).type(tname)
250 torch.dot(v1, v2, out=out)
251 self.assertEqual(res1, out)
254 for tname, _prec
in types.items():
255 v1 = torch.randn(1).type(tname).expand(100)
256 v2 = torch.randn(100).type(tname)
257 res1 = torch.dot(v1, v2)
259 for i, j
in zip(v1, v2):
261 self.assertEqual(res1, res2)
262 out = torch.randn(()).type(tname)
263 torch.dot(v1, v2, out=out)
264 self.assertEqual(res1, out)
268 'torch.DoubleTensor': 1e-8,
269 'torch.FloatTensor': 1e-4,
271 for tname, _prec
in types.items():
272 v1 = torch.randn(100).type(tname)
273 v2 = torch.randn(100).type(tname)
274 res1 = torch.ger(v1, v2)
275 res2 = torch.zeros(100, 100).type(tname)
278 res2[i, j] = v1[i] * v2[j]
279 self.assertEqual(res1, res2)
282 for tname, _prec
in types.items():
283 v1 = torch.randn(1).type(tname).expand(100)
284 v2 = torch.randn(100).type(tname)
285 res1 = torch.ger(v1, v2)
286 res2 = torch.zeros(100, 100).type(tname)
289 res2[i, j] = v1[i] * v2[j]
290 self.assertEqual(res1, res2)
294 'torch.DoubleTensor': 1e-8,
295 'torch.FloatTensor': 1e-4,
298 def run_test(m, v1, v2, m_transform=lambda x: x):
299 m = m_transform(m.clone())
301 torch.addr(m, v1, v2, out=m)
302 for i
in range(m.size(0)):
303 for j
in range(m.size(1)):
304 ref[i, j] += v1[i] * v2[j]
305 self.assertEqual(m, ref)
307 for tname, _prec
in types.items():
308 for h, w
in [(100, 110), (1, 20), (200, 2)]:
309 m = torch.randn(h, w).type(tname)
310 v1 = torch.randn(h).type(tname)
311 v2 = torch.randn(w).type(tname)
314 run_test(m, v2, v1,
lambda x: x.transpose(0, 1))
316 v1 = torch.randn(1).type(tname).expand(h)
318 run_test(m, v2, v1,
lambda x: x.transpose(0, 1))
320 def test_addmv(self):
322 'torch.DoubleTensor': 1e-8,
323 'torch.FloatTensor': 1e-4,
325 for tname, _prec
in types.items():
326 t = torch.randn(10).type(tname)
327 m = torch.randn(10, 100).type(tname)
328 v = torch.randn(100).type(tname)
329 res1 = torch.addmv(t, m, v)
330 res2 = torch.zeros(10).type(tname)
334 res2[i] += m[i, j] * v[j]
335 self.assertEqual(res1, res2)
338 for tname, _prec
in types.items():
339 t = torch.randn(1).type(tname).expand(10)
340 m = torch.randn(10, 1).type(tname).expand(10, 100)
341 v = torch.randn(100).type(tname)
342 res1 = torch.addmv(t, m, v)
343 res2 = torch.zeros(10).type(tname)
347 res2[i] += m[i, j] * v[j]
348 self.assertEqual(res1, res2)
350 def test_addmm(self):
352 'torch.DoubleTensor': 1e-8,
353 'torch.FloatTensor': 1e-4,
355 for tname, _prec
in types.items():
356 M = torch.randn(10, 25).type(tname)
357 m1 = torch.randn(10, 50).type(tname)
358 m2 = torch.randn(50, 25).type(tname)
359 res1 = torch.addmm(M, m1, m2)
360 res2 = torch.zeros(10, 25).type(tname)
365 res2[i, j] += m1[i, k] * m2[k, j]
366 self.assertEqual(res1, res2)
369 for tname, _prec
in types.items():
370 M = torch.randn(10, 1).type(tname).expand(10, 25)
371 m1 = torch.randn(10, 1).type(tname).expand(10, 50)
372 m2 = torch.randn(50, 25).type(tname)
373 res1 = torch.addmm(M, m1, m2)
374 res2 = torch.zeros(10, 25).type(tname)
379 res2[i, j] += m1[i, k] * m2[k, j]
380 self.assertEqual(res1, res2)
382 def test_logical_any(self):
384 for device
in devices:
385 x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device)
392 torch.zeros([1, 3, 400], dtype=torch.uint8, device=device),
393 x.any(0, keepdim=
True))
396 torch.zeros([2, 1, 400], dtype=torch.uint8, device=device),
397 x.any(1, keepdim=
True))
400 torch.zeros([2, 3, 1], dtype=torch.uint8, device=device),
401 x.any(2, keepdim=
True))
410 y = torch.zeros([1, 3, 400], dtype=torch.uint8, device=device)
412 self.assertEqual(y, x.any(0, keepdim=
True))
414 y = torch.zeros([2, 1, 400], dtype=torch.uint8, device=device)
416 self.assertEqual(y, x.any(1, keepdim=
True))
418 y = torch.zeros([2, 3, 1], dtype=torch.uint8, device=device)
420 self.assertEqual(y, x.any(2, keepdim=
True))
422 def test_logical_all(self):
424 for device
in devices:
425 x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device)
432 torch.ones([1, 3, 400], dtype=torch.uint8, device=device),
433 x.all(0, keepdim=
True))
436 torch.ones([2, 1, 400], dtype=torch.uint8, device=device),
437 x.all(1, keepdim=
True))
440 torch.ones([2, 3, 1], dtype=torch.uint8, device=device),
441 x.all(2, keepdim=
True))
450 y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device)
452 self.assertEqual(y, x.all(0, keepdim=
True))
454 y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device)
456 self.assertEqual(y, x.all(1, keepdim=
True))
458 y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device)
460 self.assertEqual(y, x.all(2, keepdim=
True))
462 def test_allclose(self):
465 self.assertTrue(torch.allclose(x, y, rtol=0, atol=0.02))
466 self.assertTrue(torch.allclose(x, y, rtol=0.01, atol=0.0))
467 self.assertFalse(torch.allclose(x, y))
471 self.assertFalse(torch.allclose(x, y, rtol=1e-2))
472 self.assertTrue(torch.allclose(x, y, rtol=1e-2, equal_nan=
True))
473 self.assertFalse(torch.allclose(x, y, rtol=1e-3, equal_nan=
True))
475 self.assertTrue(torch.allclose(inf_t, inf_t))
476 self.assertTrue(torch.allclose(-inf_t, -inf_t))
477 self.assertFalse(torch.allclose(inf_t, -inf_t))
478 self.assertFalse(torch.allclose(inf_t,
torch.tensor([1e20])))
479 self.assertFalse(torch.allclose(-inf_t,
torch.tensor([-1e20])))
481 def test_linear_algebra_scalar_raises(self):
482 m = torch.randn(5, 5)
485 self.assertRaises(RuntimeError,
lambda: torch.mv(m, s))
486 self.assertRaises(RuntimeError,
lambda: torch.addmv(v, m, s))
487 self.assertRaises(RuntimeError,
lambda: torch.ger(v, s))
488 self.assertRaises(RuntimeError,
lambda: torch.ger(s, v))
489 self.assertRaises(RuntimeError,
lambda: torch.addr(m, v, s))
490 self.assertRaises(RuntimeError,
lambda: torch.addr(m, s, v))
492 def _test_math(self, torchfn, mathfn, input=None, test_expand=False):
495 input.append(list(range(-5, 5)))
496 input.append([0
for x
in range(-5, 5)])
497 input.append([x + 1e-6
for x
in range(-5, 5)])
499 input.append([x + 1e10
for x
in range(-5, 5)])
500 input.append([x - 1e10
for x
in range(-5, 5)])
501 input.append(torch.randn(10).tolist())
502 input.append((torch.randn(10) + 1e6).tolist())
503 input.append([math.pi * (x / 2)
for x
in range(-5, 5)])
505 def compare_reference(input, dtype):
507 res1 = torchfn(input.clone())
508 res2 = input.clone().apply_(mathfn)
512 compare_reference(input, torch.double)
513 compare_reference(input, torch.float)
515 def check_non_contiguous(shape, dtype):
516 contig = torch.randn(shape, dtype=dtype)
517 non_contig = torch.empty(shape + (2,), dtype=dtype)[..., 0]
518 non_contig.copy_(contig)
519 self.assertFalse(non_contig.is_contiguous())
520 self.assertEqual(torchfn(contig), torchfn(non_contig),
'non-contiguous')
523 check_non_contiguous((5, 7), torch.double)
524 check_non_contiguous((1024,), torch.double)
525 check_non_contiguous((5, 7), torch.float)
526 check_non_contiguous((1024,), torch.float)
528 def check_non_contiguous_index(dtype):
529 contig = torch.randn((2, 2, 1, 2), dtype=dtype)
530 non_contig = contig[:, 1, ...]
531 contig = non_contig.clone()
532 self.assertFalse(non_contig.is_contiguous())
533 self.assertEqual(torchfn(contig), torchfn(non_contig),
'non-contiguous index')
535 check_non_contiguous_index(torch.float)
536 check_non_contiguous_index(torch.double)
538 def check_non_contiguous_expand(shape, dtype):
539 contig = torch.randn(shape, dtype=dtype)
540 non_contig = contig.clone().expand(3, -1, -1)
541 self.assertFalse(non_contig.is_contiguous())
542 contig = torchfn(contig)
543 non_contig = torchfn(non_contig)
545 self.assertEqual(contig, non_contig[i],
'non-contiguous expand[' + str(i) +
']')
550 check_non_contiguous_expand((1, 3), torch.double)
551 check_non_contiguous_expand((1, 7), torch.double)
552 check_non_contiguous_expand((5, 7), torch.float)
556 def check_contiguous_size1(dtype):
557 contig = torch.randn((5, 100), dtype=dtype)
558 contig = contig[:1, :50]
559 contig2 = torch.empty(contig.size(), dtype=dtype)
560 contig2.copy_(contig)
561 self.assertTrue(contig.is_contiguous())
562 self.assertTrue(contig2.is_contiguous())
563 self.assertEqual(torchfn(contig), torchfn(contig2),
'contiguous size1')
565 check_contiguous_size1(torch.double)
566 check_contiguous_size1(torch.float)
568 def check_contiguous_size1_largedim(dtype):
569 contig = torch.randn((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype)
570 contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
571 contig2 = torch.empty(contig.size(), dtype=dtype)
572 contig2.copy_(contig)
573 self.assertTrue(contig.is_contiguous())
574 self.assertTrue(contig2.is_contiguous())
575 self.assertEqual(torchfn(contig), torchfn(contig2),
'contiguous size1')
577 check_contiguous_size1_largedim(torch.double)
578 check_contiguous_size1_largedim(torch.float)
580 def check_large(dtype):
581 input = torch.randn(1024, 512, dtype=dtype)
582 actual = torchfn(input)
583 expected = torch.stack([torchfn(slice)
for slice
in input])
584 self.assertEqual(actual, expected,
'large')
588 check_large(torch.double)
589 check_large(torch.float)
591 def __test_math_by_name(self, function_name, mathfn, selffn):
592 mathfn = getattr(math, mathfn)
595 return getattr(x, function_name)()
597 torchfn = getattr(torch, function_name)
598 self._test_math(torchfn, mathfn, test_expand=(
not selffn))
600 def _test_math_by_name(self, function_name, test_self=True):
602 self.__test_math_by_name(function_name +
"_", function_name,
True)
603 self.__test_math_by_name(function_name, function_name,
False)
606 self._test_math_by_name(
'sin')
612 except OverflowError:
613 return inf
if x > 0
else -inf
614 self._test_math(torch.sinh, sinh)
616 def test_lgamma(self):
618 if x <= 0
and x == int(x):
620 return math.lgamma(x)
621 self._test_math(torch.lgamma, lgamma)
623 @unittest.skipIf(
not TEST_SCIPY,
"Scipy not found")
624 def test_mvlgamma(self):
625 from scipy.special
import multigammaln
626 for d
in range(1, 5):
627 input = torch.empty(10).uniform_(d, 10)
628 res_torch = torch.mvlgamma(input, d)
629 res_scipy = multigammaln(input.numpy(), d)
630 self.assertEqual(res_torch.numpy(), res_scipy)
632 def test_mvlgamma_argcheck(self):
634 input = torch.linspace((d - 2) / 2, 10, 10)
635 torch.mvlgamma(input, d)
637 with self.assertRaisesRegex(RuntimeError,
"Condition for computing multivariate log-gamma not met"):
640 def _digamma_input(self, test_poles=True):
642 input.append((torch.randn(10).abs() + 1e-4).tolist())
643 input.append((torch.randn(10).abs() + 1e6).tolist())
644 zeros = torch.linspace(-9.5, -0.5, 10)
645 input.append(zeros.tolist())
646 input.append((zeros - 0.49).tolist())
647 input.append((zeros + 0.49).tolist())
648 input.append((zeros + (torch.rand(10) * 0.99) - 0.5).tolist())
651 input.append([-0.999999994, -1.999999994, -2.0000000111,
652 -100.99999994, -1931.99999994, 0.000000111,
653 -0.000000111, 0, -2, -329])
656 @unittest.skipIf(
not TEST_SCIPY,
"Scipy not found")
657 def test_digamma(self):
658 from scipy.special
import digamma
661 def torch_digamma_without_inf(inp):
662 res = torch.digamma(inp)
663 res[(res == -inf) | (res == inf)] = nan
666 def scipy_digamma_without_inf(inp):
669 return res
if np.isfinite(res)
else nan
670 res[np.isinf(res)] = nan
673 self._test_math(torch_digamma_without_inf, scipy_digamma_without_inf, self._digamma_input())
675 @unittest.skipIf(
not TEST_SCIPY,
"Scipy not found")
676 def test_polygamma(self):
677 from scipy.special
import polygamma
679 self._test_math(
lambda x: torch.polygamma(n, x),
680 lambda x: polygamma(n, x).item(),
681 self._digamma_input(test_poles=
False))
684 self._test_math(torch.asin,
lambda x: math.asin(x)
if abs(x) <= 1
else nan)
687 self._test_math_by_name(
'cos')
693 except OverflowError:
697 self._test_math(torch.cosh, cosh)
700 self._test_math(torch.acos,
lambda x: math.acos(x)
if abs(x) <= 1
else nan)
703 self._test_math_by_name(
'tan')
706 self._test_math_by_name(
'tanh')
709 self._test_math_by_name(
'atan')
718 self._test_math(torch.log, log)
720 def test_log10(self):
727 self._test_math(torch.log10, log10)
729 def test_log1p(self):
736 self._test_math(torch.log1p, log1p)
746 except AttributeError:
747 return math.log(x, 2)
748 self._test_math(torch.log2, log2)
751 self._test_math(torch.sqrt,
lambda x: math.sqrt(x)
if x >= 0
else nan)
754 self._test_math_by_name(
'erf')
757 self._test_math_by_name(
'erfc')
759 def test_erfinv(self):
760 def checkType(tensor):
761 inputValues = torch.randn(4, 4, out=tensor()).clamp(-2., 2.)
762 self.assertEqual(tensor(inputValues).erf().erfinv(), tensor(inputValues))
764 self.assertTrue(torch.equal(tensor([-1, 1]).erfinv(), tensor([-inf, inf])))
766 self.assertEqual(tensor([-2, 2]).erfinv(), tensor([nan, nan]))
768 checkType(torch.FloatTensor)
769 checkType(torch.DoubleTensor)
775 except OverflowError:
777 self._test_math(torch.exp, exp)
779 def test_expm1(self):
783 except OverflowError:
785 self._test_math(torch.expm1, expm1)
787 def test_floor(self):
788 self._test_math_by_name(
'floor')
791 self._test_math_by_name(
'ceil')
794 def test_ceil_out_cpu_cuda(self):
796 b = torch.randn(1, device=
"cuda")
797 self.assertRaises(RuntimeError,
lambda: torch.ceil(a, out=b))
799 def test_rsqrt(self):
805 return 1.0 / math.sqrt(x)
807 self._test_math(torch.rsqrt, rsqrt)
809 def test_sigmoid(self):
811 inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000]
812 expectedOutput = [0.0000, 0.2689, 0.5, 0.6225, 0.7311, 0.8808, 1.000]
813 precision_4dps = 0.0002
815 def checkType(tensor):
816 self.assertEqual(tensor(inputValues).sigmoid(), tensor(expectedOutput), precision_4dps)
818 checkType(torch.FloatTensor)
819 checkType(torch.DoubleTensor)
822 self._test_math(torch.frac,
lambda x: math.fmod(x, 1))
824 def test_trunc(self):
825 self._test_math(torch.trunc,
lambda x: x - math.fmod(x, 1))
827 def test_round(self):
828 self._test_math(torch.round, round)
830 def test_has_storage(self):
831 self.assertIsNotNone(torch.Tensor().
storage())
832 self.assertIsNotNone(torch.Tensor(0).
storage())
833 self.assertIsNotNone(torch.Tensor([]).
storage())
834 self.assertIsNotNone(torch.Tensor().clone().
storage())
835 self.assertIsNotNone(torch.Tensor([0, 0, 0]).nonzero().
storage())
836 self.assertIsNotNone(torch.Tensor().new().
storage())
838 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
839 def test_has_storage_numpy(self):
840 for dtype
in [np.float32, np.float64, np.int64,
841 np.int32, np.int16, np.uint8]:
842 arr = np.array([1], dtype=dtype)
843 self.assertIsNotNone(torch.FloatTensor(arr).
storage())
844 self.assertIsNotNone(torch.DoubleTensor(arr).
storage())
845 self.assertIsNotNone(torch.IntTensor(arr).
storage())
846 self.assertIsNotNone(torch.LongTensor(arr).
storage())
847 self.assertIsNotNone(torch.ByteTensor(arr).
storage())
849 self.assertIsNotNone(torch.cuda.FloatTensor(arr).
storage())
850 self.assertIsNotNone(torch.cuda.DoubleTensor(arr).
storage())
851 self.assertIsNotNone(torch.cuda.IntTensor(arr).
storage())
852 self.assertIsNotNone(torch.cuda.LongTensor(arr).
storage())
853 self.assertIsNotNone(torch.cuda.ByteTensor(arr).
storage())
855 def _testSelection(self, torchfn, mathfn):
857 m1 = torch.randn(100, 100)
860 for i, j
in iter_indices(m1):
861 res2 = mathfn(res2, m1[i, j])
862 self.assertEqual(res1, res2)
865 m1 = torch.randn(10, 10, 10)
869 for i, j
in iter_indices(m2):
870 res2 = mathfn(res2, m2[i][j])
871 self.assertEqual(res1, res2)
874 m1 = torch.randn(100, 100)
875 res1val, res1ind = torchfn(m1, 1,
False)
876 res2val = m1[:, 0:1].clone().squeeze()
877 res2ind = res1ind.clone().fill_(0)
878 for i, j
in iter_indices(m1):
879 if mathfn(res2val[i], m1[i, j]) != res2val[i]:
880 res2val[i] = m1[i, j]
884 for i
in range(res1val.size(0)):
885 maxerr = max(maxerr, abs(res1val[i] - res2val[i]))
886 self.assertEqual(res1ind[i], res2ind[i])
887 self.assertLessEqual(abs(maxerr), 1e-5)
890 for index
in (0, 4, 99):
891 m1 = torch.randn(100)
893 res1val, res1ind = torch.max(m1, 0)
894 self.assertTrue(math.isnan(res1val))
895 self.assertEqual(res1ind, index)
896 res1val = torchfn(m1)
897 self.assertTrue(math.isnan(res1val))
900 self._testSelection(torch.max, max)
903 def _test_max_with_inf(self, dtypes=(torch.float, torch.double), device=
'cpu'):
905 a =
torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
906 self.assertTrue(torch.all(torch.max(a, dim=1)[0] == inf).item())
907 self.assertTrue(torch.max(a).item() == inf)
909 def test_max_with_inf(self):
910 self._test_max_with_inf(self)
913 self._testSelection(torch.min, min)
916 def _test_min_with_inf(self, dtypes=(torch.float, torch.double), device=
'cpu'):
918 a =
torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
919 self.assertTrue(torch.all(torch.min(a, dim=1)[0] == (-inf)).item())
920 self.assertTrue(torch.min(a).item() == -inf)
922 def test_min_with_inf(self):
923 self._test_min_with_inf(self)
926 def _test_norm(self, device):
928 x = torch.randn(25, device=device)
930 for p
in [0, 1, 2, 3, 4, inf, -inf]:
931 res = x.norm(p).item()
932 expected = np.linalg.norm(xn, p)
933 self.assertEqual(res, expected,
"full reduction failed for {}-norm".format(p))
936 x = torch.randn(25, 25, device=device)
938 for p
in [0, 1, 2, 3, 4, inf, -inf]:
939 res = x.norm(p, 1).cpu().numpy()
940 expected = np.linalg.norm(xn, p, 1)
941 self.assertEqual(res.shape, expected.shape)
942 self.assertTrue(np.allclose(res, expected),
"dim reduction failed for {}-norm".format(p))
945 for p
in [
'fro',
'nuc']:
946 res = x.norm(p).cpu().numpy()
947 expected = np.linalg.norm(xn, p)
948 self.assertEqual(res.shape, expected.shape)
949 self.assertTrue(np.allclose(res, expected),
"dim reduction failed for {}-norm".format(p))
952 self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000)))
954 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
957 self._test_norm(self, device=
'cpu')
960 def _test_dist(self, device):
962 for p
in [0, 1, 2, 3, 4, inf, -inf]:
963 dist_xy = torch.dist(x, y, p)
964 dist_xy_norm = torch.norm(x - y, p)
965 self.assertEqual(dist_xy, dist_xy_norm)
967 run_test(torch.randn(5, device=device), torch.randn(5, device=device))
969 x = torch.zeros(3, device=device)
970 y = torch.zeros(3, device=device)
975 self._test_dist(self, device=
'cpu')
977 def test_dim_reduction_uint8_overflow(self):
978 example = [[-1, 2, 1], [5, 3, 6]]
980 self.assertEqual(x.sum(dtype=torch.uint8).item(), 16)
981 self.assertEqual(x.sum(0, dtype=torch.uint8), torch.FloatTensor([4, 5, 7]))
982 self.assertEqual(x.sum(1, dtype=torch.uint8), torch.FloatTensor([2, 14]))
984 torch.sum(x, 0, out=y)
985 self.assertEqual(x.sum(0, dtype=torch.uint8), y)
988 def _test_dim_reduction(self, cast):
989 example = [[-1, 2, 1], [5, 3, 6]]
991 types = [torch.double,
1002 self.assertEqual(x.sum().item(), 16)
1003 self.assertEqual(x.sum(0), torch.FloatTensor([4, 5, 7]))
1004 self.assertEqual(x.sum(1), torch.FloatTensor([2, 14]))
1006 torch.sum(x, 0, out=y)
1007 self.assertEqual(x.sum(0), y)
1010 for dtype
in types[:2]:
1012 self.assertEqual(x.mean().item(), 16.0 / 6)
1013 self.assertEqual(x.mean(0), torch.FloatTensor([2.0, 2.5, 7.0 / 2]))
1014 self.assertEqual(x.mean(1), torch.FloatTensor([2.0 / 3, 14.0 / 3]))
1015 self.assertEqual(x.mean(), x.mean((0, 1)))
1019 self.assertEqual(x.prod().item(), -180)
1020 self.assertEqual(x.prod(0), torch.FloatTensor([-5, 6, 6]))
1021 self.assertEqual(x.prod(1), torch.FloatTensor([-2, 90]))
1025 self.assertEqual(x.max().item(), 6)
1026 self.assertEqual(x.max(0), (torch.FloatTensor([5, 3, 6]), torch.FloatTensor([1, 1, 1])))
1027 self.assertEqual(x.max(1), (torch.FloatTensor([2, 6]), torch.FloatTensor([1, 2])))
1031 self.assertEqual(x.min().item(), -1)
1032 self.assertEqual(x.min(0), (torch.FloatTensor([-1, 2, 1]), torch.FloatTensor([0, 0, 0])))
1033 self.assertEqual(x.min(1), (torch.FloatTensor([-1, 3]), torch.FloatTensor([0, 1])))
1037 self.assertEqual(x.argmax().item(), 5)
1038 self.assertEqual(x.argmax(dim=
None).item(), 5)
1039 self.assertEqual(x.argmax(dim=0), torch.FloatTensor([1, 1, 1]))
1040 self.assertEqual(x.argmax(dim=1), torch.FloatTensor([1, 2]))
1041 self.assertEqual(x.argmax(dim=0, keepdim=
True), torch.FloatTensor([[1, 1, 1]]))
1043 self.assertEqual(x[:, :2].argmax().item(), 2)
1047 self.assertEqual(x.argmin().item(), 0)
1048 self.assertEqual(x.argmin(dim=
None).item(), 0)
1049 self.assertEqual(x.argmin(dim=0), torch.FloatTensor([0, 0, 0]))
1050 self.assertEqual(x.argmin(dim=1), torch.FloatTensor([0, 1]))
1051 self.assertEqual(x.argmin(dim=1, keepdim=
True), torch.FloatTensor([[0], [1]]))
1053 self.assertEqual(x[:, :2].argmin().item(), 0)
1056 "mean",
"median",
"mode",
"norm",
"prod",
1057 "std",
"sum",
"var",
"max",
"min"]
1059 def normfn_attr(t, dim, keepdim=False, out=None):
1060 attr = getattr(torch,
"norm")
1061 return attr(t, 2, dim, keepdim, out=out)
1063 for fn_name
in dim_red_fns:
1064 fn_attr = getattr(torch, fn_name)
if fn_name !=
"norm" else normfn_attr
1066 def fn(x, dim, keepdim=False, out=None):
1067 ans = fn_attr(x, dim, keepdim=keepdim, out=out)
1068 return ans
if not istuple(ans)
else ans[0]
1070 def fn_tuple(x, dim, keepdim=False, out=None):
1071 return fn_attr(x, dim, keepdim=keepdim, out=out)
1073 def test_multidim(x, dim):
1074 self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=
True))
1075 self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
1076 self.assertEqual(x.ndimension(), fn(x, dim, keepdim=
True).ndimension())
1079 x = cast(torch.randn(3, 4, 5))
1080 dim = random.randint(0, 2)
1081 test_multidim(x, dim)
1084 x = cast(torch.randn(1))
1086 self.assertEqual(fn(x, dim).shape, ())
1087 self.assertEqual(fn(x, dim, keepdim=
True).shape, (1,))
1091 singleton_dim = random.randint(0, 2)
1092 dims[singleton_dim] = 1
1093 x = cast(torch.randn(dims))
1094 test_multidim(x, singleton_dim)
1097 if fn_name
in [
'median',
'mode',
'max',
'min']:
1098 y = cast(torch.randn(5, 3))
1099 values = cast(torch.randn(5, 3))
1100 indices = cast(torch.zeros(5, 3).long() - 1)
1101 fn_tuple(y, 1, keepdim=
False, out=(values[:, 1], indices[:, 1]))
1102 values_expected, indices_expected = fn_tuple(y, 1, keepdim=
False)
1103 self.assertEqual(values[:, 1], values_expected,
1104 '{} values with out= kwarg'.format(fn_name))
1105 self.assertEqual(indices[:, 1], indices_expected,
1106 '{} indices with out= kwarg'.format(fn_name))
1109 x = cast(torch.randn(5, 3))
1110 y = cast(torch.randn(5, 3))
1111 fn(y, 1, keepdim=
False, out=x[:, 1])
1112 expected = fn(y, 1, keepdim=
False)
1113 self.assertEqual(x[:, 1], expected,
'{} with out= kwarg'.format(fn_name))
1115 def test_dim_reduction(self):
1116 self._test_dim_reduction(self,
lambda t: t)
1118 def test_reduction_empty(self):
1121 (
'max', torch.max,
None),
1122 (
'kthvalue',
lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs),
None),
1123 (
'argmax', torch.argmax,
None),
1124 (
'min', torch.min,
None),
1125 (
'argmin', torch.argmin,
None),
1126 (
'mode', torch.mode,
None),
1127 (
'median', torch.median,
None),
1129 (
'prod', torch.prod, 1),
1130 (
'sum', torch.sum, 0),
1131 (
'norm', torch.norm, 0),
1132 (
'mean', torch.mean, nan),
1133 (
'var', torch.var, nan),
1134 (
'std', torch.std, nan),
1135 (
'logsumexp', torch.logsumexp, -inf),
1140 for device
in devices:
1141 x = torch.randn(shape, device=device)
1143 for item
in fns_to_test:
1144 name, fn, identity = item
1145 if identity
is None:
1146 ident_err =
'does not have an identity' 1147 self.assertRaisesRegex(RuntimeError, ident_err,
lambda: fn(x, dim=2))
1148 self.assertRaisesRegex(RuntimeError, ident_err,
lambda: fn(x, dim=2, keepdim=
True))
1149 self.assertRaisesRegex(RuntimeError, ident_err,
lambda: fn(x, dim=1))
1150 self.assertRaisesRegex(RuntimeError, ident_err,
lambda: fn(x, dim=1, keepdim=
True))
1152 self.assertEqual(torch.empty((2, 0), device=device), fn(x, dim=2))
1153 self.assertEqual(torch.empty((2, 0, 1), device=device), fn(x, dim=2, keepdim=
True))
1155 check = (torch.testing.assert_allclose
if math.isnan(identity)
or math.isinf(identity)
else 1157 check(torch.full((2, 4), identity, device=device), fn(x, dim=1))
1158 check(torch.full((2, 1, 4), identity, device=device), fn(x, dim=1, keepdim=
True))
1160 check(torch.full((), identity, device=device), fn(x))
1161 except TypeError
as err:
1163 self.assertTrue(
'required positional arguments: "dim"' in str(err))
1166 xb = x.to(torch.uint8)
1167 yb = x.to(torch.uint8)
1168 self.assertEqual((2, 0), xb.any(2).shape)
1169 self.assertEqual((2, 0, 1), xb.any(2, keepdim=
True).shape)
1170 self.assertEqual(torch.zeros((2, 4), device=device), xb.any(1))
1171 self.assertEqual(torch.zeros((2, 1, 4), device=device), xb.any(1, keepdim=
True))
1172 self.assertEqual(torch.zeros((), device=device), xb.any())
1175 self.assertEqual((2, 0), xb.all(2).shape)
1176 self.assertEqual((2, 0, 1), xb.all(2, keepdim=
True).shape)
1177 self.assertEqual(torch.ones((2, 4), device=device), xb.all(1))
1178 self.assertEqual(torch.ones((2, 1, 4), device=device), xb.all(1, keepdim=
True))
1179 self.assertEqual(torch.ones((), device=device), xb.all())
1181 def test_pairwise_distance_empty(self):
1183 for device
in devices:
1185 x = torch.randn(shape, device=device)
1186 y = torch.randn(shape, device=device)
1188 self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y))
1189 self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=
True))
1192 x = torch.randn(shape, device=device)
1193 y = torch.randn(shape, device=device)
1194 self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y))
1195 self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=
True))
1197 def test_pdist_empty(self):
1199 for device
in devices:
1201 x = torch.randn(shape, device=device)
1202 self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
1205 x = torch.randn(shape, device=device)
1206 self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
1209 x = torch.randn(shape, device=device)
1210 self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
1212 def test_pdist_norm(self):
1213 def test_pdist_single(shape, device, p, dtype, trans):
1214 x = torch.randn(shape, dtype=dtype, device=device)
1216 x.transpose_(-2, -1)
1217 actual = torch.pdist(x, p=p)
1218 expected = brute_pdist(x, p=p)
1219 self.assertEqual(expected.shape, actual.shape)
1220 self.assertTrue(torch.allclose(expected, actual))
1223 for device
in devices:
1224 for shape
in [(4, 5), (3, 2), (2, 1)]:
1225 for p
in [0, 1, 2, 3, 1.5, 2.5, float(
'inf')]:
1226 for trans
in [
False,
True]:
1227 for dtype
in [torch.float32, torch.float64]:
1228 test_pdist_single(shape, device, p, dtype, trans)
1232 for dtype
in [torch.float32, torch.float64]:
1233 test_pdist_single((1000, 2), device, 2, dtype,
False)
1235 def test_cdist_empty(self):
1237 for device
in devices:
1238 x = torch.randn((0, 5), device=device)
1239 y = torch.randn((4, 5), device=device)
1240 self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
1242 x = torch.randn((2, 5), device=device)
1243 y = torch.randn((0, 5), device=device)
1244 self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
1246 x = torch.randn((2, 0), device=device)
1247 y = torch.randn((3, 0), device=device)
1248 self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y))
1250 x = torch.randn((2, 0), device=device)
1251 y = torch.randn((0, 0), device=device)
1252 self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
1254 def test_cdist_norm(self):
1256 for device
in devices:
1257 for r1
in [3, 4, 5, 6]:
1258 for m
in [2, 3, 4, 10]:
1259 for r2
in [4, 6, 7, 8]:
1260 for p
in [0, 1, 2, 3, 1.5, 2.5, float(
'inf')]:
1261 x = torch.randn(r1, m, device=device)
1262 y = torch.randn(r2, m, device=device)
1263 actual = torch.cdist(x, y, p=p)
1264 expected = brute_cdist(x, y, p=p)
1265 self.assertTrue(torch.allclose(expected, actual))
1267 def test_cdist_large(self):
1269 for device
in devices:
1270 x = torch.randn(1000, 10, device=device)
1271 y = torch.randn(1000, 10, device=device)
1272 actual = torch.cdist(x, y, p=2)
1273 expected = brute_cdist(x, y, p=2)
1274 self.assertTrue(torch.allclose(expected, actual))
1276 def test_cdist_non_contiguous(self):
1278 for device
in devices:
1279 x = torch.randn(5, 7, device=device).t()
1280 y = torch.randn(5, 3, device=device).t()
1281 actual = torch.cdist(x, y, p=2)
1282 expected = brute_cdist(x, y, p=2)
1283 self.assertFalse(x.is_contiguous())
1284 self.assertFalse(y.is_contiguous())
1285 self.assertTrue(torch.allclose(expected, actual))
1287 x = torch.randn(7, 5, device=device)
1288 y = torch.randn(5, 3, device=device).t()
1289 actual = torch.cdist(x, y, p=2)
1290 expected = brute_cdist(x, y, p=2)
1291 self.assertTrue(x.is_contiguous())
1292 self.assertFalse(y.is_contiguous())
1293 self.assertTrue(torch.allclose(expected, actual))
1295 x = torch.randn(5, 7, device=device).t()
1296 y = torch.randn(3, 5, device=device)
1297 actual = torch.cdist(x, y, p=2)
1298 expected = brute_cdist(x, y, p=2)
1299 self.assertFalse(x.is_contiguous())
1300 self.assertTrue(y.is_contiguous())
1301 self.assertTrue(torch.allclose(expected, actual))
1303 @unittest.skipIf(
not TEST_SCIPY,
"Scipy not found")
1304 def test_logsumexp(self):
1305 from scipy.special
import logsumexp
1306 a = torch.randn(5, 4)
1309 actual = a.logsumexp(1)
1310 expected = logsumexp(a.numpy(), 1)
1311 self.assertEqual(expected.shape, actual.shape)
1312 self.assertTrue(np.allclose(expected, actual.numpy()))
1314 b = torch.zeros(5, 2)
1316 torch.logsumexp(a, 1, out=c)
1317 self.assertTrue(np.allclose(expected, b[:, 0].numpy()))
1319 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1320 def test_cpu_parallel(self):
1326 def _run_test(size):
1327 for dim
in range(len(size) + 1):
1328 nv = np.round(np.random.rand(*size))
1329 tv = torch.from_numpy(nv)
1332 self.assertTrue(tv.numel() > 32768)
1333 if dim == len(size):
1339 diff = np.abs(nvs - tvs.numpy()).sum()
1340 self.assertEqual(diff, 0)
1342 _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3])
1343 _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
1344 _run_test([1, 32 * 8 * 32 * 8])
1345 _run_test([1, 32770])
1347 def _testCSelection(self, torchfn, mathfn):
1350 a = torch.rand(*size)
1351 b = torch.rand(*size)
1353 expected_c = torch.zeros(*size)
1354 expected_c.map2_(a, b,
lambda _, a, b: mathfn(a, b))
1355 self.assertEqual(expected_c, c, 0)
1357 def test_max_elementwise(self):
1358 self._testCSelection(torch.max, max)
1360 def test_min_elementwise(self):
1361 self._testCSelection(torch.min, min)
1364 def _test_lerp(self, cast):
1365 start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)]
1366 for shapes
in product(start_end_shapes, start_end_shapes):
1367 start = cast(torch.randn(shapes[0]))
1368 end = cast(torch.randn(shapes[1]))
1371 for weight
in [cast(torch.randn(shapes[0])), random.random()]:
1372 actual = torch.lerp(start, end, weight)
1373 actual_method = start.lerp(end, weight)
1374 self.assertEqual(actual, actual_method)
1375 actual_out = cast(torch.Tensor())
1376 torch.lerp(start, end, weight, out=actual_out)
1377 self.assertEqual(actual, actual_out)
1378 expected = start + weight * (end - start)
1379 self.assertEqual(expected, actual)
1381 def test_lerp(self):
1382 self._test_lerp(self,
lambda t: t)
1384 def test_all_any(self):
1386 x = torch.ones(*size).byte()
1387 self.assertTrue(x.all())
1388 self.assertTrue(x.any())
1391 self.assertFalse(x.all())
1392 self.assertTrue(x.any())
1395 self.assertFalse(x.all())
1396 self.assertFalse(x.any())
1399 self.assertTrue(x.all())
1400 self.assertTrue(x.any())
1405 def test_all_any_empty(self):
1406 x = torch.ByteTensor()
1407 self.assertTrue(x.all())
1408 self.assertFalse(x.any())
1410 def test_all_any_with_dim(self):
1412 r1 = x.prod(dim=0, keepdim=
False).byte()
1413 r2 = x.all(dim=0, keepdim=
False)
1414 self.assertEqual(r1.shape, r2.shape)
1415 self.assertTrue((r1 == r2).all())
1417 r3 = x.sum(dim=1, keepdim=
True).clamp(0, 1).byte()
1418 r4 = x.any(dim=1, keepdim=
True)
1419 self.assertEqual(r3.shape, r4.shape)
1420 self.assertTrue((r3 == r4).all())
1422 test(torch.ByteTensor([[0, 0, 0],
1428 def test_all_any_empty_cuda(self):
1429 x = torch.cuda.ByteTensor()
1430 self.assertTrue(x.all())
1431 self.assertFalse(x.any())
1434 m1 = torch.randn(100, 100)
1435 v1 = torch.randn(100)
1437 res1 = torch.mv(m1, v1)
1438 res2 = res1.clone().zero_()
1439 for i, j
in iter_indices(m1):
1440 res2[i] += m1[i][j] * v1[j]
1442 self.assertEqual(res1, res2)
1446 m1 = torch.randn(100, 100)
1447 v1 = torch.randn(100)
1450 res1 = torch.add(m1[4], v1)
1451 res2 = res1.clone().zero_()
1452 for i
in range(m1.size(1)):
1453 res2[i] = m1[4, i] + v1[i]
1454 self.assertEqual(res1, res2)
1456 m1 = torch.randn(100, 100)
1457 v1 = torch.randn(100)
1460 res1 = torch.add(m1[:, 4], v1)
1461 res2 = res1.clone().zero_()
1462 for i
in range(m1.size(0)):
1463 res2[i] = m1[i, 4] + v1[i]
1464 self.assertEqual(res1, res2)
1467 m1 = torch.randn(10, 10)
1473 for i
in range(m1.size(1)):
1474 res2[3, i] = res2[3, i] + 2
1475 self.assertEqual(res1, res2)
1478 m1 = torch.randn(10, 10)
1482 for i
in range(m1.size(0)):
1483 res2[i, 3] = res2[i, 3] + 2
1484 self.assertEqual(res1, res2)
1487 m1 = torch.randn(10, 10)
1491 self.assertEqual(torch.add(one, 1), 2)
1492 self.assertEqual(torch.add(one, 1).dtype, torch.uint8)
1495 m1 = torch.randn(10, 10)
1496 m2 = torch.randn(10, 10).t()
1498 self.assertTrue(res.is_contiguous())
1499 self.assertEqual(res, m1 + m2.contiguous())
1504 self.assertEqual(m1 + m2, [])
1508 def test_csub(self):
1510 a = torch.randn(100, 90)
1511 b = a.clone().normal_()
1513 res_add = torch.add(a, -1, b)
1514 res_csub = a.clone()
1516 self.assertEqual(res_add, res_csub)
1519 a = torch.randn(100, 100)
1522 res_add = torch.add(a, -scalar)
1523 res_csub = a.clone()
1524 res_csub.sub_(scalar)
1525 self.assertEqual(res_add, res_csub)
1528 def _test_neg(self, cast):
1529 float_types = [torch.DoubleTensor, torch.FloatTensor, torch.LongTensor]
1530 int_types = [torch.IntTensor, torch.ShortTensor, torch.ByteTensor,
1533 for t
in float_types + int_types:
1534 if t
in float_types:
1535 a = cast(torch.randn(100, 90).type(t))
1537 a = cast(torch.randint(-128, 128, (100, 90), dtype=t.dtype))
1538 zeros = cast(torch.Tensor().type(t)).resize_as_(a).zero_()
1540 if t == torch.ByteTensor:
1541 res_add = torch.add(zeros, a, alpha=255)
1543 res_add = torch.add(zeros, a, alpha=-1)
1546 self.assertEqual(res_neg, res_add)
1549 res_neg_out_place = a.clone().neg()
1550 self.assertEqual(res_neg_out_place, res_add)
1553 res_neg_op = -a.clone()
1554 self.assertEqual(res_neg_op, res_add)
1557 self._test_neg(self,
lambda t: t)
1559 def test_threshold(self):
1561 if dtype != torch.uint8
and dtype != torch.float16:
1563 x = torch.randn(100).sign().to(dtype=dtype)
1564 y = torch.threshold(x, 0, 0)
1565 self.assertTrue(y.le(0).any())
1567 def test_reciprocal(self):
1568 a = torch.randn(100, 89)
1570 res_reciprocal = a.clone()
1571 res_reciprocal.reciprocal_()
1572 self.assertEqual(res_reciprocal, res_div)
1575 m1 = torch.randn(10, 10)
1579 for i
in range(res1.size(0)):
1580 res2[i, 3] = res2[i, 3] * 2
1581 self.assertEqual(res1, res2)
1584 m1 = torch.randn(10, 10)
1588 for i
in range(m1.size(0)):
1589 res2[i, 3] = res2[i, 3] / 2
1590 self.assertEqual(res1, res2)
1592 def test_floordiv(self):
1594 if dtype
is torch.float16:
1596 x = torch.randn(100).mul(10).to(dtype)
1598 self.assertEqual(y.dtype, x.dtype)
1599 z =
torch.tensor([math.trunc(v.item() / 3.)
for v
in x], dtype=y.dtype)
1600 self.assertEqual(y, z)
1602 def test_rdiv(self):
1604 if dtype
is torch.float16:
1606 x = torch.rand(100).add(1).mul(4).to(dtype)
1608 if dtype.is_floating_point:
1609 z =
torch.tensor([30 / v.item()
for v
in x], dtype=dtype)
1611 z =
torch.tensor([math.trunc(30. / v.item())
for v
in x], dtype=dtype)
1612 self.assertEqual(y, z)
1614 def test_fmod(self):
1615 m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
1620 for i
in range(m1.size(1)):
1621 res2[i, 3] = math.fmod(res2[i, 3], q)
1622 self.assertEqual(res1, res2)
1624 def test_remainder(self):
1626 for use_item
in [
True,
False]:
1627 m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
1630 qs = torch.arange(-5.1, 4.1)
1632 for col_idx, q
in enumerate(qs):
1634 for i
in range(m1.size(0)):
1635 res2[i, col_idx] = res2[i, col_idx] % q
1637 res1[:, col_idx].remainder_(q
if not use_item
else q.item())
1638 self.assertEqual(res1, res2)
1641 res1.remainder_(qs.unsqueeze(0).expand_as(res1))
1642 self.assertEqual(res1, res2)
1645 for use_item
in [
True,
False]:
1646 long_m1 = torch.LongTensor(10, 10).random_(-10, 10)
1647 long_res1 = long_m1.clone()
1648 long_res2 = long_m1.clone()
1649 long_qs = torch.arange(-5, 5)
1651 for col_idx, long_q
in enumerate(long_qs):
1653 for i
in range(long_m1.size(0)):
1654 long_res2[i, col_idx] = long_res2[i, col_idx] % long_q
1656 long_res1[:, col_idx].remainder_(long_q
if not use_item
else long_q.item())
1657 self.assertEqual(long_res1, long_res2)
1659 long_res1 = long_m1.clone()
1660 long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))
1663 def _test_remainder_overflow(self, dtype, device):
1667 self.assertEqual(x % q, x)
1668 self.assertEqual(-x % q, q - x)
1669 self.assertEqual(x % -q, x - q)
1670 self.assertEqual(-x % -q, -x)
1672 def test_remainder_overflow(self):
1673 self._test_remainder_overflow(self, dtype=torch.int64, device=
'cpu')
1676 def _test_mm(n, m, p, dtype, genf):
1678 def matrixmultiply(mat1, mat2):
1682 res = torch.zeros(n, p, dtype=dtype)
1683 for i, j
in iter_indices(res):
1684 res[i, j] = sum(mat1[i, k] * mat2[k, j]
for k
in range(m))
1690 res = torch.mm(mat1, mat2)
1692 res2 = matrixmultiply(mat1, mat2)
1693 self.assertEqual(res, res2)
1697 mat2 = genf(p, m).t()
1698 res = torch.mm(mat1, mat2)
1700 res2 = matrixmultiply(mat1, mat2)
1701 self.assertEqual(res, res2)
1704 mat1 = genf(m, n).t()
1706 res = torch.mm(mat1, mat2)
1708 res2 = matrixmultiply(mat1, mat2)
1709 self.assertEqual(res, res2)
1712 mat1 = genf(m, n).t()
1713 mat2 = genf(p, m).t()
1714 res = torch.mm(mat1, mat2)
1716 res2 = matrixmultiply(mat1, mat2)
1717 self.assertEqual(res, res2)
1721 mat2 = genf(m, 1).expand(m, p)
1722 res = torch.mm(mat1, mat2)
1724 res2 = matrixmultiply(mat1, mat2)
1725 self.assertEqual(res, res2)
1732 torch.mm(mat1, mat2, out=res)
1734 res2 = matrixmultiply(mat1, mat2)
1735 self.assertEqual(res, res2)
1739 mat1 = genf(m, n).t()
1740 mat2 = genf(p, m).t()
1742 torch.mm(mat1, mat2, out=res)
1744 res2 = matrixmultiply(mat1, mat2)
1745 self.assertEqual(res, res2)
1747 for (n, m, p)
in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]:
1748 _test_mm(n, m, p, torch.float32,
lambda x, y: torch.randn(x, y, dtype=torch.float32))
1749 _test_mm(n, m, p, torch.float64,
lambda x, y: torch.randn(x, y, dtype=torch.float64))
1750 _test_mm(n, m, p, torch.int32,
lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int32))
1751 _test_mm(n, m, p, torch.int64,
lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64))
1754 def _test_btrifact(self, cast):
1755 from common_utils
import random_fullrank_matrix_distinct_singular_value
as fullrank
1757 def run_test(matrix_size, batches, cast):
1758 a = cast(fullrank(matrix_size, *batches))
1759 a_LU_info, pivots_info, info_ = a.btrifact_with_info()
1760 self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size)))
1761 self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,)))
1762 self.assertEqual(info_.size(), torch.Size(batches))
1763 self.assertEqual(info_.abs().sum(), 0)
1764 a_LU, pivots = a.btrifact()
1765 self.assertEqual(a_LU, a_LU_info)
1766 self.assertEqual(pivots_info, pivots)
1768 a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=
False)
1769 self.assertIsNone(nopiv)
1770 self.assertEqual(info_, info_nopiv)
1771 P, L, U = torch.btriunpack(a_LU, pivots)
1772 self.assertEqual(P.matmul(L.matmul(U)), a)
1774 for ms, batch
in product([3, 5, 7], [(2,), (3,), (3, 5)]):
1775 run_test(ms, batch, cast)
1778 a = cast(fullrank(3, 5))
1779 if not (a.is_cuda
and any(x
in torch.version.cuda
for x
in [
'8.0',
'9.2'])):
1780 a[0, 1] = 2 * a[0, 0]
1781 self.assertGreater(a.btrifact_with_info()[2][0], 0)
1784 with self.assertRaisesRegex(RuntimeError,
1785 'btrifact without pivoting is not implemented on the CPU'):
1786 torch.btrifact(torch.empty(1, 2, 2), pivot=
False)
1790 def test_btrifact(self):
1791 self._test_btrifact(self,
lambda t: t)
1794 def _test_btrisolve(self, cast):
1795 a = torch.FloatTensor((((1.3722, -0.9020),
1801 b = torch.FloatTensor(((4.02, 6.19),
1804 a, b = cast(a), cast(b)
1805 LU_data, pivots, info = a.btrifact_with_info()
1806 self.assertEqual(info.abs().sum(), 0)
1807 x = torch.btrisolve(b, LU_data, pivots)
1808 b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
1809 self.assertEqual(b_, b)
1812 def test_btrisolve(self):
1813 self._test_btrisolve(self,
lambda t: t)
1816 def _test_btriunpack(self, cast):
1817 def run_test(shape, cast):
1818 a = cast(torch.randn(*shape))
1819 a_lu, p = torch.btrifact(a.reshape(-1, shape[-1], shape[-1]))
1820 a_lu = a_lu.reshape_as(a)
1821 p = p.reshape(a.shape[:-1])
1822 p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p)
1823 self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
1825 run_test((5, 3, 3), cast)
1826 run_test((7, 3, 5, 5), cast)
1827 run_test((7, 5, 3, 3, 3), cast)
1830 def test_btriunpack(self):
1831 self._test_btriunpack(self,
lambda t: t)
1836 b1 = torch.randn(num_batches, M, N)
1837 b2 = torch.randn(num_batches, N, O)
1838 res = torch.bmm(b1, b2)
1839 for i
in range(num_batches):
1840 r = torch.mm(b1[i], b2[i])
1841 self.assertEqual(r, res[i])
1844 self.assertRaises(RuntimeError,
lambda: torch.bmm(b1, b2.cuda()))
1845 self.assertRaises(RuntimeError,
lambda: torch.bmm(b1.cuda(), b2))
1847 def test_addbmm(self):
1852 b1 = torch.randn(num_batches, M, N)
1853 b2 = torch.randn(num_batches, N, O)
1854 res = torch.bmm(b1, b2)
1855 res2 = torch.Tensor().resize_as_(res[0]).zero_()
1857 res2.addbmm_(b1, b2)
1858 self.assertEqual(res2, res.sum(0,
False))
1860 res2.addbmm_(1, b1, b2)
1861 self.assertEqual(res2, res.sum(0,
False) * 2)
1863 res2.addbmm_(1., .5, b1, b2)
1864 self.assertEqual(res2, res.sum(0,
False) * 2.5)
1866 res3 = torch.addbmm(1, res2, 0, b1, b2)
1867 self.assertEqual(res3, res2)
1869 res4 = torch.addbmm(1, res2, .5, b1, b2)
1870 self.assertEqual(res4, res.sum(0,
False) * 3)
1872 res5 = torch.addbmm(0, res2, 1, b1, b2)
1873 self.assertEqual(res5, res.sum(0,
False))
1875 res6 = torch.addbmm(.1, res2, .5, b1, b2)
1876 self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5))
1878 def test_baddbmm(self):
1881 b1 = torch.randn(num_batches, M, N)
1882 b2 = torch.randn(num_batches, N, O)
1883 res = torch.bmm(b1, b2)
1884 res2 = torch.Tensor().resize_as_(res).zero_()
1886 res2.baddbmm_(b1, b2)
1887 self.assertEqual(res2, res)
1889 res2.baddbmm_(1, b1, b2)
1890 self.assertEqual(res2, res * 2)
1892 res2.baddbmm_(1, .5, b1, b2)
1893 self.assertEqual(res2, res * 2.5)
1895 res3 = torch.baddbmm(1, res2, 0, b1, b2)
1896 self.assertEqual(res3, res2)
1898 res4 = torch.baddbmm(1, res2, .5, b1, b2)
1899 self.assertEqual(res4, res * 3)
1901 res5 = torch.baddbmm(0, res2, 1, b1, b2)
1902 self.assertEqual(res5, res)
1904 res6 = torch.baddbmm(.1, res2, .5, b1, b2)
1905 self.assertEqual(res6, res2 * .1 + res * .5)
1908 def _test_clamp(self, device='cpu'):
1909 m1 = torch.rand(100, device=device).mul(5).add(-2.5)
1917 res1.clamp_(min_val, max_val)
1919 for i
in iter_indices(res2):
1920 res2[i] = max(min_val, min(max_val, res2[i]))
1921 self.assertEqual(res1, res2)
1924 torch.clamp(m1, min=min_val, max=max_val, out=out)
1925 self.assertEqual(out, res1)
1927 res1 = torch.clamp(m1, min=min_val)
1929 for i
in iter_indices(res2):
1930 res2[i] = max(min_val, res2[i])
1931 self.assertEqual(res1, res2)
1933 torch.clamp(m1, min=min_val, out=out)
1934 self.assertEqual(out, res1)
1936 res1 = torch.clamp(m1, max=max_val)
1938 for i
in iter_indices(res2):
1939 res2[i] = min(max_val, res2[i])
1940 self.assertEqual(res1, res2)
1942 torch.clamp(m1, max=max_val, out=out)
1943 self.assertEqual(out, res1)
1948 res1 = test_tens.clone()
1949 res1.clamp_(min_val, max_val)
1950 res2 = test_tens.clone()
1951 for i
in iter_indices(res2):
1952 res2[i] = max(min(res2[i], max_val), min_val)
1953 self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1955 out = test_tens.clone()
1956 torch.clamp(test_tens, min=min_val, max=max_val, out=out)
1957 self.assertEqual(torch.isnan(out), torch.isnan(res1))
1959 res1 = torch.clamp(test_tens, min=min_val)
1960 res2 = test_tens.clone()
1961 for i
in iter_indices(res2):
1962 res2[i] = max(res2[i], min_val)
1963 self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1965 torch.clamp(test_tens, min=min_val, out=out)
1966 self.assertEqual(torch.isnan(out), torch.isnan(res1))
1968 res1 = torch.clamp(test_tens, max=max_val)
1969 res2 = test_tens.clone()
1970 for i
in iter_indices(res2):
1971 res2[i] = min(res2[i], max_val)
1972 self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1974 torch.clamp(test_tens, max=max_val, out=out)
1975 self.assertEqual(torch.isnan(out), torch.isnan(res1))
1977 error_msg =
'At least one of \'min\' or \'max\' must not be None' 1978 with self.assertRaisesRegex(RuntimeError, error_msg):
1980 with self.assertRaisesRegex(RuntimeError, error_msg):
1983 def test_clamp(self):
1984 self._test_clamp(self)
1990 for exponent
in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]:
1993 m1 = torch.rand(100, 100) + 0.5
1994 res1 = torch.pow(m1[4], exponent)
1995 res2 = res1.clone().zero_()
1996 for i
in range(res2.size(0)):
1997 res2[i] = math.pow(m1[4][i], exponent)
1998 self.assertEqual(res1, res2)
2001 m1 = torch.rand(100, 100) + 0.5
2002 res1 = torch.pow(m1[:, 4], exponent)
2003 res2 = res1.clone().zero_()
2004 for i
in range(res2.size(0)):
2005 res2[i] = math.pow(m1[i, 4], exponent)
2006 self.assertEqual(res1, res2)
2010 m1 = torch.randn(100, 100)
2011 res1 = torch.pow(3, m1[4])
2012 res2 = res1.clone().zero_()
2013 for i
in range(res2.size(0)):
2014 res2[i] = math.pow(3, m1[4, i])
2015 self.assertEqual(res1, res2)
2018 m1 = torch.randn(100, 100)
2019 res1 = torch.pow(3, m1[:, 4])
2020 res2 = res1.clone().zero_()
2021 for i
in range(res2.size(0)):
2022 res2[i] = math.pow(3, m1[i][4])
2023 self.assertEqual(res1, res2)
2026 def _test_rpow(self, cast):
2027 m = cast(torch.randn(10, 10))
2028 self.assertEqual(torch.pow(2, m), 2**m)
2031 m = cast(torch.randn(1).squeeze())
2032 assert m.dim() == 0,
"m is intentionally a scalar" 2033 self.assertEqual(torch.pow(2, m), 2**m)
2035 def test_rpow(self):
2036 self._test_rpow(self,
lambda x: x)
2039 def _test_int_pow(self, cast):
2044 def check_against_np(tensor, exp):
2045 tensor_np = tensor.cpu().numpy()
2046 exp_np = exp
if isinstance(exp, int)
else exp.cpu().numpy()
2047 expected = torch.LongTensor(tensor_np ** exp_np).type_as(tensor)
2048 self.assertEqual(torch.pow(tensor, exp), expected)
2049 self.assertEqual(tensor.pow(exp), torch.pow(tensor, exp))
2053 lambda x: x.short(),
2058 typecasts.append(
lambda x: x.int())
2061 tensor = cast(torch.LongTensor(shape).random_(-10, 10))
2062 exps = [0, 1, 2, 5, cast(torch.LongTensor(shape).random_(0, 20))]
2064 for typecast
in typecasts:
2066 t = typecast(tensor)
2067 e = exp
if isinstance(exp, int)
else typecast(exp)
2068 check_against_np(t, e)
2070 def test_int_pow(self):
2071 self._test_int_pow(self,
lambda x: x)
2073 def _test_cop(self, torchfn, mathfn):
2074 def reference_implementation(res2):
2075 for i, j
in iter_indices(sm1):
2076 idx1d = i * sm1.size(0) + j
2077 res2[i, j] = mathfn(sm1[i, j], sm2[idx1d])
2081 m1 = torch.randn(10, 10, 10)
2082 m2 = torch.randn(10, 10 * 10)
2086 res1 = torchfn(sm1, sm2.view(10, 10))
2087 res2 = reference_implementation(res1.clone())
2088 self.assertEqual(res1, res2)
2091 m1 = torch.randn(10, 10, 10)
2092 m2 = torch.randn(10 * 10, 10 * 10)
2096 sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0]))
2097 res1 = torchfn(sm1, sm2)
2099 sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride())
2100 res2 = reference_implementation(res1.clone())
2101 self.assertEqual(res1, res2)
2103 def test_cdiv(self):
2104 self._test_cop(torch.div,
lambda x, y: x / y)
2106 def test_cfmod(self):
2107 self._test_cop(torch.fmod, math.fmod)
2109 def test_cremainder(self):
2110 self._test_cop(torch.remainder,
lambda x, y: x % y)
2112 def test_cmul(self):
2113 self._test_cop(torch.mul,
lambda x, y: x * y)
2115 def test_cpow(self):
2116 self._test_cop(torch.pow,
lambda x, y: nan
if x < 0
else math.pow(x, y))
2118 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2119 def test_einsum(self):
2123 A = torch.randn(3, 5)
2124 B = torch.randn(2, 5)
2125 C = torch.randn(2, 3, 5)
2126 D = torch.randn(2, 5, 7)
2127 E = torch.randn(7, 9)
2128 F = torch.randn(2, 3, 5, 7)
2129 G = torch.randn(7, 11, 13)
2130 H = torch.randn(4, 4)
2131 I = torch.randn(3, 4, 4)
2132 l = torch.randn(5, 10)
2133 r = torch.randn(5, 20)
2134 w = torch.randn(30, 10, 20)
2145 (
"ij,ij->ij", A, A),
2147 (
"ij,kj->ik", A, B),
2148 (
"ij,ab->ijab", A, E),
2150 (
"aij,ajk->aik", C, D),
2151 (
"ijk,jk->i", C, A),
2152 (
"aij,jk->aik", D, E),
2153 (
"abcd,dfg->abcfg", F, G),
2154 (
"ijk,jk->ik", C, A),
2155 (
"ijk,jk->ij", C, A),
2156 (
"ijk,ik->j", C, B),
2157 (
"ijk,ik->jk", C, B),
2163 (
"ki,...k->i...", A.t(), B),
2164 (
"k...,jk", A.t(), B),
2167 (
"bn,anm,bm->ba", l, w, r),
2168 (
"... ii->...i ", I),
2170 for test
in test_list:
2171 actual = torch.einsum(test[0], test[1:])
2172 expected = np.einsum(test[0], *[t.numpy()
for t
in test[1:]])
2173 self.assertEqual(expected.shape, actual.shape, test[0])
2174 self.assertTrue(np.allclose(expected, actual.numpy()), test[0])
2176 actual2 = torch.einsum(test[0], *test[1:])
2177 self.assertEqual(expected.shape, actual2.shape, test[0])
2178 self.assertTrue(np.allclose(expected, actual2.numpy()), test[0])
2180 def do_einsum(*args):
2181 return torch.einsum(test[0], args)
2183 if test[0]
not in {
"i,i->",
"i,i->i",
"ij,ij->ij"}:
2184 gradcheck_inps = tuple(t.detach().requires_grad_()
for t
in test[1:])
2186 self.assertTrue(A._version == 0)
2188 def test_sum_all(self):
2189 def check_sum_all(tensor):
2190 pylist = tensor.reshape(-1).tolist()
2191 self.assertEqual(tensor.sum(), sum(pylist))
2194 check_sum_all(torch.randn(200000))
2195 check_sum_all(torch.randn(2000, 2)[:, 0])
2197 def _assert_matches_numpy(self, t, n):
2198 self.assertEqual(n.shape, t.shape)
2199 if t.dtype == torch.float:
2200 self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05,
2203 self.assertTrue(np.allclose(n, t.numpy(), equal_nan=
True))
2205 def _test_dim_ops(self, pytorch_op, numpy_op,
2206 use_floating=
True, use_integral=
True):
2207 def do_one(tensors_dict, dim):
2208 for category, tensors
in tensors_dict.items():
2209 if category ==
"slice":
2211 for tensor
in tensors:
2213 with warnings.catch_warnings():
2214 warnings.simplefilter(
"ignore")
2215 expected = numpy_op(tensor.numpy(), dim)
2216 actual = pytorch_op(tensor, dim)
2217 self._assert_matches_numpy(actual, expected)
2219 self._assert_matches_numpy(pytorch_op(tensor.cuda(),
2222 do_one(self._make_tensors((5, 400000), use_floating=use_floating,
2223 use_integral=use_integral), 1)
2224 do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2225 use_integral=use_integral), 0)
2226 do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2227 use_integral=use_integral), 1)
2228 do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2229 use_integral=use_integral), 2)
2230 do_one(self._make_tensors((100000, ), use_floating=use_floating,
2231 use_integral=use_integral), -1)
2232 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2233 use_integral=use_integral), 0)
2234 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2235 use_integral=use_integral), 1)
2236 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2237 use_integral=use_integral), 2)
2238 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2239 use_integral=use_integral), (1, 2))
2240 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2241 use_integral=use_integral), (1, -1))
2242 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2243 use_integral=use_integral), (0, 2))
2244 do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2245 use_integral=use_integral), (0, 2, 1))
2247 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2248 def test_sum_dim(self):
2250 lambda t, d: t.sum(d),
2251 lambda n, d: n.sum(d))
2253 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2254 def test_mean_dim(self):
2256 lambda t, d: t.mean(d),
2257 lambda n, d: n.mean(d),
2260 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2261 def test_std_dim(self):
2262 for unbiased
in [
False,
True]:
2264 lambda t, d: t.std(d, unbiased=unbiased),
2265 lambda n, d: n.std(d, ddof=1
if unbiased
else 0),
2268 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2269 def test_var_dim(self):
2270 for unbiased
in [
False,
True]:
2272 lambda t, d: t.var(d, unbiased=unbiased),
2273 lambda n, d: n.var(d, ddof=1
if unbiased
else 0),
2276 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
2277 @unittest.skipIf(
not TEST_SCIPY,
'Scipy not found')
2278 def test_logsumexp_dim(self):
2279 from scipy.special
import logsumexp
2281 lambda t, d: t.logsumexp(d),
2282 lambda n, d: logsumexp(n, d),
2285 def test_sum_out(self):
2286 x = torch.rand(100, 100)
2287 res1 = torch.sum(x, 1)
2288 res2 = torch.Tensor()
2289 torch.sum(x, 1, out=res2)
2290 self.assertEqual(res1, res2)
2291 x = torch.rand(100, 100, 100)
2292 res1 = x.sum(2).sum(1)
2293 res2 = torch.Tensor()
2294 torch.sum(x, (2, 1), out=res2)
2295 self.assertEqual(res1, res2)
2299 def test_prod(self):
2300 x = torch.rand(100, 100)
2301 res1 = torch.prod(x, 1)
2302 res2 = torch.Tensor()
2303 torch.prod(x, 1, out=res2)
2304 self.assertEqual(res1, res2)
2306 def test_cumsum(self):
2307 x = torch.rand(100, 100)
2308 res1 = torch.cumsum(x, 1)
2309 res2 = torch.Tensor()
2310 torch.cumsum(x, 1, out=res2)
2311 self.assertEqual(res1, res2)
2313 def test_cumprod(self):
2314 x = torch.rand(100, 100)
2315 res1 = torch.cumprod(x, 1)
2316 res2 = torch.Tensor()
2317 torch.cumprod(x, 1, out=res2)
2318 self.assertEqual(res1, res2)
2320 def _test_reduce_integer_upcast(self, fn, has_out=True):
2322 reduced_shape = fn(torch.ones(shape)).shape
2324 def _test_out(dtype, other_dtype):
2325 out = torch.ones(reduced_shape, dtype=dtype)
2326 result = fn(x, out=out)
2327 self.assertIs(out.dtype, result.dtype)
2328 self.assertEqual(fn(x.type(dtype)), result)
2329 result = fn(x, out=out, dtype=dtype)
2330 self.assertIs(out.dtype, result.dtype)
2331 self.assertEqual(fn(x.type(dtype)), result)
2333 self.assertRaises(RuntimeError,
lambda: fn(x, out=out, dtype=other_dtype))
2336 x = torch.ones(shape, dtype=dtype)
2337 expected_dtype = dtype
if dtype.is_floating_point
else torch.int64
2338 self.assertIs(expected_dtype, fn(x).dtype)
2339 self.assertEqual(fn(x.type(expected_dtype)), fn(x))
2341 if dtype.is_floating_point:
2342 other_dtype = torch.float32
if dtype == torch.float64
else torch.float64
2344 other_dtype = torch.int32
if dtype != torch.int32
else torch.int16
2345 self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype)
2346 self.assertEqual(fn(x.type(other_dtype)), fn(x, dtype=other_dtype))
2349 mixed_dtype = torch.int32
if dtype.is_floating_point
else torch.float32
2350 self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype)
2351 self.assertEqual(fn(x.type(mixed_dtype)), fn(x, dtype=mixed_dtype))
2354 _test_out(dtype, other_dtype)
2355 _test_out(dtype, mixed_dtype)
2357 def test_sum_integer_upcast(self):
2358 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.sum(x, **kwargs),
False)
2359 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.sum(x, 0, **kwargs))
2361 def test_prod_integer_upcast(self):
2362 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.prod(x, **kwargs),
False)
2363 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.prod(x, 0, **kwargs))
2365 def test_cumsum_integer_upcast(self):
2366 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.cumsum(x, 0, **kwargs))
2368 def test_cumprod_integer_upcast(self):
2369 self._test_reduce_integer_upcast(
lambda x, **kwargs: torch.cumprod(x, 0, **kwargs))
2371 def test_cross(self):
2372 x = torch.rand(100, 3, 100)
2373 y = torch.rand(100, 3, 100)
2374 res1 = torch.cross(x, y)
2375 res2 = torch.Tensor()
2376 torch.cross(x, y, out=res2)
2377 self.assertEqual(res1, res2)
2379 def test_zeros(self):
2380 res1 = torch.zeros(100, 100)
2381 res2 = torch.Tensor()
2382 torch.zeros(100, 100, out=res2)
2383 self.assertEqual(res1, res2)
2385 boolTensor = torch.zeros(2, 2, dtype=torch.bool)
2386 expected =
torch.tensor([[
False,
False], [
False,
False]], dtype=torch.bool)
2387 self.assertEqual(boolTensor, expected)
2389 halfTensor = torch.zeros(1, 1, dtype=torch.half)
2391 self.assertEqual(halfTensor, expected)
2393 def test_zeros_like(self):
2394 expected = torch.zeros(100, 100)
2396 res1 = torch.zeros_like(expected)
2397 self.assertEqual(res1, expected)
2400 def test_zeros_like_cuda(self):
2401 expected = torch.zeros(100, 100).cuda()
2403 res1 = torch.zeros_like(expected)
2404 self.assertEqual(res1, expected)
2407 def test_zeros_like_multiple_device(self):
2408 expected = torch.zeros(100, 100).cuda()
2409 x = torch.cuda.FloatTensor(100, 100, device=1)
2410 output = torch.zeros_like(x)
2411 self.assertEqual(output, expected)
2413 def test_zeros_out(self):
2415 out = torch.zeros(shape)
2416 torch.zeros(shape, out=out)
2419 self.assertRaises(RuntimeError,
lambda: torch.zeros(shape, dtype=torch.int64, out=out))
2420 self.assertRaises(RuntimeError,
lambda: torch.zeros(shape, layout=torch.sparse_coo, out=out))
2422 self.assertRaises(RuntimeError,
lambda: torch.zeros(shape, device=
'cuda', out=out))
2425 self.assertEqual(torch.zeros(shape), torch.zeros(shape, dtype=out.dtype, out=out))
2426 self.assertEqual(torch.zeros(shape), torch.zeros(shape, layout=torch.strided, out=out))
2427 self.assertEqual(torch.zeros(shape), torch.zeros(shape, device=
'cpu', out=out))
2430 def _test_histc(self, device):
2432 with self.assertRaisesRegex(RuntimeError,
'bins must be > 0'):
2433 torch.histc(
torch.tensor([1], dtype=torch.float, device=device), bins=-1)
2436 actual = torch.histc(
2437 torch.tensor([2, 5], dtype=torch.float, device=device))
2438 expected = torch.zeros(100, dtype=torch.float, device=device)
2439 expected.data[0] = 1
2440 expected.data[99] = 1
2441 self.assertEqual(expected, actual)
2443 actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
2445 torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
2448 actual = torch.histc(
2449 torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
2451 torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
2454 actual = torch.histc(
2455 torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
2456 bins=5, min=1, max=5)
2458 torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
2461 actual = torch.histc(
2462 torch.tensor([1, 2, 1], dtype=torch.float, device=device),
2463 bins=4, min=0, max=3)
2465 torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2468 actual = torch.histc(
2469 torch.tensor([1, 2, 1], dtype=torch.double, device=device),
2470 bins=4, min=0, max=3)
2472 torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
2475 actual = torch.histc(
2476 torch.tensor([1., 2, 1], dtype=torch.float, device=device),
2477 bins=4, min=0, max=3)
2479 torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2483 def test_against_np(tensor, bins=100, min=0, max=0):
2484 if min == 0
and max == 0:
2485 min = tensor.min().item()
2486 max = tensor.max().item()
2487 nparr = tensor.cpu().numpy()
2488 actual = torch.histc(tensor, bins=bins, min=min, max=max)
2489 expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0])
2490 self.assertEqual(actual.cpu(), expected)
2493 test_against_np(
torch.tensor([1., 2, 1], device=device))
2494 test_against_np(torch.randn(5000, device=device))
2497 test_against_np(torch.randn(301, device=device), bins=10)
2500 test_against_np(torch.randn(201, device=device), min=0.1, max=1)
2502 noncontig = torch.randn(100, 3, device=device)[:, 2]
2503 test_against_np(noncontig)
2505 multidim = torch.randn(3, 5, 7, 2, device=device)
2506 test_against_np(multidim)
2508 expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
2509 test_against_np(expanded)
2511 def test_histc_cpu(self):
2512 self._test_histc(self,
'cpu')
2514 def test_ones(self):
2515 res1 = torch.ones(100, 100)
2516 res2 = torch.Tensor()
2517 torch.ones(100, 100, out=res2)
2518 self.assertEqual(res1, res2)
2521 res1 = torch.ones(1, 2, dtype=torch.bool)
2522 expected =
torch.tensor([[
True,
True]], dtype=torch.bool)
2523 self.assertEqual(res1, expected)
2525 def test_ones_like(self):
2526 expected = torch.ones(100, 100)
2528 res1 = torch.ones_like(expected)
2529 self.assertEqual(res1, expected)
2532 expected =
torch.tensor([
True,
True], dtype=torch.bool)
2533 res1 = torch.ones_like(expected)
2534 self.assertEqual(res1, expected)
2537 def test_ones_like_cuda(self):
2538 expected = torch.ones(100, 100).cuda()
2540 res1 = torch.ones_like(expected)
2541 self.assertEqual(res1, expected)
2544 def test_ones_like_multiple_device(self):
2545 expected = torch.ones(100, 100).cuda()
2546 x = torch.cuda.FloatTensor(100, 100, device=1)
2547 output = torch.ones_like(x)
2548 self.assertEqual(output, expected)
2550 def test_dtypes(self):
2552 do_test_dtypes(self, all_dtypes, torch.strided, torch.device(
'cpu'))
2554 do_test_dtypes(self, all_dtypes, torch.strided, torch.device(
'cuda:0'))
2556 def test_copy_dtypes(self):
2558 for dtype
in all_dtypes:
2559 copied_dtype = copy.deepcopy(dtype)
2560 self.assertIs(dtype, copied_dtype)
2562 def test_device(self):
2563 cpu = torch.device(
'cpu')
2564 self.assertEqual(
'cpu', str(cpu))
2565 self.assertEqual(
'cpu', cpu.type)
2566 self.assertEqual(
None, cpu.index)
2568 cpu0 = torch.device(
'cpu:0')
2569 self.assertEqual(
'cpu:0', str(cpu0))
2570 self.assertEqual(
'cpu', cpu0.type)
2571 self.assertEqual(0, cpu0.index)
2573 cpu0 = torch.device(
'cpu', 0)
2574 self.assertEqual(
'cpu:0', str(cpu0))
2575 self.assertEqual(
'cpu', cpu0.type)
2576 self.assertEqual(0, cpu0.index)
2578 cuda = torch.device(
'cuda')
2579 self.assertEqual(
'cuda', str(cuda))
2580 self.assertEqual(
'cuda', cuda.type)
2581 self.assertEqual(
None, cuda.index)
2583 cuda1 = torch.device(
'cuda:1')
2584 self.assertEqual(
'cuda:1', str(cuda1))
2585 self.assertEqual(
'cuda', cuda1.type)
2586 self.assertEqual(1, cuda1.index)
2588 cuda1 = torch.device(
'cuda', 1)
2589 self.assertEqual(
'cuda:1', str(cuda1))
2590 self.assertEqual(
'cuda', cuda1.type)
2591 self.assertEqual(1, cuda1.index)
2593 self.assertRaises(RuntimeError,
lambda: torch.device(
'cpu:-1'))
2594 self.assertRaises(RuntimeError,
lambda: torch.device(
'cpu:1'))
2595 self.assertRaises(RuntimeError,
lambda: torch.device(
'cpu', -1))
2596 self.assertRaises(RuntimeError,
lambda: torch.device(
'cpu', 1))
2597 self.assertRaises(RuntimeError,
lambda: torch.device(
'cuda:-1'))
2598 self.assertRaises(RuntimeError,
lambda: torch.device(
'cuda', -1))
2599 self.assertRaises(RuntimeError,
lambda: torch.device(-1))
2601 self.assertRaises(RuntimeError,
lambda: torch.device(
'other'))
2602 self.assertRaises(RuntimeError,
lambda: torch.device(
'other:0'))
2604 device_set = {
'cpu',
'cpu:0',
'cuda',
'cuda:0',
'cuda:1',
'cuda:10',
'cuda:100'}
2605 device_hash_set = set()
2606 for device
in list(device_set):
2607 device_hash_set.add(hash(torch.device(device)))
2608 self.assertEqual(len(device_set), len(device_hash_set))
2610 def test_tensor_device(self):
2611 def assertEqual(device_str, fn):
2612 self.assertEqual(torch.device(device_str), fn().device)
2613 self.assertEqual(device_str, str(fn().device))
2616 assertEqual(
'cpu',
lambda: torch.ones((2, 3), dtype=torch.float32, device=
'cpu'))
2619 assertEqual(
'cpu',
lambda: torch.ones((2, 3), dtype=torch.float32, device=
'cpu:0'))
2620 assertEqual(
'cpu',
lambda:
torch.tensor(torch.ones((2, 3), dtype=torch.float32), device=
'cpu:0'))
2622 assertEqual(
'cpu',
lambda:
torch.tensor(np.random.randn(2, 3), device=
'cpu'))
2626 assertEqual(
'cuda:0',
lambda:
torch.tensor(5).cuda(
'cuda:0'))
2627 self.assertRaises(RuntimeError,
lambda:
torch.tensor(5).cuda(
'cpu'))
2628 self.assertRaises(RuntimeError,
lambda:
torch.tensor(5).cuda(
'cpu:0'))
2629 assertEqual(
'cuda:0',
lambda:
torch.tensor(5, dtype=torch.int64, device=0))
2630 assertEqual(
'cuda:0',
lambda:
torch.tensor(5, dtype=torch.int64, device=
'cuda:0'))
2632 lambda:
torch.tensor(5, dtype=torch.int64, device=
'cuda'))
2633 assertEqual(
'cuda:0',
lambda:
torch.tensor(torch.ones((2, 3), dtype=torch.float32), device=
'cuda:0'))
2635 assertEqual(
'cuda:0',
lambda:
torch.tensor(np.random.randn(2, 3), device=
'cuda:0'))
2639 assertEqual(
'cuda:1',
lambda:
torch.tensor(5).cuda(
'cuda:1'))
2640 assertEqual(
'cuda:1',
lambda:
torch.tensor(5, dtype=torch.int64, device=1))
2641 assertEqual(
'cuda:1',
lambda:
torch.tensor(5, dtype=torch.int64, device=
'cuda:1'))
2642 assertEqual(
'cuda:1',
lambda:
torch.tensor(torch.ones((2, 3), dtype=torch.float32), device=
'cuda:1'))
2644 assertEqual(
'cuda:1',
lambda:
torch.tensor(np.random.randn(2, 3), device=
'cuda:1'))
2647 def test_copy_behavior(t, non_blocking=False):
2648 self.assertIs(t, t.to(t, non_blocking=non_blocking))
2649 self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
2650 self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
2651 self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=
True))
2652 self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=
True))
2653 self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=
True))
2655 devices = [t.device]
2656 if t.device.type ==
'cuda':
2657 if t.device.index == -1:
2660 devices.append(
'cuda')
2661 for device
in devices:
2662 self.assertIs(t, t.to(device, non_blocking=non_blocking))
2663 self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
2664 self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=
True))
2665 self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=
True))
2668 test_copy_behavior(a)
2669 self.assertEqual(a.device, a.to(
'cpu').device)
2670 self.assertEqual(a.device, a.to(
'cpu', dtype=torch.float32).device)
2671 self.assertIs(torch.float32, a.to(
'cpu', dtype=torch.float32).dtype)
2672 self.assertEqual(a.device, a.to(torch.float32).device)
2673 self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
2674 self.assertEqual(a.data_ptr(), a.to(
'cpu').data_ptr())
2675 self.assertEqual(a.data_ptr(), a.to(dtype=a.dtype, device=a.device, copy=
False).data_ptr())
2676 self.assertEqual(a.data_ptr(), a.to(
'cpu', copy=
False).data_ptr())
2677 self.assertNotEqual(a.data_ptr(), a.to(
'cpu', copy=
True).data_ptr())
2680 for non_blocking
in [
True,
False]:
2683 test_copy_behavior(b, non_blocking)
2684 self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device)
2685 self.assertEqual(a.device, b.to(
'cpu', non_blocking=non_blocking).device)
2686 self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device)
2687 self.assertIs(torch.int32, b.to(
'cpu', dtype=torch.int32, non_blocking=non_blocking).dtype)
2688 self.assertEqual(a.device, b.to(
'cpu', dtype=torch.int32, non_blocking=non_blocking).device)
2689 self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
2690 self.assertEqual(b.device, b.to(dtype=torch.int32).device)
2692 def test_to_with_tensor(self):
2694 self.assertEqual(a.device, a.to(a).device)
2697 for non_blocking
in [
True,
False]:
2700 self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device)
2701 self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device)
2702 self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device)
2704 def test_empty_full(self):
2710 def test_dtype_out_match(self):
2711 d = torch.autograd.Variable(torch.DoubleTensor(2, 3))
2712 self.assertRaises(RuntimeError,
lambda: torch.zeros((2, 3), out=d, dtype=torch.float32))
2714 def test_constructor_dtypes(self):
2715 default_type = torch.Tensor().type()
2716 self.assertIs(torch.Tensor().dtype, torch.get_default_dtype())
2718 self.assertIs(torch.uint8, torch.ByteTensor.dtype)
2719 self.assertIs(torch.float32, torch.FloatTensor.dtype)
2720 self.assertIs(torch.float64, torch.DoubleTensor.dtype)
2723 self.assertIs(torch.float32, torch.get_default_dtype())
2727 self.assertIs(torch.float64, torch.get_default_dtype())
2731 self.assertIs(torch.float32, torch.get_default_dtype())
2736 self.assertIs(torch.float32, torch.get_default_dtype())
2737 self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
2741 self.assertIs(torch.float64, torch.get_default_dtype())
2753 def test_constructor_device_legacy(self):
2754 self.assertRaises(RuntimeError,
lambda: torch.FloatTensor(device=
'cuda'))
2755 self.assertRaises(RuntimeError,
lambda: torch.FloatTensor(torch.Size([2, 3, 4]), device=
'cuda'))
2756 self.assertRaises(RuntimeError,
lambda: torch.FloatTensor((2.0, 3.0), device=
'cuda'))
2758 self.assertRaises(RuntimeError,
lambda: torch.Tensor(device=
'cuda'))
2759 self.assertRaises(RuntimeError,
lambda: torch.Tensor(torch.Size([2, 3, 4]), device=
'cuda'))
2760 self.assertRaises(RuntimeError,
lambda: torch.Tensor((2.0, 3.0), device=
'cuda'))
2762 x = torch.randn((3,), device=
'cpu')
2763 self.assertRaises(RuntimeError,
lambda: x.new(device=
'cuda'))
2764 self.assertRaises(RuntimeError,
lambda: x.new(torch.Size([2, 3, 4]), device=
'cuda'))
2765 self.assertRaises(RuntimeError,
lambda: x.new((2.0, 3.0), device=
'cuda'))
2768 self.assertRaises(RuntimeError,
lambda: torch.cuda.FloatTensor(device=
'cpu'))
2769 self.assertRaises(RuntimeError,
lambda: torch.cuda.FloatTensor(torch.Size([2, 3, 4]), device=
'cpu'))
2770 self.assertRaises(RuntimeError,
lambda: torch.cuda.FloatTensor((2.0, 3.0), device=
'cpu'))
2772 default_type = torch.Tensor().type()
2774 self.assertRaises(RuntimeError,
lambda: torch.Tensor(device=
'cpu'))
2775 self.assertRaises(RuntimeError,
lambda: torch.Tensor(torch.Size([2, 3, 4]), device=
'cpu'))
2776 self.assertRaises(RuntimeError,
lambda: torch.Tensor((2.0, 3.0), device=
'cpu'))
2780 x = torch.randn((3,), device=
'cuda')
2781 self.assertRaises(RuntimeError,
lambda: x.new(device=
'cpu'))
2782 self.assertRaises(RuntimeError,
lambda: x.new(torch.Size([2, 3, 4]), device=
'cpu'))
2783 self.assertRaises(RuntimeError,
lambda: x.new((2.0, 3.0), device=
'cpu'))
2785 def test_type(self):
2786 x = torch.randn(3, 3).double()
2787 self.assertEqual(x.type(
'torch.FloatTensor').dtype, torch.float32)
2788 self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32)
2789 self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype())
2790 self.assertEqual(x.type(torch.int32).dtype, torch.int32)
2792 def test_tensor_factory(self):
2793 expected = torch.Tensor([1, 1])
2796 self.assertEqual(res1, expected)
2799 self.assertEqual(res1, expected)
2800 self.assertIs(torch.int, res1.dtype)
2804 self.assertEqual(res2, expected)
2806 self.assertEqual(expected, torch.ones_like(expected))
2809 self.assertEqual(res1, expected)
2810 self.assertIs(torch.int, res1.dtype)
2814 for dtype
in [np.float64, np.int64, np.int8, np.uint8]:
2815 a = np.array([5.]).astype(dtype)
2817 self.assertEqual(5., res1[0].item())
2819 self.assertEqual(5., res1[0].item())
2822 a =
torch.tensor([
True,
True,
False,
True,
True], dtype=torch.bool)
2823 b =
torch.tensor([-1, -1.1, 0, 1, 1.1], dtype=torch.bool)
2824 self.assertEqual(a, b)
2826 def test_tensor_factory_copy_var(self):
2828 def check_copy(copy, is_leaf, requires_grad, data_ptr=None):
2829 if data_ptr
is None:
2830 data_ptr = copy.data_ptr
2831 self.assertEqual(copy.data, source.data)
2832 self.assertTrue(copy.is_leaf == is_leaf)
2833 self.assertTrue(copy.requires_grad == requires_grad)
2834 self.assertTrue(copy.data_ptr == data_ptr)
2836 source = torch.randn(5, 5, dtype=torch.double, requires_grad=
True)
2839 check_copy(
torch.tensor(source, requires_grad=
False),
True,
False)
2840 check_copy(
torch.tensor(source, requires_grad=
True),
True,
True)
2843 copy = torch.randn(1)
2844 check_copy(copy.new_tensor(source),
True,
False)
2845 check_copy(copy.new_tensor(source, requires_grad=
False),
True,
False)
2846 check_copy(copy.new_tensor(source, requires_grad=
True),
True,
True)
2849 check_copy(torch.as_tensor(source), source.is_leaf, source.requires_grad, source.data_ptr)
2850 check_copy(torch.as_tensor(source, dtype=torch.float),
False,
True)
2852 def test_tensor_factory_type_inference(self):
2853 def test_inference(default_dtype):
2854 saved_dtype = torch.get_default_dtype()
2860 self.assertIs(torch.int32,
torch.tensor(5, dtype=torch.int32).dtype)
2861 self.assertIs(default_dtype,
torch.tensor(((7, 5), (9, 5.))).dtype)
2862 self.assertIs(default_dtype,
torch.tensor(((5., 5), (3, 5))).dtype)
2863 self.assertIs(torch.int64,
torch.tensor(((5, 3), (3, 5))).dtype)
2866 self.assertIs(torch.float64,
torch.tensor(np.array(())).dtype)
2867 self.assertIs(torch.float64,
torch.tensor(np.array(5.)).dtype)
2868 if np.array(5).dtype == np.int64:
2869 self.assertIs(torch.int64,
torch.tensor(np.array(5)).dtype)
2871 self.assertIs(torch.int32,
torch.tensor(np.array(5)).dtype)
2872 self.assertIs(torch.uint8,
torch.tensor(np.array(3, dtype=np.uint8)).dtype)
2873 self.assertIs(default_dtype,
torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype)
2874 self.assertIs(torch.float64,
torch.tensor(((7, 5), (9, np.array(5.)))).dtype)
2875 self.assertIs(torch.int64,
torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype)
2878 test_inference(torch.float64)
2879 test_inference(torch.float32)
2882 def test_tensor_factory_cuda_type_inference(self):
2883 saved_type = torch.Tensor().type()
2887 self.assertEqual(torch.device(
'cuda:0'),
torch.tensor(0.).device)
2890 self.assertEqual(torch.device(
'cuda:0'),
torch.tensor(0.).device)
2894 def test_tensor_factory_cuda_type(self):
2895 saved_type = torch.Tensor().type()
2897 x = torch.zeros((5, 5))
2898 self.assertIs(torch.float32, x.dtype)
2899 self.assertTrue(x.is_cuda)
2901 x = torch.zeros((5, 5))
2902 self.assertIs(torch.float64, x.dtype)
2903 self.assertTrue(x.is_cuda)
2908 def test_tensor_factories_empty_bool(self):
2909 expectedShape = (1, 2)
2910 test = torch.empty(expectedShape, dtype=torch.bool)
2911 self.assertEqual(expectedShape, test.shape)
2912 self.assertEqual(expectedShape, torch.empty_like(test).shape)
2914 test = torch.full(expectedShape,
True, dtype=torch.bool)
2915 self.assertEqual(test,
torch.tensor([[
True,
True]], dtype=torch.bool))
2916 self.assertEqual(expectedShape, test.shape)
2917 self.assertEqual(expectedShape, torch.full_like(test,
True).shape)
2919 def test_tensor_factories_empty(self):
2921 shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)]
2924 for device
in devices:
2925 for shape
in shapes:
2926 self.assertEqual(shape, torch.zeros(shape, device=device).shape)
2927 self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device)).shape)
2928 self.assertEqual(shape, torch.empty(shape, device=device).shape)
2929 self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device)).shape)
2930 self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device).shape)
2931 self.assertEqual(shape, torch.full(shape, 3, device=device).shape)
2932 self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device), 3).shape)
2933 self.assertEqual(shape, torch.ones(shape, device=device).shape)
2934 self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device)).shape)
2935 self.assertEqual(shape, torch.rand(shape, device=device).shape)
2936 self.assertEqual(shape, torch.rand_like(torch.zeros(shape, device=device)).shape)
2937 self.assertEqual(shape, torch.randn(shape, device=device).shape)
2938 self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device)).shape)
2939 self.assertEqual(shape, torch.randint(6, shape, device=device).shape)
2940 self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device), 6).shape)
2942 self.assertEqual((0,), torch.arange(0, device=device).shape)
2943 self.assertEqual((0, 0), torch.eye(0, device=device).shape)
2944 self.assertEqual((0, 0), torch.eye(0, 0, device=device).shape)
2945 self.assertEqual((5, 0), torch.eye(5, 0, device=device).shape)
2946 self.assertEqual((0, 5), torch.eye(0, 5, device=device).shape)
2947 self.assertEqual((0,), torch.linspace(1, 1, 0, device=device).shape)
2948 self.assertEqual((0,), torch.logspace(1, 1, 0, device=device).shape)
2949 self.assertEqual((0,), torch.randperm(0, device=device).shape)
2950 self.assertEqual((0,), torch.bartlett_window(0, device=device).shape)
2951 self.assertEqual((0,), torch.bartlett_window(0, periodic=
False, device=device).shape)
2952 self.assertEqual((0,), torch.hamming_window(0, device=device).shape)
2953 self.assertEqual((0,), torch.hann_window(0, device=device).shape)
2954 self.assertEqual((1, 1, 0),
torch.tensor([[[]]], device=device).shape)
2955 self.assertEqual((1, 1, 0), torch.as_tensor([[[]]], device=device).shape)
2957 def test_new_tensor(self):
2958 expected = torch.autograd.Variable(torch.ByteTensor([1, 1]))
2960 res1 = expected.new_tensor([1, 1])
2961 self.assertEqual(res1, expected)
2962 res1 = expected.new_tensor([1, 1], dtype=torch.int)
2963 self.assertEqual(res1, expected)
2964 self.assertIs(torch.int, res1.dtype)
2967 res2 = expected.new_tensor(expected)
2968 self.assertEqual(res2, expected)
2970 self.assertEqual(expected, torch.ones_like(expected))
2971 res2 = expected.new_tensor(expected, dtype=torch.int)
2972 self.assertEqual(res2, expected)
2973 self.assertIs(torch.int, res2.dtype)
2979 res1 = res1.new_tensor(a)
2980 self.assertEqual(5., res1[0].item())
2982 self.assertEqual(5., res1[0].item())
2985 expected = expected.cuda(1)
2986 res1 = expected.new_tensor([1, 1])
2987 self.assertEqual(res1.get_device(), expected.get_device())
2988 res1 = expected.new_tensor([1, 1], dtype=torch.int)
2989 self.assertIs(torch.int, res1.dtype)
2990 self.assertEqual(res1.get_device(), expected.get_device())
2992 res2 = expected.new_tensor(expected)
2993 self.assertEqual(res2.get_device(), expected.get_device())
2994 res2 = expected.new_tensor(expected, dtype=torch.int)
2995 self.assertIs(torch.int, res1.dtype)
2996 self.assertEqual(res2.get_device(), expected.get_device())
2997 res2 = expected.new_tensor(expected, dtype=torch.int, device=0)
2998 self.assertIs(torch.int, res1.dtype)
2999 self.assertEqual(res2.get_device(), 0)
3001 res1 = expected.new_tensor(1)
3002 self.assertEqual(res1.get_device(), expected.get_device())
3003 res1 = expected.new_tensor(1, dtype=torch.int)
3004 self.assertIs(torch.int, res1.dtype)
3005 self.assertEqual(res1.get_device(), expected.get_device())
3007 def test_as_tensor(self):
3009 x = [[0, 1], [2, 3]]
3011 self.assertEqual(
torch.tensor(x, dtype=torch.float32), torch.as_tensor(x, dtype=torch.float32))
3015 with self.assertRaisesRegex(TypeError,
"invalid data type"):
3022 with self.assertRaisesRegex(TypeError,
"self-referential lists are incompatible"):
3027 with self.assertRaisesRegex(TypeError,
"self-referential lists are incompatible"):
3033 self.assertIs(y, torch.as_tensor(y))
3034 self.assertIsNot(y, torch.as_tensor(y, dtype=torch.float32))
3036 self.assertIsNot(y, torch.as_tensor(y, device=
'cuda'))
3037 y_cuda = y.to(
'cuda')
3038 self.assertIs(y_cuda, torch.as_tensor(y_cuda))
3039 self.assertIs(y_cuda, torch.as_tensor(y_cuda, device=
'cuda'))
3043 for dtype
in [np.float64, np.int64, np.int8, np.uint8]:
3044 n = np.random.rand(5, 6).astype(dtype)
3045 n_astensor = torch.as_tensor(n)
3047 n_astensor[0][0] = 25.7
3051 n = np.random.rand(5, 6).astype(np.float32)
3052 n_astensor = torch.as_tensor(n, dtype=torch.float64)
3053 self.assertEqual(
torch.tensor(n, dtype=torch.float64), n_astensor)
3054 n_astensor[0][1] = 250.8
3055 self.assertNotEqual(
torch.tensor(n, dtype=torch.float64), n_astensor)
3059 n = np.random.randn(5, 6)
3060 n_astensor = torch.as_tensor(n, device=
'cuda')
3061 self.assertEqual(
torch.tensor(n, device=
'cuda'), n_astensor)
3062 n_astensor[0][2] = 250.9
3063 self.assertNotEqual(
torch.tensor(n, device=
'cuda'), n_astensor)
3065 def test_diag(self):
3066 x = torch.rand(100, 100)
3067 res1 = torch.diag(x)
3068 res2 = torch.Tensor()
3069 torch.diag(x, out=res2)
3070 self.assertEqual(res1, res2)
3073 def _test_diagonal(self, dtype, device):
3074 x = torch.randn((100, 100), dtype=dtype, device=device)
3075 result = torch.diagonal(x)
3076 expected = torch.diag(x)
3077 self.assertEqual(result, expected)
3079 x = torch.randn((100, 100), dtype=dtype, device=device)
3080 result = torch.diagonal(x, 17)
3081 expected = torch.diag(x, 17)
3082 self.assertEqual(result, expected)
3084 def test_diagonal(self):
3085 self._test_diagonal(self, dtype=torch.float32, device=
'cpu')
3087 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
3088 def test_diagonal_multidim(self):
3089 x = torch.randn(10, 11, 12, 13)
3091 for args
in [(2, 2, 3),
3095 result = torch.diagonal(x, *args)
3096 expected = xn.diagonal(*args)
3097 self.assertEqual(expected.shape, result.shape)
3098 self.assertTrue(np.allclose(expected, result.numpy()))
3100 xp = x.permute(1, 2, 3, 0)
3101 result = torch.diagonal(xp, 0, -2, -1)
3102 expected = xp.numpy().diagonal(0, -2, -1)
3103 self.assertEqual(expected.shape, result.shape)
3104 self.assertTrue(np.allclose(expected, result.numpy()))
3107 def _test_diag_embed(self, dtype, device):
3108 x = torch.arange(3 * 4, dtype=dtype, device=device).view(3, 4)
3109 result = torch.diag_embed(x)
3110 expected = torch.stack([torch.diag(r)
for r
in x], 0)
3111 self.assertEqual(result, expected)
3113 result = torch.diag_embed(x, offset=1, dim1=0, dim2=2)
3114 expected = torch.stack([torch.diag(r, 1)
for r
in x], 1)
3115 self.assertEqual(result, expected)
3117 def test_diag_embed(self):
3118 self._test_diag_embed(self, dtype=torch.float32, device=
'cpu')
3121 def _test_diagflat(self, dtype, device):
3123 x = torch.randn((100,), dtype=dtype, device=device)
3124 result = torch.diagflat(x)
3125 expected = torch.diag(x)
3126 self.assertEqual(result, expected)
3129 x = torch.randn((100,), dtype=dtype, device=device)
3130 result = torch.diagflat(x, 17)
3131 expected = torch.diag(x, 17)
3132 self.assertEqual(result, expected)
3135 x = torch.randn((2, 3, 4), dtype=dtype, device=device)
3136 result = torch.diagflat(x)
3137 expected = torch.diag(x.contiguous().view(-1))
3138 self.assertEqual(result, expected)
3141 x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0)
3142 self.assertFalse(x.is_contiguous())
3143 result = torch.diagflat(x)
3144 expected = torch.diag(x.contiguous().view(-1))
3145 self.assertEqual(result, expected)
3147 def test_diagflat(self):
3148 self._test_diagflat(self, dtype=torch.float32, device=
'cpu')
3151 res1 = torch.eye(100, 100)
3152 res2 = torch.Tensor()
3153 torch.eye(100, 100, out=res2)
3154 self.assertEqual(res1, res2)
3156 def test_renorm(self):
3157 m1 = torch.randn(10, 5)
3158 res1 = torch.Tensor()
3160 def renorm(matrix, value, dim, max_norm):
3161 m1 = matrix.transpose(dim, 0).contiguous()
3163 m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0))))
3164 norms = m2.norm(value, 1,
True)
3166 new_norms = norms.clone()
3167 new_norms[torch.gt(norms, max_norm)] = max_norm
3168 new_norms.div_(norms.add_(1e-7))
3170 m1.mul_(new_norms.expand_as(m1))
3171 return m1.transpose(dim, 0)
3174 maxnorm = m1.norm(2, 1).mean()
3175 m2 = renorm(m1, 2, 1, maxnorm)
3176 m1.renorm_(2, 1, maxnorm)
3177 self.assertEqual(m1, m2, 1e-5)
3178 self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), 1e-5)
3180 m1 = torch.randn(3, 4, 5)
3181 m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
3182 maxnorm = m2.norm(2, 0).mean()
3183 m2 = renorm(m2, 2, 1, maxnorm)
3184 m1.renorm_(2, 1, maxnorm)
3185 m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
3186 self.assertEqual(m3, m2)
3187 self.assertEqual(m3.norm(2, 0), m2.norm(2, 0))
3190 def _test_renorm_ps(self, device):
3192 x = torch.randn(5, 5)
3194 for p
in [1, 2, 3, 4, inf]:
3195 res = x.renorm(p, 1, 1)
3196 expected = x / x.norm(p, 0, keepdim=
True).clamp(min=1)
3197 self.assertEqual(res.numpy(), expected.numpy(),
"renorm failed for {}-norm".format(p))
3199 def test_renorm_ps(self):
3200 self._test_renorm_ps(self, device=
'cpu')
3203 def test_renorm_ps_cuda(self):
3204 self._test_renorm_ps(self, device=
'cuda')
3207 def _test_multinomial(self, type):
3208 def make_prob_dist(shape, is_contiguous):
3210 return type(*shape).uniform_()
3211 elif len(shape) == 1:
3212 return type(*(shape + [5])).uniform_()[:, 2]
3215 new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
3216 prob_dist = type(*new_shape).uniform_()
3217 prob_dist = prob_dist.transpose(1, 4)
3218 prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
3219 assert not prob_dist.is_contiguous()
3222 for is_contiguous
in (
True,
False):
3225 for n_col
in range(4, 5 + 1):
3226 prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
3228 zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist()
3229 for i, j
in enumerate(zero_prob_indices):
3232 n_sample = n_col * 3
3233 sample_indices = torch.multinomial(prob_dist, n_sample,
True)
3234 self.assertEqual(prob_dist.dim(), 2)
3235 self.assertEqual(sample_indices.size(1), n_sample)
3236 for i
in range(n_row):
3237 zero_prob_idx = zero_prob_indices[i]
3238 if zero_prob_idx < 0:
3240 for j
in range(n_sample):
3241 self.assertNotEqual(sample_indices[i, j], zero_prob_idx,
3242 "sampled an index with zero probability")
3246 for n_col
in range(2, 10 + 1, 2):
3247 prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
3249 zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist()
3250 for i, j
in enumerate(zero_prob_indices):
3253 n_sample = max(1, n_col - 2)
3254 sample_indices = torch.multinomial(prob_dist, n_sample,
False)
3255 self.assertEqual(prob_dist.dim(), 2)
3256 self.assertEqual(sample_indices.size(1), n_sample)
3257 for i
in range(n_row):
3259 zero_prob_idx = zero_prob_indices[i]
3260 for j
in range(n_sample):
3261 sample_idx = sample_indices[i, j]
3262 if zero_prob_idx >= 0:
3263 self.assertNotEqual(sample_idx, zero_prob_idx,
3264 "sampled an index with zero probability")
3265 self.assertNotIn(sample_idx, row_samples,
"sampled an index twice")
3266 row_samples[sample_idx] =
True 3270 prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1)
3272 prob_dist[zero_prob_idx] = 0
3274 sample_indices = torch.multinomial(prob_dist, n_sample,
True)
3275 for sample_index
in sample_indices:
3276 self.assertNotEqual(sample_index, zero_prob_idx,
"sampled an index with zero probability")
3277 s_dim = sample_indices.dim()
3278 self.assertEqual(sample_indices.dim(), 1,
"wrong number of dimensions")
3279 self.assertEqual(prob_dist.dim(), 1,
"wrong number of prob_dist dimensions")
3280 self.assertEqual(sample_indices.size(0), n_sample,
"wrong number of samples")
3282 def test_multinomial(self):
3283 self._test_multinomial(self, torch.FloatTensor)
3285 def _spawn_method(self, method, arg):
3287 mp.set_start_method(
'spawn')
3288 except RuntimeError:
3290 with mp.Pool(1)
as pool:
3291 self.assertTrue(pool.map(method, [arg]))
3294 def _test_multinomial_invalid_probs(probs):
3297 torch.multinomial(probs.to(
'cpu'), 2)
3299 except RuntimeError
as e:
3300 return 'invalid multinomial distribution' in str(e)
3302 @unittest.skipIf(NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that \ 3303 don't support multiprocessing with spawn start method")
3304 @unittest.skipIf(IS_WINDOWS,
'FIXME: CUDA OOM error on Windows')
3305 @unittest.skipIf(
not PY3,
3306 "spawn start method is not supported in Python 2, \ 3307 but we need it for for testing failure case for CPU RNG on Windows")
3308 def test_multinomial_invalid_probs(self):
3309 test_method = _TestTorchMixin._test_multinomial_invalid_probs
3310 self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
3311 self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
3312 self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
3313 self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
3314 self._spawn_method(test_method, torch.Tensor([0, 1, 0]))
3317 def test_range(self):
3318 res1 = torch.range(0, 1)
3319 res2 = torch.Tensor()
3320 torch.range(0, 1, out=res2)
3321 self.assertEqual(res1, res2, 0)
3324 x = torch.zeros(2, 3)
3325 torch.range(0, 3, out=x.narrow(1, 1, 2))
3326 res2 = torch.Tensor(((0, 0, 1), (0, 2, 3)))
3327 self.assertEqual(x, res2, 1e-16)
3330 res1 = torch.Tensor((1, 0))
3331 res2 = torch.Tensor()
3332 torch.range(1, 0, -1, out=res2)
3333 self.assertEqual(res1, res2, 0)
3336 res1 = torch.ones(1)
3337 res2 = torch.Tensor()
3338 torch.range(1, 1, -1, out=res2)
3339 self.assertEqual(res1, res2, 0)
3340 torch.range(1, 1, 1, out=res2)
3341 self.assertEqual(res1, res2, 0)
3344 res1 = torch.range(0.6, 0.9, 0.1, out=torch.FloatTensor())
3345 self.assertEqual(res1.size(0), 4)
3346 res1 = torch.range(1, 10, 0.3, out=torch.FloatTensor())
3347 self.assertEqual(res1.size(0), 31)
3350 res1 = torch.range(0.6, 0.9, 0.1, out=torch.DoubleTensor())
3351 self.assertEqual(res1.size(0), 4)
3352 res1 = torch.range(1, 10, 0.3, out=torch.DoubleTensor())
3353 self.assertEqual(res1.size(0), 31)
3355 def test_range_warning(self):
3356 with warnings.catch_warnings(record=
True)
as w:
3358 self.assertEqual(len(w), 1)
3360 def test_arange(self):
3361 res1 = torch.arange(0, 1)
3362 res2 = torch.Tensor()
3363 torch.arange(0, 1, out=res2)
3364 self.assertEqual(res1, res2, 0)
3367 res1 = torch.arange(10)
3368 res2 = torch.arange(0, 10)
3369 self.assertEqual(res1, res2, 0)
3372 x = torch.zeros(2, 3)
3373 torch.arange(0, 4, out=x.narrow(1, 1, 2))
3374 res2 = torch.Tensor(((0, 0, 1), (0, 2, 3)))
3375 self.assertEqual(x, res2, 1e-16)
3378 res1 = torch.Tensor((1, 0))
3379 res2 = torch.Tensor()
3380 torch.arange(1, -1, -1, out=res2)
3381 self.assertEqual(res1, res2, 0)
3384 res1 = torch.ones(1)
3385 res2 = torch.Tensor()
3386 torch.arange(1, 0, -1, out=res2)
3387 self.assertEqual(res1, res2, 0)
3388 torch.arange(1, 2, 1, out=res2)
3389 self.assertEqual(res1, res2, 0)
3392 res1 = torch.arange(0.6, 0.89, 0.1, out=torch.FloatTensor())
3393 self.assertEqual(res1, [0.6, 0.7, 0.8])
3394 res1 = torch.arange(1, 10, 0.3, out=torch.FloatTensor())
3395 self.assertEqual(res1.size(0), 30)
3396 self.assertEqual(res1[0], 1)
3397 self.assertEqual(res1[29], 9.7)
3400 res1 = torch.arange(0.6, 0.89, 0.1, out=torch.DoubleTensor())
3401 self.assertEqual(res1, [0.6, 0.7, 0.8])
3402 res1 = torch.arange(1, 10, 0.3, out=torch.DoubleTensor())
3403 self.assertEqual(res1.size(0), 30)
3404 self.assertEqual(res1[0], 1)
3405 self.assertEqual(res1[29], 9.7)
3408 r = torch.arange(0, 5)
3409 self.assertEqual(r.min(), 0)
3410 self.assertEqual(r.max(), 4)
3411 self.assertEqual(r.numel(), 5)
3413 r = torch.arange(0, 5, 2)
3414 self.assertEqual(r.min(), 0)
3415 self.assertEqual(r.max(), 4)
3416 self.assertEqual(r.numel(), 3)
3418 r1 = torch.arange(0, 5 + 1e-6)
3419 r2 = torch.arange(0, 5)
3420 r3 = torch.arange(0, 5 - 1e-6)
3421 self.assertEqual(r1[:-1], r2, 0)
3422 self.assertEqual(r2, r3, 0)
3424 r1 = torch.arange(10, -1 + 1e-6, -1)
3425 r2 = torch.arange(10, -1, -1)
3426 r3 = torch.arange(10, -1 - 1e-6, -1)
3427 self.assertEqual(r1, r2, 0)
3428 self.assertEqual(r2, r3[:-1], 0)
3430 msg =
"unsupported range" 3431 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(0, float(
'inf')))
3432 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'inf')))
3435 for device
in devices:
3436 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(-5, float(
'nan'), device=device))
3438 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(0, float(
'-inf'), -1, device=device))
3439 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(0, float(
'inf'), device=device))
3440 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'-inf'), 10, device=device))
3441 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'nan'), 10, device=device))
3442 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'inf'), device=device))
3443 self.assertRaisesRegex(RuntimeError, msg,
lambda: torch.arange(float(
'nan'), device=device))
3445 self.assertRaisesRegex(
3446 RuntimeError,
"overflow",
3447 lambda: torch.arange(1.175494351e-38, 3.402823466e+38, device=device))
3449 def test_arange_inference(self):
3450 saved_dtype = torch.get_default_dtype()
3453 self.assertIs(torch.float32, torch.arange(1.).dtype)
3454 self.assertIs(torch.float32, torch.arange(
torch.tensor(1.)).dtype)
3455 self.assertIs(torch.float32, torch.arange(
torch.tensor(1., dtype=torch.float64)).dtype)
3457 self.assertIs(torch.int64, torch.arange(1).dtype)
3458 self.assertIs(torch.int64, torch.arange(
torch.tensor(1)).dtype)
3459 self.assertIs(torch.int64, torch.arange(
torch.tensor(1, dtype=torch.int16)).dtype)
3462 self.assertIs(torch.float32, torch.arange(1., 3).dtype)
3463 self.assertIs(torch.float32, torch.arange(
torch.tensor(1., dtype=torch.float64), 3).dtype)
3464 self.assertIs(torch.float32, torch.arange(1, 3.).dtype)
3466 self.assertIs(torch.float32, torch.arange(1, 3, 1.).dtype)
3467 self.assertIs(torch.float32,
3472 self.assertIs(torch.int64, torch.arange(1, 3).dtype)
3473 self.assertIs(torch.int64, torch.arange(
torch.tensor(1), 3).dtype)
3475 self.assertIs(torch.int64, torch.arange(1, 3, 1).dtype)
3476 self.assertIs(torch.int64,
3482 def test_randint_inference(self):
3484 for args
in [(3,), (1, 3)]:
3485 self.assertIs(torch.int64, torch.randint(*args, size=size).dtype)
3486 self.assertIs(torch.int64, torch.randint(*args, size=size, layout=torch.strided).dtype)
3487 self.assertIs(torch.int64, torch.randint(*args, size=size, generator=torch.default_generator).dtype)
3488 self.assertIs(torch.float32, torch.randint(*args, size=size, dtype=torch.float32).dtype)
3489 out = torch.empty(size, dtype=torch.float32)
3490 self.assertIs(torch.float32, torch.randint(*args, size=size, out=out).dtype)
3491 self.assertIs(torch.float32, torch.randint(*args, size=size, out=out, dtype=torch.float32).dtype)
3492 out = torch.empty(size, dtype=torch.int64)
3493 self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype)
3494 self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype)
3497 def _select_broadcastable_dims(dims_full=None):
3499 if dims_full
is None:
3501 ndims = random.randint(1, 4)
3502 dims_full = [random.randint(1, 8)
for _
in range(ndims)]
3504 ndims = len(dims_full)
3509 smaller_ndims = random.randint(1, ndims)
3512 for i
in range(ndims - 1, -1, -1):
3513 j = random.randint(1, 3)
3519 dl = 1
if len(dims_small) < smaller_ndims
else dims_full[i]
3523 dims_large = [dl] + dims_large
3524 if len(dims_small) < smaller_ndims:
3525 dims_small = [ds] + dims_small
3526 return (dims_small, dims_large, dims_full)
3529 def _test_broadcast(self, cast):
3533 "dist",
"atan2",
"pow",
"lerp",
"add",
3534 "sub",
"mul",
"div",
"fmod",
"remainder",
3535 "eq",
"ge",
"gt",
"le",
"lt",
"max",
"min",
"ne",
3536 "addcdiv",
"addcmul",
"masked_scatter",
"masked_select",
"masked_fill",
3537 "map",
"map2",
"copy" 3540 fns_3_args = {
"addcdiv",
"addcmul",
"map2"}
3543 (dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
3544 full1d = cast(torch.randn(*dims_full).flatten().float())
3545 small = cast(torch.randn(*dims_small).float())
3546 large = cast(torch.randn(*dims_large).float())
3547 small_expanded = small.expand(*dims_full)
3548 large_expanded = large.expand(*dims_full)
3550 small2_expanded =
None 3551 if fn
in fns_3_args:
3553 (dims_small2, _, _) = self._select_broadcastable_dims(dims_full)
3554 small2 = cast(torch.randn(*dims_small2).float())
3555 small2_expanded = small2.expand(*dims_full)
3557 if small.is_cuda
and fn
in [
'map',
'map2']:
3561 if hasattr(large_expanded, fn):
3564 expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
3566 def tensorfn(myfn, t1, t2):
3568 return myfn(t1, 0.5)
3569 elif fn ==
"masked_select":
3571 elif fn ==
"masked_scatter":
3572 return myfn(t1 < 0.5, full1d)
3573 elif fn ==
"masked_fill":
3574 return myfn(t1 < 0.5, 1.0)
3575 elif fn
in fns_3_args:
3576 return myfn(1, t1, t2)
3581 for first, second, third
in [(large, small, small2), (small, large, small2),
3582 (small2, small, large), (small2, large, small)]:
3585 method_expanded = getattr(expanded[first], fn)
3586 method = getattr(first, fn)
3587 r1 = tensorfn(method_expanded, expanded[second], expanded[third])
3588 r2 = tensorfn(method, second, third)
3589 self.assertEqual(r1, r2)
3592 if hasattr(torch, fn):
3593 fntorch = getattr(torch, fn)
3594 expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
3596 def torchfn(t1, t2, t3):
3598 return fntorch(t1, t2, 0.5)
3599 elif fn ==
"masked_select":
3600 return fntorch(t1, t2 < 0)
3601 elif fn ==
"masked_scatter":
3602 return fntorch(t1, t2 < 0.5, full1d)
3603 elif fn ==
"masked_fill":
3604 return fntorch(t1, t2 < 0.5, 1.0)
3605 elif fn
in fns_3_args:
3606 return fntorch(t1, 1.0, t2, t3)
3608 return fntorch(t1, t2)
3611 for first, second, third
in [(large, small, small2), (small, large, small2),
3612 (small2, small, large), (small2, large, small)]:
3615 r1 = torchfn(expanded[first], expanded[second], expanded[third])
3616 r2 = torchfn(first, second, third)
3617 self.assertEqual(r1, r2)
3622 if not hasattr(large_expanded, fn +
"_"):
3626 large_expanded_clone = large_expanded.clone()
3628 def tensorfn_inplace(t0, t1, t2=None):
3629 t0_fn = getattr(t0, fn +
"_")
3631 return t0_fn(t1, 0.5)
3632 elif fn ==
"masked_scatter":
3633 return t0_fn(t1 < 0.5, full1d)
3634 elif fn ==
"masked_fill":
3635 return t0_fn(t1 < 0.5, 1.0)
3637 return t0_fn(t1,
lambda x, y: x + y)
3639 return t0_fn(t1, t2,
lambda x, y, z: x + y + z)
3640 elif fn
in fns_3_args:
3641 return t0_fn(1.0, t1, t2)
3644 r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
3645 r2 = tensorfn_inplace(large_expanded_clone, small, small2)
3648 if (0
not in large_expanded.stride()
and 0
not in large_expanded_clone.stride()):
3649 self.assertEqual(r1, r2)
3651 def broadcastable(t0, t1, t2=None):
3656 except RuntimeError:
3660 def _test_in_place_broadcastable(t0, t1, t2=None):
3661 if not broadcastable(t0, t1, t2):
3662 same_size = t0.numel() == t1.numel()
and (t0.numel() == t2.numel()
if t2
is not None else True)
3664 self.assertRaises(RuntimeError,
lambda: tensorfn_inplace(t0, t1, t2))
3666 tensorfn_inplace(t0, t1, t2)
3668 if fn
not in fns_3_args:
3669 _test_in_place_broadcastable(small, large_expanded)
3670 _test_in_place_broadcastable(small, large)
3672 _test_in_place_broadcastable(small2, small_expanded, large_expanded)
3673 _test_in_place_broadcastable(small2, small, large)
3675 def test_broadcast(self):
3676 self._test_broadcast(self,
lambda t: t)
3678 def test_broadcast_empty(self):
3680 self.assertRaises(RuntimeError,
lambda: torch.randn(5, 0) + torch.randn(0, 5))
3681 self.assertEqual(torch.randn(5, 0), torch.randn(0) + torch.randn(5, 0))
3682 self.assertEqual(torch.randn(5, 0, 0), torch.randn(0) + torch.randn(5, 0, 1))
3685 self.assertEqual(torch.randn(5, 0, 6), torch.randn(()) + torch.randn(5, 0, 6))
3688 self.assertEqual(torch.randn(0), torch.randn(0) + torch.randn(1))
3689 self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7),
3690 torch.randn(0, 7, 0, 6, 5, 0, 1) + torch.randn(1, 1, 5, 1, 7))
3691 self.assertRaises(RuntimeError,
lambda: torch.randn(7, 0) + torch.randn(2, 1))
3693 def test_broadcast_tensors(self):
3694 x0 = torch.randn(2, 1, 3)
3696 x2 = torch.randn(3, 1)
3697 expected_size = (2, 3, 3)
3699 y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2)
3700 self.assertTrue(y0.size() == expected_size)
3701 self.assertTrue(y1.size() == expected_size)
3702 self.assertTrue(y2.size() == expected_size)
3705 def _test_contiguous(self, cast):
3706 x = cast(torch.randn(1, 16, 5, 5))
3707 self.assertTrue(x.is_contiguous())
3708 stride = list(x.stride())
3711 x.set_(x.storage(), 0, x.size(), stride)
3712 self.assertTrue(x.is_contiguous())
3714 def test_contiguous(self):
3715 return self._test_contiguous(self,
lambda t: t)
3717 def test_empty_tensor_props(self):
3718 sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)]
3721 for device
in devices:
3722 x = torch.empty(tuple(size), device=device)
3723 self.assertEqual(size, x.shape)
3724 self.assertTrue(x.is_contiguous())
3725 size_ones_instead_of_zeros = (x
if x != 0
else 1
for x
in size)
3726 y = torch.empty(tuple(size_ones_instead_of_zeros), device=device)
3727 self.assertEqual(x.stride(), y.stride())
3729 def test_scalars_as_floats(self):
3730 "zero-dim variables that don't require grad should bind to scalar arguments" 3734 self.assertEqual(y.addcmul(y, y, value=x), 21)
3737 self.assertRaises(Exception,
lambda: y.addcmul(y, y, value=x))
3740 def _test_broadcast_fused_matmul(self, cast):
3741 fns = [
"baddbmm",
"addbmm",
"addmm",
"addmv",
"addr"]
3744 batch_dim = random.randint(1, 8)
3745 n_dim = random.randint(1, 8)
3746 m_dim = random.randint(1, 8)
3747 p_dim = random.randint(1, 8)
3749 def dims_full_for_fn():
3751 return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
3752 elif fn ==
"addbmm":
3753 return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
3755 return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
3757 return ([n_dim], [n_dim, m_dim], [m_dim])
3759 return ([n_dim, m_dim], [n_dim], [m_dim])
3761 raise AssertionError(
"unknown function")
3763 (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
3764 (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
3766 t0_small = cast(torch.randn(*t0_dims_small).float())
3767 t1 = cast(torch.randn(*t1_dims).float())
3768 t2 = cast(torch.randn(*t2_dims).float())
3770 t0_full = cast(t0_small.expand(*t0_dims_full))
3772 fntorch = getattr(torch, fn)
3773 r0 = fntorch(t0_small, t1, t2)
3774 r1 = fntorch(t0_full, t1, t2)
3775 self.assertEqual(r0, r1)
3777 def test_broadcast_fused_matmul(self):
3778 self._test_broadcast_fused_matmul(self,
lambda t: t)
3781 def _test_broadcast_batched_matmul(self, cast):
3782 n_dim = random.randint(1, 8)
3783 m_dim = random.randint(1, 8)
3784 p_dim = random.randint(1, 8)
3785 full_batch_dims = [random.randint(1, 3)
for i
in range(random.randint(1, 3))]
3786 (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims)
3788 def verify_batched_matmul(full_lhs, one_dimensional):
3789 if not one_dimensional:
3790 lhs_dims = [n_dim, m_dim]
3791 rhs_dims = [m_dim, p_dim]
3792 result_dims = [n_dim, p_dim]
3794 lhs_dims = [n_dim, m_dim]
if full_lhs
else [m_dim]
3795 rhs_dims = [m_dim, p_dim]
if not full_lhs
else [m_dim]
3796 result_dims = [n_dim]
if full_lhs
else [p_dim]
3798 lhs_mat_dims = lhs_dims
if len(lhs_dims) != 1
else [1, m_dim]
3799 rhs_mat_dims = rhs_dims
if len(rhs_dims) != 1
else [m_dim, 1]
3800 full_mat_dims = lhs_mat_dims
if full_lhs
else rhs_mat_dims
3801 dim0_dims = rhs_dims
if full_lhs
else lhs_dims
3802 small_dims = batch_dims_small + (rhs_mat_dims
if full_lhs
else lhs_mat_dims)
3804 small = cast(torch.randn(*(small_dims)).float())
3805 dim0 = cast(torch.randn(*(dim0_dims)).float())
3806 full = cast(torch.randn(*(full_batch_dims + full_mat_dims)).float())
3807 if not one_dimensional:
3808 (lhsTensors, rhsTensors) = ((full,), (small, dim0))
if full_lhs
else ((small, dim0), (full,))
3810 (lhsTensors, rhsTensors) = ((full,), (dim0,))
if full_lhs
else ((dim0,), (full,))
3812 def maybe_squeeze_result(l, r, result):
3813 if len(lhs_dims) == 1
and l.dim() != 1:
3814 return result.squeeze(-2)
3815 elif len(rhs_dims) == 1
and r.dim() != 1:
3816 return result.squeeze(-1)
3820 for lhs
in lhsTensors:
3821 lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims)))
3822 lhs_expanded_matmul_fn = getattr(lhs_expanded,
"matmul")
3823 for rhs
in rhsTensors:
3824 rhs_expanded = ((rhs
if len(rhs_dims) != 1
else rhs.unsqueeze(-1)).
3825 expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims))))
3826 truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded))
3827 for l
in (lhs, lhs_expanded):
3828 for r
in (rhs, rhs_expanded):
3829 l_matmul_fn = getattr(l,
"matmul")
3830 result = maybe_squeeze_result(l, r, l_matmul_fn(r))
3831 self.assertEqual(truth, result)
3833 torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
3834 self.assertEqual(truth, torch_result)
3836 out = torch.zeros_like(torch_result)
3837 torch.matmul(l, r, out=out)
3838 self.assertEqual(truth, maybe_squeeze_result(l, r, out))
3841 bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
3842 rhs_expanded.contiguous().view(-1, *rhs_mat_dims)))
3843 self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims))
3845 for indices
in product((
True,
False), repeat=2):
3846 verify_batched_matmul(*indices)
3848 def test_broadcast_batched_matmul(self):
3849 self._test_broadcast_batched_matmul(self,
lambda t: t)
3851 def test_copy_broadcast(self):
3852 torch.zeros(5, 6).copy_(torch.zeros(6))
3853 self.assertRaises(RuntimeError,
lambda: torch.zeros(5, 6).copy_(torch.zeros(30)))
3855 def test_randperm(self):
3856 _RNGState = torch.get_rng_state()
3857 res1 = torch.randperm(100)
3858 res2 = torch.LongTensor()
3859 torch.set_rng_state(_RNGState)
3860 torch.randperm(100, out=res2)
3861 self.assertEqual(res1, res2, 0)
3864 res1 = torch.randperm(0)
3865 res2 = torch.LongTensor(5)
3866 torch.randperm(0, out=res2)
3867 self.assertEqual(res1.numel(), 0)
3868 self.assertEqual(res2.numel(), 0)
3870 def test_random(self):
3872 t = torch.FloatTensor(200)
3878 self.assertEqual(t.min(), lb)
3879 self.assertEqual(t.max(), ub - 1)
3883 self.assertEqual(t.min(), 0)
3884 self.assertEqual(t.max(), ub - 1)
3887 def _test_random_neg_values(self, use_cuda=False):
3888 signed_types = [
'torch.DoubleTensor',
'torch.FloatTensor',
'torch.LongTensor',
3889 'torch.IntTensor',
'torch.ShortTensor']
3890 for tname
in signed_types:
3891 res = torch.rand(SIZE, SIZE).type(tname)
3894 res.random_(-10, -1)
3895 self.assertLessEqual(res.max().item(), 9)
3896 self.assertGreaterEqual(res.min().item(), -10)
3898 def test_random_neg_values(self):
3899 self._test_random_neg_values(self)
3901 def assertIsOrdered(self, order, x, mxx, ixx, task):
3903 if order ==
'descending':
3904 def check_order(a, b):
3908 return a != a
or a >= b
3909 elif order ==
'ascending':
3910 def check_order(a, b):
3912 return b != b
or a <= b
3914 error(
'unknown order "{}", must be "ascending" or "descending"'.format(order))
3917 for j, k
in product(range(SIZE), range(1, SIZE)):
3918 self.assertTrue(check_order(mxx[j][k - 1], mxx[j][k]),
3919 'torch.sort ({}) values unordered for {}'.format(order, task))
3922 indicesCorrect =
True 3923 size = x.size(x.dim() - 1)
3924 for k
in range(size):
3926 for j
in range(size):
3927 self.assertEqual(x[k][ixx[k][j]], mxx[k][j],
3928 'torch.sort ({}) indices wrong for {}'.format(order, task))
3930 self.assertEqual(len(seen), size)
3932 def test_sort(self):
3934 x = torch.rand(SIZE, SIZE)
3935 res1val, res1ind = torch.sort(x)
3938 res2val = torch.Tensor()
3939 res2ind = torch.LongTensor()
3940 torch.sort(x, out=(res2val, res2ind))
3941 self.assertEqual(res1val, res2val, 0)
3942 self.assertEqual(res1ind, res2ind, 0)
3943 self.assertEqual(torch.argsort(x), res1ind)
3944 self.assertEqual(x.argsort(), res1ind)
3947 self.assertIsOrdered(
'ascending', x, res2val, res2ind,
'random')
3951 torch.sort(torch.Tensor((50, 40, 30, 20, 10)))[0],
3952 torch.Tensor((10, 20, 30, 40, 50)),
3957 x = torch.floor(torch.rand(SIZE, SIZE) * 10)
3958 torch.sort(x, out=(res2val, res2ind))
3959 self.assertIsOrdered(
'ascending', x, res2val, res2ind,
'random with duplicate keys')
3962 x = torch.rand(SIZE, SIZE)
3963 res1val, res1ind = torch.sort(x, x.dim() - 1,
True)
3966 res2val = torch.Tensor()
3967 res2ind = torch.LongTensor()
3968 torch.sort(x, x.dim() - 1,
True, out=(res2val, res2ind))
3969 self.assertEqual(res1val, res2val, 0)
3970 self.assertEqual(res1ind, res2ind, 0)
3971 self.assertEqual(torch.argsort(x, x.dim() - 1,
True), res1ind)
3972 self.assertEqual(x.argsort(x.dim() - 1,
True), res1ind)
3975 self.assertIsOrdered(
'descending', x, res2val, res2ind,
'random')
3979 torch.sort(torch.Tensor((10, 20, 30, 40, 50)), 0,
True)[0],
3980 torch.Tensor((50, 40, 30, 20, 10)),
3985 self.assertIsOrdered(
'descending', x, res2val, res2ind,
'random with duplicate keys')
3988 x = torch.rand(SIZE, SIZE)
3989 x[1][2] = float(
'NaN')
3990 x[3][0] = float(
'NaN')
3991 torch.sort(x, out=(res2val, res2ind))
3992 self.assertIsOrdered(
'ascending', x, res2val, res2ind,
3994 torch.sort(x, out=(res2val, res2ind), descending=
True)
3995 self.assertIsOrdered(
'descending', x, res2val, res2ind,
3998 @unittest.skipIf(
not TEST_NUMPY,
'Numpy not found')
3999 def test_tensordot(self):
4002 a = torch.arange(60., device=d).reshape(3, 4, 5)
4003 b = torch.arange(24., device=d).reshape(4, 3, 2)
4004 c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
4005 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
4006 axes=([1, 0], [0, 1])))
4007 self.assertEqual(c, cn)
4008 a = torch.randn(2, 3, 4, 5, device=d)
4009 b = torch.randn(4, 5, 6, 7, device=d)
4010 c = torch.tensordot(a, b, dims=2).cpu()
4011 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
4013 self.assertEqual(c, cn)
4014 c = torch.tensordot(a, b).cpu()
4015 cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
4016 self.assertEqual(c, cn)
4018 def test_topk(self):
4019 def topKViaSort(t, k, dim, dir):
4020 sorted, indices = t.sort(dim, dir)
4021 return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
4023 def compareTensors(t, res1, ind1, res2, ind2, dim):
4025 self.assertEqual(res1, res2, 0)
4029 if not ind1.eq(ind2).all():
4033 vals = t.gather(dim, ind2)
4034 self.assertEqual(res1, vals, 0)
4036 def compare(t, k, dim, dir):
4037 topKVal, topKInd = t.topk(k, dim, dir,
True)
4038 sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
4039 compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
4041 t = torch.rand(random.randint(1, SIZE),
4042 random.randint(1, SIZE),
4043 random.randint(1, SIZE))
4045 for _kTries
in range(3):
4046 for _dimTries
in range(3):
4047 for transpose
in (
True,
False):
4048 for dir
in (
True,
False):
4051 dim1 = random.randrange(t.ndimension())
4054 dim2 = random.randrange(t.ndimension())
4056 testTensor = t.transpose(dim1, dim2)
4058 dim = random.randrange(testTensor.ndimension())
4059 k = random.randint(1, testTensor.size(dim))
4060 compare(testTensor, k, dim, dir)
4062 def test_topk_arguments(self):
4063 q = torch.randn(10, 2, 10)
4065 self.assertRaises(TypeError,
lambda: q.topk(4,
True))
4068 def test_topk_noncontiguous_gpu(self):
4069 t = torch.randn(20, device=
"cuda")[::2]
4070 top1, idx1 = t.topk(5)
4071 top2, idx2 = t.contiguous().topk(5)
4072 self.assertEqual(top1, top2)
4073 self.assertEqual(idx1, idx2)
4076 def _test_kthvalue(self, device='cpu'):
4078 x = torch.rand(SIZE, SIZE, SIZE, device=device)
4081 k = random.randint(1, SIZE)
4082 res1val, res1ind = torch.kthvalue(x, k, keepdim=
False)
4083 res2val, res2ind = torch.sort(x)
4085 self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4086 self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4088 k = random.randint(1, SIZE)
4090 res1ind =
torch.tensor([], dtype=torch.long, device=device)
4091 torch.kthvalue(x, k, keepdim=
False, out=(res1val, res1ind))
4092 res2val, res2ind = torch.sort(x)
4093 self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4094 self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4097 k = random.randint(1, SIZE)
4098 res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=
False)
4099 res2val, res2ind = torch.sort(x, 0)
4100 self.assertEqual(res1val, res2val[k - 1], 0)
4101 self.assertEqual(res1ind, res2ind[k - 1], 0)
4104 y = x.narrow(1, 0, 1)
4106 k = random.randint(1, SIZE)
4107 res1val, res1ind = torch.kthvalue(y, k)
4108 res2val, res2ind = torch.kthvalue(y0, k)
4109 self.assertEqual(res1val, res2val, 0)
4110 self.assertEqual(res1ind, res2ind, 0)
4113 self.assertEqual(x, x0, 0)
4117 self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0)
4118 self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0)
4122 x = torch.rand(SIZE, SIZE, SIZE, device=device)
4123 x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan
4124 ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1]
4125 res2val, res2ind = torch.sort(x)
4127 res1val, res1ind = torch.kthvalue(x, k, keepdim=
False)
4128 self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4129 self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4131 def test_kthvalue(self):
4132 self._test_kthvalue(self)
4134 def test_median(self):
4135 for size
in (155, 156):
4136 x = torch.rand(size, size)
4139 nelem = x.nelement()
4140 res1val = torch.median(x)
4141 res2val, _ = torch.sort(x.view(nelem))
4142 ind = int(math.floor((nelem + 1) / 2) - 1)
4144 self.assertEqual(res2val[ind], res1val, 0)
4146 res1val, res1ind = torch.median(x, dim=1, keepdim=
False)
4147 res2val, res2ind = torch.sort(x)
4148 ind = int(math.floor((size + 1) / 2) - 1)
4150 self.assertEqual(res2val.select(1, ind), res1val, 0)
4151 self.assertEqual(res2val.select(1, ind), res1val, 0)
4154 res2val = torch.Tensor()
4155 res2ind = torch.LongTensor()
4156 torch.median(x, dim=-1, keepdim=
False, out=(res2val, res2ind))
4157 self.assertEqual(res2val, res1val, 0)
4158 self.assertEqual(res2ind, res1ind, 0)
4161 res1val, res1ind = torch.median(x, 0, keepdim=
False)
4162 res2val, res2ind = torch.sort(x, 0)
4163 self.assertEqual(res1val, res2val[ind], 0)
4164 self.assertEqual(res1ind, res2ind[ind], 0)
4167 self.assertEqual(x, x0, 0)
4169 def test_mode(self):
4170 x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE)
4176 res1val = torch.Tensor(SIZE).fill_(1)
4178 res1ind = torch.LongTensor(SIZE).fill_(1)
4179 res1ind[0] = SIZE - 1
4180 res1ind[1] = SIZE - 1
4182 res2val, res2ind = torch.mode(x, keepdim=
False)
4183 self.assertEqual(res1val, res2val, 0)
4184 self.assertEqual(res1ind, res2ind, 0)
4187 res2val = torch.Tensor()
4188 res2ind = torch.LongTensor()
4189 torch.mode(x, keepdim=
False, out=(res2val, res2ind))
4190 self.assertEqual(res1val, res2val, 0)
4191 self.assertEqual(res1ind, res2ind, 0)
4194 res2val, res2ind = torch.mode(x, 0,
False)
4195 self.assertEqual(res1val, res2val, 0)
4196 self.assertEqual(res1ind, res2ind, 0)
4199 self.assertEqual(x, x0, 0)
4201 def test_trilu_indices(self):
4202 for test_args
in tri_tests_args:
4203 _compare_trilu_indices(self, *test_args)
4204 run_additional_tri_tests(self,
'cpu')
4208 3, 3, dtype=torch.long, device=
'cpu', layout=torch.strided)
4210 x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3))
4212 x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3))
4215 def _test_triu_tril(self, cast):
4216 def gen_mask(shape, diagonal, cast, upper):
4217 mask = torch.zeros(*shape[-2:]).byte()
4218 for i
in range(shape[-2]):
4219 for j
in range(shape[-1]):
4220 cond = j - i < diagonal
if upper
else j - i > diagonal
4223 return cast(mask.expand(*shape))
4225 torch_functions = {
True: torch.triu,
False: torch.tril}
4227 numpy_functions = {
True: np.triu,
False: np.tril}
4229 def run_test(shape, cast, diagonal):
4230 x_cpu = torch.randn(*shape)
4233 for upper
in [
True,
False]:
4235 torch_tri_func = torch_functions[upper]
4236 res1 = torch_tri_func(x, diagonal=diagonal)
4237 res2 = cast(torch.Tensor())
4238 torch_tri_func(x, diagonal=diagonal, out=res2)
4239 exp_mask = gen_mask(shape, diagonal, cast, upper)
4240 expected = torch.where(exp_mask,
torch.tensor(0).type_as(x), x)
4241 self.assertEqual(res1, res2, 0)
4242 self.assertEqual(expected, res1, 0)
4245 if not (0
in shape
or 1
in shape):
4246 for s
in range(-len(shape), -1):
4248 x_nc = x.clone().transpose(s, s + 1)
4249 exp_mask = gen_mask(x_nc.size(), diagonal, cast, upper)
4250 assert not x_nc.is_contiguous(),
"x is intentionally non-contiguous" 4251 exp_nc = torch.where(exp_mask,
torch.tensor(0).type_as(x), x_nc)
4252 self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
4253 x_nc_is_contiguous = x_nc.is_contiguous()
4255 self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
4257 self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
4259 self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous,
4260 "contiguity of x_nc should not be changed")
4263 expanded_size = (x.size(0),) + x.size()
4264 x_expanded = x.clone().expand(*expanded_size)
4265 assert 0
in x_expanded.stride(),
"x intentionally has 0 in its stride" 4266 output = torch_tri_func(x_expanded, diagonal)
4267 self.assertEqual(output, expected.expand(expanded_size), 0)
4268 self.assertTrue(0
in x_expanded.stride(),
4269 "geometry of x_expanded should be the same")
4271 self.assertEqual(output, x_expanded.triu_(diagonal), 0)
4273 self.assertEqual(output, x_expanded.tril_(diagonal), 0)
4279 numpy_tri_func = numpy_functions[upper]
4280 self.assertEqual(numpy_tri_func(x_cpu.numpy(), diagonal), res1.cpu().numpy())
4282 diagonals = [-2, -1, 0, 1, 2]
4283 shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3),
4284 (7, 3), (5, 7, 3), (7, 5, 7, 3),
4285 (3, 7), (5, 3, 7), (7, 5, 3, 7),
4286 (3, 0), (0, 3, 3), (3, 3, 0, 0),
4287 (3, 1), (5, 3, 1), (7, 5, 3, 1),
4288 (1, 3), (5, 1, 3), (7, 5, 1, 3)]
4289 for s, d
in product(shapes, diagonals):
4290 run_test(s, cast, d)
4292 def test_triu_tril(self):
4293 self._test_triu_tril(self,
lambda t: t)
4297 for dtype
in (torch.half, torch.double, torch.int):
4298 for dim
in range(-3, 3):
4299 pos_dim = dim
if dim >= 0
else 3 + dim
4300 x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4301 y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4302 z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4304 res1 = torch.cat((x, y, z), dim)
4305 self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
4306 self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
4307 self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)
4309 x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE)).to(dtype)
4310 self.assertEqual(torch.cat(torch.split(x, 7)), x)
4311 self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
4313 y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE)).to(dtype)
4314 z = torch.cat([x, y])
4315 self.assertEqual(z.size(), (21, SIZE, SIZE))
4317 self.assertRaises(RuntimeError,
lambda: torch.cat([]))
4318 self.assertRaisesRegex(TypeError,
'got None',
lambda: torch.cat([x,
None]))
4320 def test_cat_bad_input_sizes(self):
4321 x = torch.randn(2, 1)
4322 y = torch.randn(2, 1, 1)
4323 z = torch.randn(2, 1, 1)
4324 self.assertRaises(RuntimeError,
lambda: torch.cat([x, y, z]))
4326 x = torch.randn(2, 1, 2)
4327 y = torch.randn(2, 1, 1)
4328 z = torch.randn(2, 2, 1)
4329 self.assertRaises(RuntimeError,
lambda: torch.cat([x, y, z], dim=1))
4331 def test_cat_scalars(self):
4334 with self.assertRaisesRegex(RuntimeError,
'zero-dimensional.*cannot be concatenated'):
4338 def _test_cat_empty_legacy(self, use_cuda=False):
4341 dtype = torch.float32
4342 device =
'cuda' if use_cuda
else 'cpu' 4344 x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
4345 empty = torch.randn((0,), dtype=dtype, device=device)
4347 res1 = torch.cat([x, empty], dim=1)
4348 res2 = torch.cat([empty, x], dim=1)
4349 self.assertEqual(res1, res2)
4351 conv = torch.nn.Conv2d(3, 3, kernel_size=1).float()
4354 res1 = torch.cat([
conv(x), empty], dim=1)
4355 res2 = torch.cat([empty,
conv(x)], dim=1)
4356 self.assertEqual(res1, res2)
4358 res1 = torch.cat([empty, empty], dim=1)
4359 self.assertEqual(res1, empty)
4361 with self.assertRaisesRegex(RuntimeError,
4362 'expected a non-empty list of Tensors'):
4363 torch.cat([], dim=1)
4365 def test_cat_empty_legacy(self):
4366 self._test_cat_empty_legacy(self)
4369 def _test_cat_empty(self, use_cuda=False):
4370 dtype = torch.float32
4371 device =
'cuda' if use_cuda
else 'cpu' 4373 x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
4374 empty = torch.randn((4, 0, 32, 32), dtype=dtype, device=device)
4376 res1 = torch.cat([x, empty], dim=1)
4377 res2 = torch.cat([empty, x], dim=1)
4378 self.assertEqual(res1, res2)
4380 conv = torch.nn.Conv2d(3, 3, kernel_size=1).float()
4383 res1 = torch.cat([
conv(x), empty], dim=1)
4384 res2 = torch.cat([empty,
conv(x)], dim=1)
4385 self.assertEqual(res1, res2)
4387 res1 = torch.cat([empty, empty], dim=1)
4388 self.assertEqual(res1, empty)
4391 empty = torch.randn((4, 0, 31, 32), dtype=dtype, device=device)
4392 self.assertRaises(RuntimeError,
lambda: torch.cat([x, empty], dim=1))
4393 self.assertRaises(RuntimeError,
lambda: torch.cat([empty, x], dim=1))
4396 empty = torch.randn((4, 0), dtype=dtype, device=device)
4397 self.assertRaises(RuntimeError,
lambda: torch.cat([x, empty], dim=1))
4398 self.assertRaises(RuntimeError,
lambda: torch.cat([empty, x], dim=1))
4400 def test_cat_empty(self):
4401 self._test_cat_empty(self)
4403 def test_narrow(self):
4404 x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
4405 self.assertEqual(x.narrow(0, 0, 1), torch.Tensor([[0, 1, 2]]))
4406 self.assertEqual(x.narrow(0, 0, 2), torch.Tensor([[0, 1, 2], [3, 4, 5]]))
4407 self.assertEqual(x.narrow(0, 1, 1), torch.Tensor([[3, 4, 5]]))
4408 self.assertEqual(x.narrow(0, -1, 1), torch.Tensor([[6, 7, 8]]))
4409 self.assertEqual(x.narrow(0, -2, 2), torch.Tensor([[3, 4, 5], [6, 7, 8]]))
4410 self.assertEqual(x.narrow(0, -3, 3), torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
4411 self.assertEqual(x.narrow(-1, -1, 1), torch.Tensor([[2], [5], [8]]))
4412 self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]]))
4414 def test_narrow_empty(self):
4416 for device
in devices:
4417 x = torch.randn(2, 3, 4, device=device)
4418 for d
in range(x.dim()):
4419 y = x.narrow(d, x.size(d), 0)
4422 self.assertEqual(sz, y.size())
4424 def test_stack(self):
4425 for dtype
in (torch.half, torch.double, torch.int):
4426 x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4427 y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4428 z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4429 for dim
in range(4):
4430 res = torch.stack((x, y, z), dim)
4431 res_neg = torch.stack((x, y, z), dim - 4)
4432 expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
4433 self.assertEqual(res, res_neg)
4434 self.assertEqual(res.size(), expected_size)
4435 self.assertEqual(res.select(dim, 0), x, 0)
4436 self.assertEqual(res.select(dim, 1), y, 0)
4437 self.assertEqual(res.select(dim, 2), z, 0)
4439 def test_stack_out(self):
4440 for dtype
in (torch.half, torch.double, torch.int):
4441 x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4442 y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4443 z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4444 for dim
in range(4):
4445 expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
4446 res_out = x.new(expected_size)
4447 res_neg_out = x.new(expected_size)
4448 res_out_dp = res_out.data_ptr()
4449 res_out_neg_dp = res_neg_out.data_ptr()
4450 torch.stack((x, y, z), dim, out=res_out)
4451 torch.stack((x, y, z), dim - 4, out=res_neg_out)
4452 self.assertEqual(res_out, res_neg_out)
4453 self.assertEqual(res_out.size(), expected_size)
4454 self.assertEqual(res_out_dp, res_out.data_ptr())
4455 self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr())
4456 self.assertEqual(res_out.select(dim, 0), x, 0)
4457 self.assertEqual(res_out.select(dim, 1), y, 0)
4458 self.assertEqual(res_out.select(dim, 2), z, 0)
4460 def test_unbind(self):
4461 x = torch.rand(2, 3, 4, 5)
4462 for dim
in range(4):
4463 res = torch.unbind(x, dim)
4464 res2 = x.unbind(dim)
4465 self.assertEqual(x.size(dim), len(res))
4466 self.assertEqual(x.size(dim), len(res2))
4467 for i
in range(dim):
4468 self.assertEqual(x.select(dim, i), res[i])
4469 self.assertEqual(x.select(dim, i), res2[i])
4472 def test_linspace(self):
4474 for device
in devices:
4475 _from = random.random()
4476 to = _from + random.random()
4477 res1 = torch.linspace(_from, to, 137, device=device)
4479 torch.linspace(_from, to, 137, out=res2)
4480 self.assertEqual(res1, res2, 0)
4481 self.assertRaises(RuntimeError,
lambda: torch.linspace(0, 1, -1, device=device))
4482 self.assertEqual(torch.linspace(0, 1, 1, device=device), torch.zeros(1, device=device), 0)
4485 self.assertEqual(torch.linspace(2, 0, 3, device=device),
torch.tensor((2, 1, 0), device=device), 0)
4488 x = torch.zeros(2, 3, device=device)
4489 y = torch.linspace(0, 3, 4, out=x.narrow(1, 1, 2))
4490 self.assertEqual(x,
torch.tensor(((0, 0, 1), (0, 2, 3)), device=device), 0)
4492 def test_logspace(self):
4493 _from = random.random()
4494 to = _from + random.random()
4495 res1 = torch.logspace(_from, to, 137)
4496 res2 = torch.Tensor()
4497 torch.logspace(_from, to, 137, out=res2)
4498 self.assertEqual(res1, res2, 0)
4499 self.assertRaises(RuntimeError,
lambda: torch.logspace(0, 1, -1))
4500 self.assertEqual(torch.logspace(0, 1, 1), torch.ones(1), 0)
4503 self.assertEqual(torch.logspace(1, 0, 2), torch.Tensor((10, 1)), 0)
4506 x = torch.zeros(2, 3)
4507 y = torch.logspace(0, 3, 4, out=x.narrow(1, 1, 2))
4508 self.assertEqual(x, torch.Tensor(((0, 1, 10), (0, 100, 1000))), 0)
4510 def test_rand(self):
4511 torch.manual_seed(123456)
4512 res1 = torch.rand(SIZE, SIZE)
4513 res2 = torch.Tensor()
4514 torch.manual_seed(123456)
4515 torch.rand(SIZE, SIZE, out=res2)
4516 self.assertEqual(res1, res2)
4518 def test_randint(self):
4519 torch.manual_seed(123456)
4520 res1 = torch.randint(0, 6, (SIZE, SIZE))
4521 res2 = torch.Tensor()
4522 torch.manual_seed(123456)
4523 torch.randint(0, 6, (SIZE, SIZE), out=res2)
4524 torch.manual_seed(123456)
4525 res3 = torch.randint(6, (SIZE, SIZE))
4526 res4 = torch.Tensor()
4527 torch.manual_seed(123456)
4528 torch.randint(6, (SIZE, SIZE), out=res4)
4529 self.assertEqual(res1, res2)
4530 self.assertEqual(res1, res3)
4531 self.assertEqual(res1, res4)
4532 self.assertEqual(res2, res3)
4533 self.assertEqual(res2, res4)
4534 self.assertEqual(res3, res4)
4535 res1 = res1.view(-1)
4536 high = (res1 < 6).type(torch.LongTensor)
4537 low = (res1 >= 0).type(torch.LongTensor)
4538 tensorSize = res1.size()[0]
4539 assert(tensorSize == high.sum())
4540 assert(tensorSize == low.sum())
4542 def test_randn(self):
4543 torch.manual_seed(123456)
4544 res1 = torch.randn(SIZE, SIZE)
4545 res2 = torch.Tensor()
4546 torch.manual_seed(123456)
4547 torch.randn(SIZE, SIZE, out=res2)
4548 self.assertEqual(res1, res2)
4550 def test_slice(self):
4551 empty = torch.empty(0, 4)
4552 x = torch.arange(0., 16).view(4, 4)
4553 self.assertEqual(x[:], x)
4554 self.assertEqual(x[:4], x)
4556 self.assertEqual(x[:5], x)
4558 self.assertEqual(x[2:1], empty)
4559 self.assertEqual(x[2:2], empty)
4561 self.assertEqual(x[10:12], empty)
4563 self.assertEqual(x[:1].data.tolist(), [[0, 1, 2, 3]])
4564 self.assertEqual(x[:-3].data.tolist(), [[0, 1, 2, 3]])
4565 self.assertEqual(x[:, -2:3].data.tolist(), [[2], [6], [10], [14]])
4566 self.assertEqual(x[0:-1:2].data.tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]])
4568 def test_is_signed(self):
4569 self.assertEqual(torch.IntTensor(5).is_signed(),
True)
4570 self.assertEqual(torch.ByteTensor(5).is_signed(),
False)
4571 self.assertEqual(torch.CharTensor(5).is_signed(),
True)
4572 self.assertEqual(torch.FloatTensor(5).is_signed(),
True)
4573 self.assertEqual(torch.HalfTensor(10).is_signed(),
True)
4576 def test_is_signed_cuda(self):
4577 self.assertEqual(torch.cuda.IntTensor(5).is_signed(),
True)
4578 self.assertEqual(torch.cuda.ByteTensor(5).is_signed(),
False)
4579 self.assertEqual(torch.cuda.CharTensor(5).is_signed(),
True)
4580 self.assertEqual(torch.cuda.FloatTensor(5).is_signed(),
True)
4581 self.assertEqual(torch.cuda.HalfTensor(10).is_signed(),
True)
4584 def _test_solve(self, cast):
4585 a = cast(torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
4586 (-6.05, -3.30, 5.36, -4.44, 1.08),
4587 (-0.45, 2.58, -2.70, 0.27, 9.04),
4588 (8.32, 2.71, 4.35, -7.17, 2.14),
4589 (-9.67, -5.14, -7.26, 6.08, -6.87)))).t()
4590 b = cast(torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
4591 (-1.56, 4.00, -8.67, 1.75, 2.86),
4592 (9.81, -4.09, -4.57, -8.61, 8.99)))).t()
4594 res1 = torch.solve(b, a)[0]
4595 self.assertLessEqual(b.dist(torch.mm(a, res1)), 1e-12)
4597 ta = cast(torch.Tensor())
4598 tb = cast(torch.Tensor())
4599 res2 = torch.solve(b, a, out=(tb, ta))[0]
4600 res3 = torch.solve(b, a, out=(b, a))[0]
4601 self.assertEqual(res1, tb)
4602 self.assertEqual(res1, b)
4603 self.assertEqual(res1, res2)
4604 self.assertEqual(res1, res3)
4607 res1 = torch.solve(b, a)[0]
4608 ta = cast(torch.Tensor())
4609 tb = cast(torch.Tensor())
4610 torch.solve(b, a, out=(tb, ta))[0]
4611 self.assertEqual(res1, tb)
4612 torch.solve(b, a, out=(tb, ta))[0]
4613 self.assertEqual(res1, tb)
4616 def test_solve(self):
4617 self._test_solve(self,
lambda t: t)
4620 def _test_solve_batched(self, cast):
4621 from common_utils
import random_fullrank_matrix_distinct_singular_value
4623 A = cast(random_fullrank_matrix_distinct_singular_value(5, 1))
4624 b = cast(torch.randn(1, 5, 10))
4625 x_exp, LU_exp = torch.solve(b.squeeze(0), A.squeeze(0))
4626 x, LU = torch.solve(b, A)
4627 self.assertEqual(x, x_exp.unsqueeze(0))
4628 self.assertEqual(LU, LU_exp.unsqueeze(0))
4631 A = cast(random_fullrank_matrix_distinct_singular_value(5, 4))
4632 b = cast(torch.randn(4, 5, 10))
4637 x_exp, LU_exp = torch.solve(b[i], A[i])
4638 x_exp_list.append(x_exp)
4639 LU_exp_list.append(LU_exp)
4640 x_exp = torch.stack(x_exp_list)
4641 LU_exp = torch.stack(LU_exp_list)
4643 x, LU = torch.solve(b, A)
4644 self.assertEqual(x, x_exp)
4645 self.assertEqual(LU, LU_exp)
4648 A = cast(random_fullrank_matrix_distinct_singular_value(5, 3))
4649 b = cast(torch.randn(3, 5, 10))
4650 x, LU = torch.solve(b, A)
4651 self.assertEqual(torch.matmul(A, x), b)
4657 from numpy.linalg
import solve
4658 A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2)
4659 b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0)
4660 x, _ = torch.solve(b, A)
4661 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4662 self.assertEqual(x.data, cast(x_exp))
4665 def test_solve_batched(self):
4666 self._test_solve_batched(self,
lambda t: t)
4669 def _test_solve_batched_dims(self, cast):
4673 from numpy.linalg
import solve
4674 from common_utils
import random_fullrank_matrix_distinct_singular_value
4676 A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3))
4677 b = cast(torch.randn(2, 1, 3, 4, 6))
4678 x, _ = torch.solve(b, A)
4679 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4680 self.assertEqual(x.data, cast(x_exp))
4683 A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3)).transpose(-2, -1)
4684 b = cast(torch.randn(2, 1, 3, 6, 4)).transpose(-2, -1)
4685 assert not A.is_contiguous()
4686 assert not b.is_contiguous()
4687 x, _ = torch.solve(b, A)
4688 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4689 self.assertEqual(x.data, cast(x_exp))
4692 A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3))
4693 b = cast(torch.randn(4, 6))
4694 x, _ = torch.solve(b, A)
4695 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4696 self.assertEqual(x.data, cast(x_exp))
4699 A = cast(random_fullrank_matrix_distinct_singular_value(4))
4700 b = cast(torch.randn(2, 1, 3, 4, 2))
4701 x, _ = torch.solve(b, A)
4702 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4703 self.assertEqual(x.data, cast(x_exp))
4706 A = cast(random_fullrank_matrix_distinct_singular_value(4, 1, 3, 1))
4707 b = cast(torch.randn(2, 1, 3, 4, 5))
4708 x, _ = torch.solve(b, A)
4709 x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4710 self.assertEqual(x.data, cast(x_exp))
4713 def test_solve_batched_dims(self):
4714 self._test_solve_batched_dims(self,
lambda t: t)
4716 def test_solve_methods_arg_device(self):
4720 for b_device, A_device
in product([
'cpu',
'cuda'], repeat=2):
4721 if b_device == A_device:
4724 b = torch.randn(3, 1, device=b_device)
4725 A = torch.randn(3, 3, device=A_device)
4726 err_str =
"Expected b and A to be on the same device" 4727 with self.assertRaisesRegex(RuntimeError, err_str):
4730 with self.assertRaisesRegex(RuntimeError, err_str):
4731 torch.cholesky_solve(b, A)
4738 def canonicalize(q, r):
4739 d = r.diag().sign().diag()
4740 return torch.mm(q, d), torch.mm(d, r)
4742 def canon_and_check(q, r, expected_q, expected_r):
4743 q_canon, r_canon = canonicalize(q, r)
4744 expected_q_canon, expected_r_canon = canonicalize(expected_q, expected_r)
4745 self.assertEqual(q_canon, expected_q_canon)
4746 self.assertEqual(r_canon, expected_r_canon)
4748 def check_qr(a, expected_q, expected_r):
4751 canon_and_check(q, r, expected_q, expected_r)
4754 q, r = torch.Tensor(), torch.Tensor()
4755 torch.qr(a, out=(q, r))
4756 canon_and_check(q, r, expected_q, expected_r)
4762 result, tau = torch.geqrf(a)
4763 self.assertEqual(result.size(0), m)
4764 self.assertEqual(result.size(1), n)
4765 self.assertEqual(tau.size(0), k)
4766 r = torch.triu(result.narrow(0, 0, k))
4767 q = torch.orgqr(result, tau)
4768 q, r = q.narrow(1, 0, k), r
4769 canon_and_check(q, r, expected_q, expected_r)
4772 a = torch.Tensor(((1, 2, 3), (4, 5, 6), (7, 8, 10)))
4774 expected_q = torch.Tensor((
4775 (-1.230914909793328e-01, 9.045340337332914e-01, 4.082482904638621e-01),
4776 (-4.923659639173310e-01, 3.015113445777629e-01, -8.164965809277264e-01),
4777 (-8.616404368553292e-01, -3.015113445777631e-01, 4.082482904638634e-01)))
4778 expected_r = torch.Tensor((
4779 (-8.124038404635959e+00, -9.601136296387955e+00, -1.193987e+01),
4780 (0.000000000000000e+00, 9.045340337332926e-01, 1.507557e+00),
4781 (0.000000000000000e+00, 0.000000000000000e+00, 4.082483e-01)))
4783 check_qr(a, expected_q, expected_r)
4792 expected_q = torch.Tensor((
4793 (-0.0776150525706334, -0.833052161400748, 0.3651483716701106),
4794 (-0.3104602102825332, -0.4512365874254053, -0.1825741858350556),
4795 (-0.5433053679944331, -0.0694210134500621, -0.7302967433402217),
4796 (-0.7761505257063329, 0.3123945605252804, 0.5477225575051663)
4798 expected_r = torch.Tensor((
4799 (-12.8840987267251261, -14.5916298832790581, -17.0753115655393231),
4800 (0, -1.0413152017509357, -1.770235842976589),
4801 (0, 0, 0.5477225575051664)
4804 check_qr(a, expected_q, expected_r)
4812 expected_q = torch.Tensor((
4813 (-0.0966736489045663, 0.907737593658436, 0.4082482904638653),
4814 (-0.4833682445228317, 0.3157348151855452, -0.8164965809277254),
4815 (-0.870062840141097, -0.2762679632873518, 0.4082482904638621)
4817 expected_r = torch.Tensor((
4818 (-1.0344080432788603e+01, -1.1794185166357092e+01,
4819 -1.3244289899925587e+01, -1.5564457473635180e+01),
4820 (0.0000000000000000e+00, 9.4720444555662542e-01,
4821 1.8944088911132546e+00, 2.5653453733825331e+00),
4822 (0.0000000000000000e+00, 0.0000000000000000e+00,
4823 1.5543122344752192e-15, 4.0824829046386757e-01)
4825 check_qr(a, expected_q, expected_r)
4828 a = torch.randn(1000, 1000)
4830 a_qr = torch.mm(q, r)
4831 self.assertEqual(a, a_qr, prec=1e-3)
4834 def test_ormqr(self):
4835 mat1 = torch.randn(7, 7)
4836 mat2 = torch.randn(7, 7)
4837 q, r = torch.qr(mat1)
4838 m, tau = torch.geqrf(mat1)
4839 out_holder = torch.empty_like(mat1)
4841 res1 = torch.mm(q, mat2)
4842 res2 = torch.ormqr(m, tau, mat2, left=
True, transpose=
False)
4843 torch.ormqr(m, tau, mat2, out=out_holder)
4844 self.assertEqual(res1, res2)
4845 self.assertEqual(res2, out_holder)
4847 res1 = torch.mm(mat2, q)
4848 res2 = torch.ormqr(m, tau, mat2, left=
False, transpose=
False)
4849 torch.ormqr(m, tau, mat2, left=
False, transpose=
False, out=out_holder)
4850 self.assertEqual(res1, res2)
4851 self.assertEqual(res2, out_holder)
4853 res1 = torch.mm(q.t(), mat2)
4854 res2 = torch.ormqr(m, tau, mat2, left=
True, transpose=
True)
4855 torch.ormqr(m, tau, mat2, left=
True, transpose=
True, out=out_holder)
4856 self.assertEqual(res1, res2)
4857 self.assertEqual(res2, out_holder)
4859 res1 = torch.mm(mat2, q.t())
4860 res2 = torch.ormqr(m, tau, mat2, left=
False, transpose=
True)
4861 torch.ormqr(m, tau, mat2, left=
False, transpose=
True, out=out_holder)
4862 self.assertEqual(res1, res2)
4863 self.assertEqual(res2, out_holder)
4866 def _test_geqrf(self, cast):
4867 a = cast(torch.randn(5, 5))
4868 b, c = torch.geqrf(a)
4869 b_placeholder, c_placeholder = torch.empty_like(b), torch.empty_like(c)
4870 torch.geqrf(a, out=(b_placeholder, c_placeholder))
4871 self.assertEqual(b, b_placeholder)
4872 self.assertEqual(c, c_placeholder)
4875 def test_geqrf(self):
4876 self._test_geqrf(self,
lambda t: t)
4879 def _test_trtrs(self, cast):
4880 a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
4881 (-6.05, -3.30, 5.36, -4.44, 1.08),
4882 (-0.45, 2.58, -2.70, 0.27, 9.04),
4883 (8.32, 2.71, 4.35, -7.17, 2.14),
4884 (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
4885 b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
4886 (-1.56, 4.00, -8.67, 1.75, 2.86),
4887 (9.81, -4.09, -4.57, -8.61, 8.99))).t()
4896 x = torch.trtrs(b, U)[0]
4897 self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
4898 x = torch.trtrs(b, U,
True,
False,
False)[0]
4899 self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
4902 x = torch.trtrs(b, L,
False)[0]
4903 self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
4904 x = torch.trtrs(b, L,
False,
False,
False)[0]
4905 self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
4908 x = torch.trtrs(b, U,
True,
True)[0]
4909 self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
4910 x = torch.trtrs(b, U,
True,
True,
False)[0]
4911 self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
4914 y = torch.trtrs(b, U.t(),
False,
False)[0]
4915 self.assertLessEqual(x.dist(y), 1e-12)
4918 x = torch.trtrs(b, L,
False,
True)[0]
4919 self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
4920 x = torch.trtrs(b, L,
False,
True,
False)[0]
4921 self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
4924 y = torch.trtrs(b, L.t(),
True,
False)[0]
4925 self.assertLessEqual(x.dist(y), 1e-12)
4928 res1 = torch.trtrs(b, a)[0]
4929 ta = cast(torch.Tensor())
4930 tb = cast(torch.Tensor())
4931 torch.trtrs(b, a, out=(tb, ta))
4932 self.assertEqual(res1, tb, 0)
4934 torch.trtrs(b, a, out=(tb, ta))
4935 self.assertEqual(res1, tb, 0)
4938 def test_trtrs(self):
4939 self._test_trtrs(self,
lambda t: t)
4942 def _test_trtrs_batched(self, cast):
4943 def trtrs_test_helper(A_dims, b_dims, cast, upper, unitriangular):
4944 A = cast(torch.randn(*A_dims))
4945 A = A.triu()
if upper
else A.tril()
4947 A.diagonal(dim1=-2, dim2=-1).fill_(1.)
4948 b = cast(torch.randn(*b_dims))
4951 for upper, transpose, unitriangular
in product([
True,
False], repeat=3):
4953 A, b = trtrs_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular)
4954 x_exp = torch.trtrs(b.squeeze(0), A.squeeze(0),
4955 upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4956 x = torch.trtrs(b, A,
4957 upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4958 self.assertEqual(x, x_exp.unsqueeze(0))
4961 A, b = trtrs_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular)
4964 x_exp = torch.trtrs(b[i], A[i],
4965 upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4966 x_exp_list.append(x_exp)
4967 x_exp = torch.stack(x_exp_list)
4969 x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4970 self.assertEqual(x, x_exp)
4973 A, b = trtrs_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular)
4974 x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4976 self.assertLessEqual(b.dist(torch.matmul(A.transpose(-1, -2), x)), 2e-12)
4978 self.assertLessEqual(b.dist(torch.matmul(A, x)), 2e-12)
4981 def test_trtrs_batched(self):
4982 _TestTorchMixin._test_trtrs_batched(self,
lambda t: t)
4985 def _test_trtrs_batched_dims(self, cast):
4989 from scipy.linalg
import solve_triangular
as tri_solve
4991 def scipy_tri_solve_batched(A, B, upper, trans, diag):
4992 batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2]
4993 single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:]
4994 expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A),
4995 torch.Size(batch_dims_B)))
4996 expand_A = np.broadcast_to(A, expand_dims + single_dim_A)
4997 expand_B = np.broadcast_to(B, expand_dims + single_dim_B)
4998 flat_A = expand_A.reshape((-1,) + single_dim_A)
4999 flat_B = expand_B.reshape((-1,) + single_dim_B)
5000 flat_X = np.vstack([tri_solve(a, b, lower=(
not upper), trans=int(trans), unit_diagonal=diag)
5001 for a, b
in zip(flat_A, flat_B)])
5002 return flat_X.reshape(expand_B.shape)
5004 def run_test(A_dims, b_dims, cast, upper, transpose, unitriangular):
5005 A = torch.randn(*A_dims)
5006 A = A.triu()
if upper
else A.tril()
5008 A.diagonal(dim1=-2, dim2=-1).fill_(1.)
5009 b = torch.randn(*b_dims)
5010 x_exp = torch.Tensor(scipy_tri_solve_batched(A.numpy(), b.numpy(),
5011 upper, transpose, unitriangular))
5012 A, b = cast(A), cast(b)
5013 x = torch.trtrs(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
5015 self.assertEqual(x, cast(x_exp))
5017 for upper, transpose, unitriangular
in product([
True,
False], repeat=3):
5019 run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast, upper, transpose, unitriangular)
5020 run_test((2, 1, 3, 4, 4), (4, 6), cast, upper, transpose, unitriangular)
5021 run_test((4, 4), (2, 1, 3, 4, 2), cast, upper, transpose, unitriangular)
5022 run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper, transpose, unitriangular)
5025 def test_trtrs_batched_dims(self):
5026 self._test_trtrs_batched_dims(self,
lambda t: t)
5029 def test_gels(self):
5030 def _test_underdetermined(a, b, expectedNorm):
5037 res1 = torch.gels(b, a)[0]
5038 self.assertEqual(a, a_copy, 0)
5039 self.assertEqual(b, b_copy, 0)
5040 self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8)
5044 res2 = torch.gels(b, a, out=(tb, ta))[0]
5045 self.assertEqual(a, a_copy, 0)
5046 self.assertEqual(b, b_copy, 0)
5047 self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8)
5049 res3 = torch.gels(b, a, out=(b, a))[0]
5050 self.assertEqual((torch.mm(a_copy, b) - b_copy).norm(), expectedNorm, 1e-8)
5051 self.assertEqual(res1, tb, 0)
5052 self.assertEqual(res1, b, 0)
5053 self.assertEqual(res1, res2, 0)
5054 self.assertEqual(res1, res3, 0)
5056 def _test_overdetermined(a, b, expectedNorm):
5061 def check_norm(a, b, expected_norm, gels_result):
5068 resid_info = gels_result[n:]
5070 resid_norm = (torch.mm(a, x) - b).norm()
5071 self.assertEqual(resid_norm, expectedNorm, 1e-8)
5072 self.assertEqual(resid_info.norm(), resid_norm, 1e-8)
5076 res1 = torch.gels(b, a)[0]
5077 self.assertEqual(a, a_copy, 0)
5078 self.assertEqual(b, b_copy, 0)
5079 check_norm(a, b, expectedNorm, res1)
5083 res2 = torch.gels(b, a, out=(tb, ta))[0]
5084 self.assertEqual(a, a_copy, 0)
5085 self.assertEqual(b, b_copy, 0)
5086 check_norm(a, b, expectedNorm, res2)
5088 res3 = torch.gels(b, a, out=(b, a))[0]
5089 check_norm(a_copy, b_copy, expectedNorm, res3)
5091 self.assertEqual(res1, tb, 0)
5092 self.assertEqual(res1, b, 0)
5093 self.assertEqual(res1, res2, 0)
5094 self.assertEqual(res1, res3, 0)
5098 a = torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5099 (-7.84, -0.28, 3.24, 8.09),
5100 (-4.39, -3.24, 6.27, 5.28),
5101 (4.53, 3.83, -6.64, 2.06))).t()
5102 b = torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5103 (9.35, -4.43, -0.70, -0.26))).t()
5104 _test_underdetermined(a, b, expectedNorm)
5107 expectedNorm = 17.390200628863
5108 a = torch.Tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45),
5109 (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70),
5110 (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19),
5111 (4.53, 3.83, -6.64, 2.06, -2.47, 4.70))).t()
5112 b = torch.Tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93),
5113 (9.35, -4.43, -0.70, -0.26, -7.36, -2.52))).t()
5114 _test_overdetermined(a, b, expectedNorm)
5118 a = torch.Tensor(((1.44, -9.96, -7.55),
5119 (-7.84, -0.28, 3.24),
5120 (-4.39, -3.24, 6.27),
5121 (4.53, 3.83, -6.64))).t()
5122 b = torch.Tensor(((8.58, 8.26, 8.48),
5123 (9.35, -4.43, -0.70))).t()
5124 _test_underdetermined(a, b, expectedNorm)
5128 a = torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5129 (-7.84, -0.28, 3.24, 8.09),
5130 (-4.39, -3.24, 6.27, 5.28),
5131 (4.53, 3.83, -6.64, 2.06))).t()
5132 b = torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5133 (9.35, -4.43, -0.70, -0.26))).t()
5136 torch.gels(b, a, out=(tb, ta))
5137 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5138 torch.gels(b, a, out=(tb, ta))
5139 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5140 torch.gels(b, a, out=(tb, ta))
5141 self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5145 a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
5146 (-6.49, 3.80, 0.00, 0.00, 0.00),
5147 (-0.47, -6.39, 4.17, 0.00, 0.00),
5148 (-7.20, 1.50, -1.51, 5.70, 0.00),
5149 (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous()
5151 ee, vv = torch.eig(a,
True)
5154 eee, vvv = torch.eig(a,
True, out=(te, tv))
5155 self.assertEqual(e, ee, 1e-12)
5156 self.assertEqual(ee, eee, 1e-12)
5157 self.assertEqual(ee, te, 1e-12)
5158 self.assertEqual(vv, vvv, 1e-12)
5159 self.assertEqual(vv, tv, 1e-12)
5162 X = torch.randn(4, 4)
5163 X = torch.mm(X.t(), X)
5164 e, v = torch.zeros(4, 2), torch.zeros(4, 4)
5165 torch.eig(X,
True, out=(e, v))
5166 Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
5167 self.assertEqual(X, Xhat, 1e-8,
'VeV\' wrong')
5168 self.assertFalse(v.is_contiguous(),
'V is contiguous')
5170 torch.eig(X,
True, out=(e, v))
5171 Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t()))
5172 self.assertEqual(X, Xhat, 1e-8,
'VeV\' wrong')
5173 self.assertFalse(v.is_contiguous(),
'V is contiguous')
5176 X = torch.randn(4, 4)
5177 X = torch.mm(X.t(), X)
5178 e = torch.zeros(4, 2, 2)[:, 1]
5179 v = torch.zeros(4, 2, 4)[:, 1]
5180 self.assertFalse(v.is_contiguous(),
'V is contiguous')
5181 self.assertFalse(e.is_contiguous(),
'E is contiguous')
5182 torch.eig(X,
True, out=(e, v))
5183 Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
5184 self.assertEqual(X, Xhat, 1e-8,
'VeV\' wrong')
5187 def _test_symeig(self, conv_fn):
5188 xval = conv_fn(torch.rand(100, 3))
5189 cov = torch.mm(xval.t(), xval)
5190 rese = conv_fn(torch.zeros(3))
5191 resv = conv_fn(torch.zeros(3, 3))
5194 self.assertTrue(resv.is_contiguous(),
'resv is not contiguous')
5195 torch.symeig(cov.clone(),
True, out=(rese, resv))
5196 ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
5197 self.assertEqual(cov, ahat, 1e-8,
'VeV\' wrong')
5200 self.assertFalse(resv.is_contiguous(),
'resv is contiguous')
5201 torch.symeig(cov.clone(),
True, out=(rese, resv))
5202 ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
5203 self.assertEqual(cov, ahat, 1e-8,
'VeV\' wrong')
5206 rese2 = conv_fn(torch.zeros(3))
5207 resv2 = conv_fn(torch.randn(3, 3))
5208 expected_resv2 = conv_fn(torch.zeros(3, 3))
5209 torch.symeig(cov.clone(),
False, out=(rese2, resv2))
5210 self.assertEqual(rese, rese2)
5211 self.assertEqual(resv2, expected_resv2)
5214 X = conv_fn(torch.rand(5, 5))
5216 e = conv_fn(torch.zeros(4, 2)).select(1, 1)
5217 v = conv_fn(torch.zeros(4, 2, 4))[:, 1]
5218 self.assertFalse(v.is_contiguous(),
'V is contiguous')
5219 self.assertFalse(e.is_contiguous(),
'E is contiguous')
5220 torch.symeig(X,
True, out=(e, v))
5221 Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t())
5222 self.assertEqual(X, Xhat, 1e-8,
'VeV\' wrong')
5225 def test_symeig(self):
5226 self._test_symeig(self,
lambda x: x)
5230 a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
5231 (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
5232 (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
5233 (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
5234 (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
5235 u, s, v = torch.svd(a)
5239 uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
5240 self.assertEqual(u, uu, 0,
'torch.svd')
5241 self.assertEqual(u, uuu, 0,
'torch.svd')
5242 self.assertEqual(s, ss, 0,
'torch.svd')
5243 self.assertEqual(s, sss, 0,
'torch.svd')
5244 self.assertEqual(v, vv, 0,
'torch.svd')
5245 self.assertEqual(v, vvv, 0,
'torch.svd')
5248 X = torch.randn(4, 4)
5249 U, S, V = torch.svd(X)
5250 Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5251 self.assertEqual(X, Xhat, 1e-8,
'USV\' wrong')
5253 self.assertFalse(U.is_contiguous(),
'U is contiguous')
5254 torch.svd(X, out=(U, S, V))
5255 Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5256 self.assertEqual(X, Xhat, 1e-8,
'USV\' wrong')
5259 X = torch.randn(5, 5)
5260 U = torch.zeros(5, 2, 5)[:, 1]
5261 S = torch.zeros(5, 2)[:, 1]
5262 V = torch.zeros(5, 2, 5)[:, 1]
5264 self.assertFalse(U.is_contiguous(),
'U is contiguous')
5265 self.assertFalse(S.is_contiguous(),
'S is contiguous')
5266 self.assertFalse(V.is_contiguous(),
'V is contiguous')
5267 torch.svd(X, out=(U, S, V))
5268 Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5269 self.assertEqual(X, Xhat, 1e-8,
'USV\' wrong')
5272 def _test_svd_no_singularvectors(self, cast):
5273 for size
in [(5, 5), (5, 20), (20, 5)]:
5274 a = cast(torch.randn(*size))
5275 u, s_expect, v = torch.svd(a)
5276 u, s_actual, v = torch.svd(a, compute_uv=
False)
5277 self.assertEqual(s_expect, s_actual,
"Singular values don't match")
5280 def test_svd_no_singularvectors(self):
5281 self._test_svd_no_singularvectors(self,
lambda t: t)
5284 def _test_matrix_rank(self, conv_fn):
5285 a = conv_fn(torch.eye(10))
5286 self.assertEqual(torch.matrix_rank(a).item(), 10)
5287 self.assertEqual(torch.matrix_rank(a,
True).item(), 10)
5290 self.assertEqual(torch.matrix_rank(a).item(), 9)
5291 self.assertEqual(torch.matrix_rank(a,
True).item(), 9)
5293 a = conv_fn(torch.randn(24, 42))
5294 self.assertEqual(torch.matrix_rank(a), torch.matrix_rank(a.t()))
5295 aaT = torch.mm(a, a.t())
5296 self.assertEqual(torch.matrix_rank(aaT), torch.matrix_rank(aaT,
True))
5297 aTa = torch.mm(a.t(), a)
5298 self.assertEqual(torch.matrix_rank(aTa), torch.matrix_rank(aTa,
True))
5301 from numpy.linalg
import matrix_rank
5302 a = conv_fn(torch.randn(35, 75))
5303 self.assertEqual(torch.matrix_rank(a).item(), matrix_rank(a.cpu().numpy()))
5304 self.assertEqual(torch.matrix_rank(a, 0.01).item(), matrix_rank(a.cpu().numpy(), 0.01))
5306 aaT = torch.mm(a, a.t())
5307 self.assertEqual(torch.matrix_rank(aaT).item(), matrix_rank(aaT.cpu().numpy()))
5308 self.assertEqual(torch.matrix_rank(aaT, 0.01).item(), matrix_rank(aaT.cpu().numpy(), 0.01))
5310 if np.lib.NumpyVersion(np.__version__) >=
'1.14.0':
5311 self.assertEqual(torch.matrix_rank(aaT,
True).item(), matrix_rank(aaT.cpu().numpy(),
True))
5312 self.assertEqual(torch.matrix_rank(aaT, 0.01,
True).item(),
5313 matrix_rank(aaT.cpu().numpy(), 0.01,
True))
5316 def test_matrix_rank(self):
5317 self._test_matrix_rank(self,
lambda x: x)
5320 def _test_signal_window_functions(self, device='cpu'):
5322 raise unittest.SkipTest(
'Scipy not found')
5325 torch_method = getattr(torch, name +
'_window')
5326 for size
in [1, 2, 5, 10, 50, 100, 1024, 2048]:
5327 for periodic
in [
True,
False]:
5328 res = torch_method(size, periodic=periodic, device=device)
5329 ref = torch.from_numpy(signal.get_window(name, size, fftbins=periodic))
5330 self.assertEqual(res, ref)
5331 with self.assertRaisesRegex(RuntimeError,
r'not implemented for sparse types'):
5332 torch_method(3, layout=torch.sparse_coo)
5333 with self.assertRaisesRegex(RuntimeError,
r'floating point'):
5334 torch_method(3, dtype=torch.long)
5335 self.assertTrue(torch_method(3, requires_grad=
True).requires_grad)
5336 self.assertFalse(torch_method(3).requires_grad)
5338 for window
in [
'hann',
'hamming',
'bartlett',
'blackman']:
5341 def test_signal_window_functions(self):
5342 self._test_signal_window_functions(self)
5345 def _test_inverse(self, conv_fn):
5346 from common_utils
import random_fullrank_matrix_distinct_singular_value
5349 matrix = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
5350 matrix_inverse = torch.inverse(matrix)
5351 identity = conv_fn(torch.eye(5))
5352 self.assertEqual(identity, torch.mm(matrix, matrix_inverse), 1e-8,
'inverse value')
5353 self.assertEqual(identity, torch.mm(matrix_inverse, matrix), 1e-8,
'inverse value')
5355 matrix_inverse_out = conv_fn(torch.empty(5, 5))
5356 torch.inverse(matrix, out=matrix_inverse_out)
5357 self.assertEqual(matrix_inverse_out, matrix_inverse, 0,
'inverse value in-place')
5359 torch.inverse(matrix, out=matrix_inverse_out)
5360 self.assertEqual(matrix_inverse_out, matrix_inverse, 0,
'inverse value in-place')
5363 matrix = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 1))
5364 matrix_inverse = torch.inverse(matrix)
5365 expected_inv = matrix.squeeze(0).inverse()
5366 self.assertEqual(matrix_inverse, expected_inv.unsqueeze(0))
5369 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 4))
5370 expected_inv_list = []
5371 for i
in range(0, 4):
5372 expected_inv_list.append(torch.inverse(matrices[i]))
5373 expected_inv = torch.stack(expected_inv_list)
5374 matrices_inverse = torch.inverse(matrices)
5375 self.assertEqual(matrices_inverse, expected_inv)
5378 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 2, 3))
5379 expected_inv_list = []
5380 for mat
in matrices.view(-1, 5, 5):
5381 expected_inv_list.append(torch.inverse(mat))
5382 expected_inv = torch.stack(expected_inv_list).view(2, 3, 5, 5)
5383 matrices_inverse = torch.inverse(matrices)
5384 self.assertEqual(matrices_inverse, expected_inv)
5387 with self.assertRaisesRegex(RuntimeError,
"must be batches of square matrices"):
5388 torch.inverse(torch.randn(2, 3, 4, 3))
5391 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 3))
5392 matrices_inverse = torch.inverse(matrices)
5393 self.assertEqual(torch.matmul(matrices, matrices_inverse), identity.expand_as(matrices))
5394 self.assertEqual(torch.matmul(matrices_inverse, matrices), identity.expand_as(matrices))
5397 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 3))
5398 matrices_inverse = conv_fn(torch.empty(3, 5, 5))
5399 torch.inverse(matrices, out=matrices_inverse)
5400 self.assertEqual(torch.inverse(matrices), matrices_inverse)
5406 from numpy.linalg
import inv
5407 matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 2)).permute(0, 2, 1)
5408 assert not matrices.is_contiguous()
5409 matrices_inverse = torch.inverse(matrices)
5410 expected_inv = torch.as_tensor(inv(matrices.cpu().numpy()))
5411 self.assertEqual(matrices_inverse, conv_fn(expected_inv))
5414 def test_inverse(self):
5415 self._test_inverse(self,
lambda t: t)
5418 def _test_pinverse(self, conv_fn):
5421 MPI = torch.pinverse(M)
5422 self.assertEqual(M, M.mm(MPI).mm(M), 1e-8,
'pseudo-inverse condition 1')
5423 self.assertEqual(MPI, MPI.mm(M).mm(MPI), 1e-8,
'pseudo-inverse condition 2')
5424 self.assertEqual(M.mm(MPI), (M.mm(MPI)).t(), 1e-8,
'pseudo-inverse condition 3')
5425 self.assertEqual(MPI.mm(M), (MPI.mm(M)).t(), 1e-8,
'pseudo-inverse condition 4')
5428 M = conv_fn(torch.randn(5, 5))
5432 M = conv_fn(torch.randn(3, 4))
5436 M = torch.randn(5, 5)
5437 M = conv_fn(M.mm(M.t()))
5438 self.assertEqual(conv_fn(torch.eye(5)), M.pinverse().mm(M), 1e-7,
'pseudo-inverse for invertible matrix')
5441 def test_pinverse(self):
5442 self._test_pinverse(self, conv_fn=
lambda x: x)
5445 def _test_matrix_power(self, conv_fn):
5446 def run_test(M, sign=1):
5449 MP2 = torch.matrix_power(M, 2)
5450 self.assertEqual(MP2, torch.matmul(M, M))
5452 MP3 = torch.matrix_power(M, 3)
5453 self.assertEqual(MP3, torch.matmul(MP2, M))
5455 MP4 = torch.matrix_power(M, 4)
5456 self.assertEqual(MP4, torch.matmul(MP2, MP2))
5458 MP6 = torch.matrix_power(M, 6)
5459 self.assertEqual(MP6, torch.matmul(MP3, MP3))
5461 MP0 = torch.matrix_power(M, 0)
5462 self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M))
5465 M = conv_fn(torch.randn(5, 5))
5469 M = conv_fn(torch.randn(3, 3, 3))
5473 M = conv_fn(torch.randn(2, 3, 3, 3))
5477 from common_utils
import random_fullrank_matrix_distinct_singular_value
5478 M = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
5479 run_test(M, sign=-1)
5481 M = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 3))
5482 run_test(M, sign=-1)
5484 M = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 2, 3))
5485 run_test(M, sign=-1)
5488 def test_matrix_power(self):
5489 self._test_matrix_power(self, conv_fn=
lambda x: x)
5492 def _test_chain_matmul(self, cast):
5493 def product(matrices):
5494 for mat
in matrices[1:]:
5495 matrices[0] = matrices[0].mm(mat)
5498 def run_test(p, cast):
5500 for (pi, pi_1)
in zip(p[:-1], p[1:]):
5501 matrices.append(cast(torch.randn(pi, pi_1)))
5502 self.assertEqual(torch.chain_matmul(*matrices), product(matrices))
5504 run_test([10, 20, 30, 5], cast)
5505 run_test([15, 5, 10, 20, 25], cast)
5507 def test_chain_matmul(self):
5508 self._test_chain_matmul(self, cast=
lambda x: x)
5511 def _test_det_logdet_slogdet(self, conv_fn):
5512 def reference_det(M):
5520 M[0], M[i] = M[i], M[0]
5525 for i
in range(1, l):
5528 row -= row[j] / M[j, j] * M[j]
5530 return M.diag().prod() * multiplier
5532 def test_single_det(M, target, desc):
5535 sdet, logabsdet = M.slogdet()
5536 self.assertEqual(det, target, 1e-7,
'{} (det)'.format(desc))
5538 self.assertTrue(logdet.item() != logdet.item(),
'{} (logdet negative case)'.format(desc))
5539 self.assertTrue(sdet.item() == -1,
'{} (slogdet sign negative case)'.format(desc))
5540 self.assertEqual(logabsdet.exp(), det.abs(), 1e-7,
'{} (slogdet logabsdet negative case)'.format(desc))
5541 elif det.item() == 0:
5542 self.assertEqual(logdet.exp().item(), 0, 1e-7,
'{} (logdet zero case)'.format(desc))
5543 self.assertTrue(sdet.item() == 0,
'{} (slogdet sign zero case)'.format(desc))
5544 self.assertEqual(logabsdet.exp().item(), 0, 1e-7,
'{} (slogdet logabsdet zero case)'.format(desc))
5546 self.assertEqual(logdet.exp(), det, 1e-7,
'{} (logdet positive case)'.format(desc))
5547 self.assertTrue(sdet.item() == 1,
'{} (slogdet sign positive case)'.format(desc))
5548 self.assertEqual(logabsdet.exp(), det, 1e-7,
'{} (slogdet logabsdet positive case)'.format(desc))
5550 eye = conv_fn(torch.eye(5))
5551 test_single_det(eye,
torch.tensor(1, dtype=eye.dtype),
'identity')
5554 is_cuda_8_92 =
False 5556 is_cuda_8_92 = any(x
in torch.version.cuda
for x
in [
'8.0',
'9.2'])
5559 assert M.size(0) >= 5,
'this helper fn assumes M to be at least 5x5' 5562 if M.is_cuda
and is_cuda_8_92:
5566 ref_M_det = reference_det(M)
5568 test_single_det(M, ref_M_det,
'basic')
5569 if abs(ref_M_det.item()) >= 1e-10:
5570 test_single_det(M, M.inverse().det().pow_(-1),
'inverse')
5571 test_single_det(M, M.t().det(),
'transpose')
5574 for scale
in [-2, -0.1, 0, 10]:
5575 target = M_det * scale
5578 M_clone[:, x] *= scale
5579 test_single_det(M_clone, target,
'scale a row')
5582 M_clone[x, :] *= scale
5583 test_single_det(M_clone, target,
'scale a column')
5585 for x1, x2
in [(0, 3), (4, 1), (3, 2)]:
5586 assert x1 != x2,
'x1 and x2 needs to be different for this test' 5587 target = M_det.clone().zero_()
5590 M_clone[:, x2] = M_clone[:, x1]
5591 test_single_det(M_clone, target,
'two rows are same')
5594 M_clone[x2, :] = M_clone[x1, :]
5595 test_single_det(M_clone, target,
'two columns are same')
5597 for scale1, scale2
in [(0.3, -1), (0, 2), (10, 0.1)]:
5598 target = -M_det * scale1 * scale2
5601 t = M_clone[:, x1] * scale1
5602 M_clone[:, x1] += M_clone[:, x2] * scale2
5604 test_single_det(M_clone, target,
'exchanging rows')
5607 t = M_clone[x1, :] * scale1
5608 M_clone[x1, :] += M_clone[x2, :] * scale2
5610 test_single_det(M_clone, target,
'exchanging columns')
5612 def get_random_mat_scale(n):
5628 return math.factorial(n - 1) ** (-1.0 / (2 * n))
5630 for n
in [5, 10, 25]:
5631 scale = get_random_mat_scale(n)
5632 test(torch.randn(n, n) * scale)
5633 r = torch.randn(n, n) * scale
5637 r = torch.randn(n, n) * scale
5638 test(r.mm(r.t()) + torch.eye(n) * 1e-6)
5640 r = torch.randn(n, n) * scale
5646 test((torch.randn(n, n, n + 1) * scale)[:, 2, 1:])
5648 r = torch.randn(n, n) * scale
5650 if reference_det(u) < 0:
5652 if reference_det(v) < 0:
5656 test(u.mm(s.diag()).mm(v))
5659 def test_det_logdet_slogdet(self):
5660 self._test_det_logdet_slogdet(self,
lambda x: x)
5663 def _test_fft_ifft_rfft_irfft(self, device='cpu'):
5664 def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
5665 x = prepro_fn(torch.randn(*sizes, device=device))
5666 for normalized
in (
True,
False):
5667 res = x.fft(signal_ndim, normalized=normalized)
5668 rec = res.ifft(signal_ndim, normalized=normalized)
5669 self.assertEqual(x, rec, 1e-8,
'fft and ifft')
5670 res = x.ifft(signal_ndim, normalized=normalized)
5671 rec = res.fft(signal_ndim, normalized=normalized)
5672 self.assertEqual(x, rec, 1e-8,
'ifft and fft')
5674 def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
5675 x = prepro_fn(torch.randn(*sizes, device=device))
5677 signal_sizes = x.size()[-signal_ndim:]
5678 for normalized, onesided
in product((
True,
False), repeat=2):
5679 res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided)
5681 def test_one_sample(res, test_num=10):
5682 idxs_per_dim = [torch.LongTensor(test_num).random_(s).tolist()
for s
in signal_sizes]
5683 for idx
in zip(*idxs_per_dim):
5684 reflected_idx = tuple((s - i) % s
for i, s
in zip(idx, res.size()))
5685 idx_val = res.__getitem__(idx)
5686 reflected_val = res.__getitem__(reflected_idx)
5687 self.assertEqual(idx_val[0], reflected_val[0],
'rfft hermitian symmetry on real part')
5688 self.assertEqual(idx_val[1], -reflected_val[1],
'rfft hermitian symmetry on imaginary part')
5689 if len(sizes) == signal_ndim:
5690 test_one_sample(res)
5692 output_non_batch_shape = res.size()[-(signal_ndim + 1):]
5693 flatten_batch_res = res.view(-1, *output_non_batch_shape)
5694 nb = flatten_batch_res.size(0)
5695 test_idxs = torch.LongTensor(min(nb, 4)).random_(nb)
5696 for test_idx
in test_idxs.tolist():
5697 test_one_sample(flatten_batch_res[test_idx])
5699 xc = torch.stack([x, torch.zeros_like(x)], -1)
5700 xc_res = xc.fft(signal_ndim, normalized=normalized)
5701 self.assertEqual(res, xc_res)
5702 test_input_signal_sizes = [signal_sizes]
5703 rec = res.irfft(signal_ndim, normalized=normalized,
5704 onesided=onesided, signal_sizes=signal_sizes)
5705 self.assertEqual(x, rec, 1e-8,
'rfft and irfft')
5707 rec = res.ifft(signal_ndim, normalized=normalized)
5708 self.assertEqual(x, rec.select(-1, 0), 1e-8,
'twosided rfft and ifft real')
5709 self.assertEqual(rec.select(-1, 1).data.abs().mean(), 0, 1e-8,
'twosided rfft and ifft imaginary')
5712 _test_real((100,), 1)
5713 _test_real((10, 1, 10, 100), 1)
5714 _test_real((100, 100), 2)
5715 _test_real((2, 2, 5, 80, 60), 2)
5716 _test_real((50, 40, 70), 3)
5717 _test_real((30, 1, 50, 25, 20), 3)
5719 _test_complex((100, 2), 1)
5720 _test_complex((100, 100, 2), 1)
5721 _test_complex((100, 100, 2), 2)
5722 _test_complex((1, 20, 80, 60, 2), 2)
5723 _test_complex((50, 40, 70, 2), 3)
5724 _test_complex((6, 5, 50, 25, 20, 2), 3)
5727 _test_real((165,), 1,
lambda x: x.narrow(0, 25, 100))
5728 _test_real((100, 100, 3), 1,
lambda x: x[:, :, 0])
5729 _test_real((100, 100), 2,
lambda x: x.t())
5730 _test_real((20, 100, 10, 10), 2,
lambda x: x.view(20, 100, 100)[:, :60])
5731 _test_real((65, 80, 115), 3,
lambda x: x[10:60, 13:53, 10:80])
5732 _test_real((30, 20, 50, 25), 3,
lambda x: x.transpose(1, 2).transpose(2, 3))
5734 _test_complex((2, 100), 1,
lambda x: x.t())
5735 _test_complex((100, 2), 1,
lambda x: x.expand(100, 100, 2))
5736 _test_complex((300, 200, 3), 2,
lambda x: x[:100, :100, 1:])
5737 _test_complex((20, 90, 110, 2), 2,
lambda x: x[:, 5:85].narrow(2, 5, 100))
5738 _test_complex((40, 60, 3, 80, 2), 3,
lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
5739 _test_complex((30, 55, 50, 22, 2), 3,
lambda x: x[:, 3:53, 15:40, 1:21])
5742 _test_complex((50,), 1,
lambda x: x.as_strided([5, 5, 2], [3, 2, 1]))
5743 _test_complex((50,), 1,
lambda x: x.as_strided([5, 5, 2], [4, 2, 2]))
5744 _test_complex((50,), 1,
lambda x: x.as_strided([5, 5, 2], [4, 3, 1]))
5745 _test_complex((50,), 2,
lambda x: x.as_strided([5, 5, 2], [3, 3, 1]))
5746 _test_complex((50,), 2,
lambda x: x.as_strided([5, 5, 2], [4, 2, 2]))
5747 _test_complex((50,), 2,
lambda x: x.as_strided([5, 5, 2], [4, 3, 1]))
5749 @unittest.skipIf(
not TEST_MKL,
"PyTorch is built without MKL support")
5750 def test_fft_ifft_rfft_irfft(self):
5751 self._test_fft_ifft_rfft_irfft(self)
5754 def _test_stft(self, device='cpu'):
5755 if not TEST_LIBROSA:
5756 raise unittest.SkipTest(
'librosa not found')
5758 def librosa_stft(x, n_fft, hop_length, win_length, window, center):
5760 window = np.ones(n_fft
if win_length
is None else win_length)
5762 window = window.cpu().numpy()
5763 input_1d = x.dim() == 1
5768 ri = librosa.stft(xi.cpu().numpy(), n_fft, hop_length, win_length, window, center=center)
5769 result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1)))
5770 result = torch.stack(result, 0)
5775 def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
5776 center=
True, expected_error=
None):
5777 x = torch.randn(*sizes, device=device)
5778 if win_sizes
is not None:
5779 window = torch.randn(*win_sizes, device=device)
5782 if expected_error
is None:
5783 result = x.stft(n_fft, hop_length, win_length, window, center=center)
5784 ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
5785 self.assertEqual(result, ref_result, 7e-6,
'stft comparison against librosa')
5787 self.assertRaises(expected_error,
5788 lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
5790 for center
in [
True,
False]:
5791 _test((10,), 7, center=center)
5792 _test((10, 4000), 1024, center=center)
5794 _test((10,), 7, 2, center=center)
5795 _test((10, 4000), 1024, 512, center=center)
5797 _test((10,), 7, 2, win_sizes=(7,), center=center)
5798 _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center)
5801 _test((10,), 7, 2, win_length=5, center=center)
5802 _test((10, 4000), 1024, 512, win_length=100, center=center)
5804 _test((10, 4, 2), 1, 1, expected_error=RuntimeError)
5805 _test((10,), 11, 1, center=
False, expected_error=RuntimeError)
5806 _test((10,), -1, 1, expected_error=RuntimeError)
5807 _test((10,), 3, win_length=5, expected_error=RuntimeError)
5808 _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
5809 _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
5811 def test_stft(self):
5812 self._test_stft(self)
5814 @unittest.skip(
"Not implemented yet")
5815 def test_conv2(self):
5816 x = torch.rand(math.floor(torch.uniform(50, 100)), math.floor(torch.uniform(50, 100)))
5817 k = torch.rand(math.floor(torch.uniform(10, 20)), math.floor(torch.uniform(10, 20)))
5818 imvc = torch.conv2(x, k)
5819 imvc2 = torch.conv2(x, k,
'V')
5820 imfc = torch.conv2(x, k,
'F')
5825 for i
in range(ks.size() - 1, 0, -1):
5826 kis[ks.size() - i + 1] = ks[i]
5828 imvx = torch.xcorr2(x, ki)
5829 imvx2 = torch.xcorr2(x, ki,
'V')
5830 imfx = torch.xcorr2(x, ki,
'F')
5832 self.assertEqual(imvc, imvc2, 0,
'torch.conv2')
5833 self.assertEqual(imvc, imvx, 0,
'torch.conv2')
5834 self.assertEqual(imvc, imvx2, 0,
'torch.conv2')
5835 self.assertEqual(imfc, imfx, 0,
'torch.conv2')
5836 self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr2(x, x)[0][0]), 1e-10,
'torch.conv2')
5838 xx = torch.Tensor(2, x.size(1), x.size(2))
5841 kk = torch.Tensor(2, k.size(1), k.size(2))
5845 immvc = torch.conv2(xx, kk)
5846 immvc2 = torch.conv2(xx, kk,
'V')
5847 immfc = torch.conv2(xx, kk,
'F')
5849 self.assertEqual(immvc[0], immvc[1], 0,
'torch.conv2')
5850 self.assertEqual(immvc[0], imvc, 0,
'torch.conv2')
5851 self.assertEqual(immvc2[0], imvc2, 0,
'torch.conv2')
5852 self.assertEqual(immfc[0], immfc[1], 0,
'torch.conv2')
5853 self.assertEqual(immfc[0], imfc, 0,
'torch.conv2')
5855 @unittest.skip(
"Not implemented yet")
5856 def test_conv3(self):
5857 x = torch.rand(math.floor(torch.uniform(20, 40)),
5858 math.floor(torch.uniform(20, 40)),
5859 math.floor(torch.uniform(20, 40)))
5860 k = torch.rand(math.floor(torch.uniform(5, 10)),
5861 math.floor(torch.uniform(5, 10)),
5862 math.floor(torch.uniform(5, 10)))
5863 imvc = torch.conv3(x, k)
5864 imvc2 = torch.conv3(x, k,
'V')
5865 imfc = torch.conv3(x, k,
'F')
5870 for i
in range(ks.size() - 1, 0, -1):
5871 kis[ks.size() - i + 1] = ks[i]
5872 imvx = torch.xcorr3(x, ki)
5873 imvx2 = torch.xcorr3(x, ki,
'V')
5874 imfx = torch.xcorr3(x, ki,
'F')
5876 self.assertEqual(imvc, imvc2, 0,
'torch.conv3')
5877 self.assertEqual(imvc, imvx, 0,
'torch.conv3')
5878 self.assertEqual(imvc, imvx2, 0,
'torch.conv3')
5879 self.assertEqual(imfc, imfx, 0,
'torch.conv3')
5880 self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr3(x, x)[0][0][0]), 4e-10,
'torch.conv3')
5882 xx = torch.Tensor(2, x.size(1), x.size(2), x.size(3))
5885 kk = torch.Tensor(2, k.size(1), k.size(2), k.size(3))
5889 immvc = torch.conv3(xx, kk)
5890 immvc2 = torch.conv3(xx, kk,
'V')
5891 immfc = torch.conv3(xx, kk,
'F')
5893 self.assertEqual(immvc[0], immvc[1], 0,
'torch.conv3')
5894 self.assertEqual(immvc[0], imvc, 0,
'torch.conv3')
5895 self.assertEqual(immvc2[0], imvc2, 0,
'torch.conv3')
5896 self.assertEqual(immfc[0], immfc[1], 0,
'torch.conv3')
5897 self.assertEqual(immfc[0], imfc, 0,
'torch.conv3')
5899 @unittest.skip(
"Not implemented yet")
5900 def _test_conv_corr_eq(self, fn, fn_2_to_3):
5901 ix = math.floor(random.randint(20, 40))
5902 iy = math.floor(random.randint(20, 40))
5903 iz = math.floor(random.randint(20, 40))
5904 kx = math.floor(random.randint(5, 10))
5905 ky = math.floor(random.randint(5, 10))
5906 kz = math.floor(random.randint(5, 10))
5908 x = torch.rand(ix, iy, iz)
5909 k = torch.rand(kx, ky, kz)
5912 o32 = torch.zeros(o3.size())
5913 fn_2_to_3(x, k, o3, o32)
5914 self.assertEqual(o3, o32)
5916 @unittest.skip(
"Not implemented yet")
5917 def test_xcorr3_xcorr2_eq(self):
5918 def reference(x, k, o3, o32):
5919 for i
in range(o3.size(1)):
5920 for j
in range(k.size(1)):
5921 o32[i].add(torch.xcorr2(x[i + j - 1], k[j]))
5922 self._test_conv_corr_eq(torch.xcorr3, reference)
5924 @unittest.skip(
"Not implemented yet")
5925 def test_xcorr3_xcorr2_eq_full(self):
5926 def reference(x, k, o3, o32):
5927 for i
in range(x.size(1)):
5928 for j
in range(k.size(1)):
5929 o32[i].add(torch.xcorr2(x[i], k[k.size(1) - j + 1],
'F'))
5930 self._test_conv_corr_eq(
lambda x, k: torch.xcorr3(x, k,
'F'), reference)
5932 @unittest.skip(
"Not implemented yet")
5933 def test_conv3_conv2_eq_valid(self):
5934 def reference(x, k, o3, o32):
5935 for i
in range(o3.size(1)):
5936 for j
in range(k.size(1)):
5937 o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1]))
5938 self._test_conv_corr_eq(torch.conv3, reference)
5940 @unittest.skip(
"Not implemented yet")
5941 def test_fconv3_fconv2_eq(self):
5942 def reference(x, k, o3, o32):
5943 for i
in range(o3.size(1)):
5944 for j
in range(k.size(1)):
5945 o32[i + j - 1].add(torch.conv2(x[i], k[j],
'F'))
5946 self._test_conv_corr_eq(
lambda x, k: torch.conv3(x, k,
'F'), reference)
5948 def test_logical(self):
5949 x = torch.rand(100, 100) * 2 - 1
5951 xgt = torch.gt(x, 1)
5952 xlt = torch.lt(x, 1)
5954 xeq = torch.eq(x, 1)
5955 xne = torch.ne(x, 1)
5959 self.assertEqual(neqs.long().sum(), xne.long().sum(), 0)
5960 self.assertEqual(x.nelement(), all.long().sum())
5962 def test_isfinite(self):
5963 x = torch.Tensor([1, inf, 2, -inf, nan, -10])
5964 self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 0, 1, 0, 0, 1]))
5966 def test_isfinite_int(self):
5968 self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 1, 1]))
5971 def _test_isinf(self, cast):
5972 t1 = cast(torch.Tensor([1, inf, 2, -inf, nan]))
5973 t2 = cast(torch.ByteTensor([1, 2, 3]))
5974 t3 = cast(torch.CharTensor([1, 2, 3]))
5975 t4 = cast(torch.ShortTensor([1, 2, 3]))
5976 t5 = cast(torch.IntTensor([1, 2, 3]))
5977 t6 = cast(torch.LongTensor([1, 2, 3]))
5978 self.assertEqual(torch.isinf(t1), cast(torch.ByteTensor([0, 1, 0, 1, 0])))
5979 self.assertEqual(torch.isinf(t2), cast(torch.ByteTensor([0, 0, 0])))
5980 self.assertEqual(torch.isinf(t3), cast(torch.ByteTensor([0, 0, 0])))
5981 self.assertEqual(torch.isinf(t4), cast(torch.ByteTensor([0, 0, 0])))
5982 self.assertEqual(torch.isinf(t5), cast(torch.ByteTensor([0, 0, 0])))
5983 self.assertEqual(torch.isinf(t6), cast(torch.ByteTensor([0, 0, 0])))
5985 def test_isinf(self):
5986 self._test_isinf(self,
lambda t: t)
5988 def test_isnan(self):
5989 x = torch.Tensor([1, nan, 2])
5990 self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0]))
5992 def test_RNGState(self):
5993 state = torch.get_rng_state()
5994 stateCloned = state.clone()
5995 before = torch.rand(1000)
5997 self.assertEqual(state.ne(stateCloned).long().sum(), 0, 0)
5999 torch.set_rng_state(state)
6000 after = torch.rand(1000)
6001 self.assertEqual(before, after, 0)
6003 def test_RNGStateAliasing(self):
6005 gen = torch.Generator()
6006 gen.set_state(torch.get_rng_state())
6007 self.assertEqual(gen.get_state(), torch.get_rng_state())
6009 target_value = torch.rand(1000)
6011 _ = torch.rand(100000)
6012 forked_value = torch.rand(1000, generator=gen)
6013 self.assertEqual(target_value, forked_value, 0,
"RNG has not forked correctly.")
6015 def test_RNG_after_pickle(self):
6017 before = torch.rand(10)
6021 tensor = torch.Tensor([1, 2, 3])
6022 ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor)
6023 after = torch.rand(10)
6025 self.assertEqual(before, after, 0)
6027 def test_boxMullerState(self):
6028 torch.manual_seed(123)
6030 seeded = torch.randn(odd_number)
6031 state = torch.get_rng_state()
6032 midstream = torch.randn(odd_number)
6033 torch.set_rng_state(state)
6034 repeat_midstream = torch.randn(odd_number)
6035 torch.manual_seed(123)
6036 reseeded = torch.randn(odd_number)
6037 self.assertEqual(midstream, repeat_midstream, 0,
6038 'get_rng_state/set_rng_state not generating same sequence of normally distributed numbers')
6039 self.assertEqual(seeded, reseeded, 0,
6040 'repeated calls to manual_seed not generating same sequence of normally distributed numbers')
6042 def test_manual_seed(self):
6043 rng_state = torch.get_rng_state()
6044 torch.manual_seed(2)
6045 x = torch.randn(100)
6046 self.assertEqual(torch.initial_seed(), 2)
6047 torch.manual_seed(2)
6048 y = torch.randn(100)
6049 self.assertEqual(x, y)
6050 torch.set_rng_state(rng_state)
6053 def _test_cholesky(self, cast):
6054 x = cast(torch.rand(10, 10) + 1e-1)
6055 A = torch.mm(x, x.t())
6058 C = torch.cholesky(A)
6059 B = torch.mm(C, C.t())
6060 self.assertEqual(A, B, 1e-14)
6063 U = torch.cholesky(A,
True)
6064 B = torch.mm(U.t(), U)
6065 self.assertEqual(A, B, 1e-14,
'cholesky (upper) did not allow rebuilding the original matrix')
6068 L = torch.cholesky(A,
False)
6069 B = torch.mm(L, L.t())
6070 self.assertEqual(A, B, 1e-14,
'cholesky (lower) did not allow rebuilding the original matrix')
6073 def test_cholesky(self):
6074 self._test_cholesky(self,
lambda t: t)
6077 def _test_cholesky_batched(self, cast):
6078 from common_utils
import random_symmetric_pd_matrix
6080 def cholesky_test_helper(n, batch_dims, cast, upper):
6081 A = cast(random_symmetric_pd_matrix(n, *batch_dims))
6082 cholesky_exp = torch.stack([m.cholesky(upper=upper)
for m
in A.reshape(-1, n, n)])
6083 cholesky_exp = cholesky_exp.reshape_as(A)
6084 self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
6086 for upper, batchsize
in product([
True,
False], [(3,), (3, 4), (2, 3, 4)]):
6087 cholesky_test_helper(3, batchsize, cast, upper)
6090 def test_cholesky_batched(self):
6091 self._test_cholesky_batched(self,
lambda t: t)
6094 def _test_cholesky_solve(self, cast):
6095 a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
6096 (-6.05, -3.30, 5.36, -4.44, 1.08),
6097 (-0.45, 2.58, -2.70, 0.27, 9.04),
6098 (8.32, 2.71, 4.35, -7.17, 2.14),
6099 (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
6100 b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
6101 (-1.56, 4.00, -8.67, 1.75, 2.86),
6102 (9.81, -4.09, -4.57, -8.61, 8.99))).t()
6105 a = torch.mm(a, a.t())
6106 a, b = cast(a), cast(b)
6109 U = torch.cholesky(a,
True)
6110 x = torch.cholesky_solve(b, U,
True)
6111 self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
6114 L = torch.cholesky(a,
False)
6115 x = torch.cholesky_solve(b, L,
False)
6116 self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
6119 L_def = torch.cholesky(a)
6120 x_def = torch.cholesky_solve(b, L_def)
6121 self.assertLessEqual(b.dist(torch.mm(a, x_def)), 1e-12)
6124 def test_cholesky_solve(self):
6125 self._test_cholesky_solve(self,
lambda t: t)
6128 def _test_cholesky_solve_batched(self, cast):
6129 from common_utils
import random_symmetric_pd_matrix
6131 def cholesky_solve_test_helper(A_dims, b_dims, cast, upper):
6132 A = cast(random_symmetric_pd_matrix(*A_dims))
6133 L = torch.cholesky(A, upper)
6134 b = cast(torch.randn(*b_dims))
6137 for upper
in [
True,
False]:
6139 A, L, b = cholesky_solve_test_helper((5, 1), (1, 5, 10), cast, upper)
6140 x_exp = torch.cholesky_solve(b.squeeze(0), L.squeeze(0), upper=upper)
6141 x = torch.cholesky_solve(b, L, upper=upper)
6142 self.assertEqual(x, x_exp.unsqueeze(0))
6145 A, L, b = cholesky_solve_test_helper((5, 4), (4, 5, 10), cast, upper)
6148 x_exp = torch.cholesky_solve(b[i], L[i], upper=upper)
6149 x_exp_list.append(x_exp)
6150 x_exp = torch.stack(x_exp_list)
6152 x = torch.cholesky_solve(b, L, upper=upper)
6153 self.assertEqual(x, x_exp)
6156 A, L, b = cholesky_solve_test_helper((5, 3), (3, 5, 10), cast, upper)
6157 x = torch.cholesky_solve(b, L, upper)
6158 self.assertLessEqual(b.dist(torch.matmul(A, x)), 1e-12)
6164 from numpy.linalg
import solve
6165 A = random_symmetric_pd_matrix(2, 2)
6166 b = torch.randn(2, 2, 2)
6167 x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy()))
6168 A = cast(A).permute(0, 2, 1)
6169 b = cast(b).permute(2, 1, 0)
6170 assert not A.is_contiguous()
and not b.is_contiguous(),
"contiguous inputs" 6171 L = torch.cholesky(A, upper)
6172 x = torch.cholesky_solve(b, L, upper=upper)
6173 self.assertEqual(x, cast(x_exp))
6176 def test_cholesky_solve_batched(self):
6177 self._test_cholesky_solve_batched(self,
lambda t: t)
6180 def _test_cholesky_solve_batched_dims(self, cast):
6184 from numpy.linalg
import solve
6185 from common_utils
import random_symmetric_pd_matrix
6187 def run_test(A_dims, b_dims, cast, upper):
6188 A = random_symmetric_pd_matrix(*A_dims)
6189 b = torch.randn(*b_dims)
6190 x_exp = torch.Tensor(solve(A.numpy(), b.numpy()))
6191 A, b = cast(A), cast(b)
6192 L = torch.cholesky(A, upper)
6193 x = torch.cholesky_solve(b, L, upper=upper)
6194 self.assertEqual(x, cast(x_exp))
6196 for upper
in [
True,
False]:
6198 run_test((4, 2, 1, 3), (2, 1, 3, 4, 6), cast, upper)
6199 run_test((4, 2, 1, 3), (4, 6), cast, upper)
6200 run_test((4,), (2, 1, 3, 4, 2), cast, upper)
6201 run_test((4, 1, 3, 1), (2, 1, 3, 4, 5), cast, upper)
6204 def test_cholesky_solve_batched_dims(self):
6205 self._test_cholesky_solve_batched_dims(self,
lambda t: t)
6208 def test_potri(self):
6209 a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
6210 (-6.05, -3.30, 5.36, -4.44, 1.08),
6211 (-0.45, 2.58, -2.70, 0.27, 9.04),
6212 (8.32, 2.71, 4.35, -7.17, 2.14),
6213 (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
6216 a = torch.mm(a, a.t())
6219 inv0 = torch.inverse(a)
6222 chol = torch.cholesky(a)
6223 inv1 = torch.potri(chol,
False)
6224 self.assertLessEqual(inv0.dist(inv1), 1e-12)
6227 chol = torch.cholesky(a,
True)
6228 inv1 = torch.potri(chol,
True)
6229 self.assertLessEqual(inv0.dist(inv1), 1e-12)
6232 chol = torch.cholesky(a,
False)
6233 inv1 = torch.potri(chol,
False)
6234 self.assertLessEqual(inv0.dist(inv1), 1e-12)
6237 def test_pstrf(self):
6238 def checkPsdCholesky(a, uplo, inplace):
6240 u = torch.empty_like(a)
6241 piv = a.new(a.size(0)).int()
6242 kwargs = {
'out': (u, piv)}
6247 if uplo
is not None:
6250 u, piv = torch.pstrf(*args, **kwargs)
6253 a_reconstructed = torch.mm(u, u.t())
6255 a_reconstructed = torch.mm(u.t(), u)
6258 a_permuted = a.index_select(0, piv).index_select(1, piv)
6259 self.assertEqual(a_permuted, a_reconstructed, 1e-14)
6261 dimensions = ((5, 1), (5, 3), (5, 5), (10, 10))
6262 for dim
in dimensions:
6263 m = torch.Tensor(*dim).uniform_()
6264 a = torch.mm(m, m.t())
6266 for i
in range(m.size(0)):
6267 a[i][i] = a[i][i] + 1e-7
6268 for inplace
in (
True,
False):
6269 for uplo
in (
None,
True,
False):
6270 checkPsdCholesky(a, uplo, inplace)
6272 def test_numel(self):
6273 b = torch.ByteTensor(3, 100, 100)
6274 self.assertEqual(b.nelement(), 3 * 100 * 100)
6275 self.assertEqual(b.numel(), 3 * 100 * 100)
6277 def _consecutive(self, size, start=1):
6278 sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0)
6279 sequence.add_(start - 1)
6280 return sequence.resize_(*size)
6283 def _test_index(self, conv_fn):
6285 def consec(size, start=1):
6286 sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0)
6287 sequence.add_(start - 1)
6288 return sequence.view(*size)
6290 reference = conv_fn(consec((3, 3, 3)))
6293 self.assertEqual(reference[conv_fn(torch.LongTensor())], reference.new(0, 3, 3))
6295 self.assertEqual(reference[0], consec((3, 3)), 0)
6296 self.assertEqual(reference[1], consec((3, 3), 10), 0)
6297 self.assertEqual(reference[2], consec((3, 3), 19), 0)
6298 self.assertEqual(reference[0, 1], consec((3,), 4), 0)
6299 self.assertEqual(reference[0:2], consec((2, 3, 3)), 0)
6300 self.assertEqual(reference[2, 2, 2], 27, 0)
6301 self.assertEqual(reference[:], consec((3, 3, 3)), 0)
6304 self.assertEqual(reference[..., 2], torch.Tensor([[3, 6, 9],
6307 self.assertEqual(reference[0, ..., 2], torch.Tensor([3, 6, 9]), 0)
6308 self.assertEqual(reference[..., 2], reference[:, :, 2], 0)
6309 self.assertEqual(reference[0, ..., 2], reference[0, :, 2], 0)
6310 self.assertEqual(reference[0, 2, ...], reference[0, 2], 0)
6311 self.assertEqual(reference[..., 2, 2, 2], 27, 0)
6312 self.assertEqual(reference[2, ..., 2, 2], 27, 0)
6313 self.assertEqual(reference[2, 2, ..., 2], 27, 0)
6314 self.assertEqual(reference[2, 2, 2, ...], 27, 0)
6315 self.assertEqual(reference[...], reference, 0)
6317 reference_5d = conv_fn(consec((3, 3, 3, 3, 3)))
6318 self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], 0)
6319 self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], 0)
6320 self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], 0)
6321 self.assertEqual(reference_5d[...], reference_5d, 0)
6324 reference = conv_fn(consec((5, 5, 5)))
6325 idx = conv_fn(torch.LongTensor([2, 4]))
6326 self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
6332 self.assertEqual(reference[2,
None], reference[2].unsqueeze(0))
6333 self.assertEqual(reference[2,
None,
None], reference[2].unsqueeze(0).unsqueeze(0))
6334 self.assertEqual(reference[2:4,
None], reference[2:4].unsqueeze(1))
6335 self.assertEqual(reference[
None, 2,
None,
None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0))
6336 self.assertEqual(reference[
None, 2:5,
None,
None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2))
6339 self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)])
6340 self.assertEqual(torch.empty(0, 5), reference[slice(0), 2])
6341 self.assertEqual(torch.empty(0, 5), reference[2, slice(0)])
6342 self.assertEqual(
torch.tensor([]), reference[2, 1:1, 2])
6345 reference = consec((10, 10, 10))
6346 self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0))
6347 self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0))
6348 self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0))
6349 self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1))
6350 self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0))
6351 self.assertEqual(reference[
None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0))
6352 self.assertEqual(reference[:, 2, 1:6:2],
6353 torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1))
6355 lst = [list(range(i, i + 10))
for i
in range(0, 100, 10)]
6356 tensor = conv_fn(torch.DoubleTensor(lst))
6357 for _i
in range(100):
6358 idx1_start = random.randrange(10)
6359 idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
6360 idx1_step = random.randrange(1, 8)
6361 idx1 = slice(idx1_start, idx1_end, idx1_step)
6362 if random.randrange(2) == 0:
6363 idx2_start = random.randrange(10)
6364 idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
6365 idx2_step = random.randrange(1, 8)
6366 idx2 = slice(idx2_start, idx2_end, idx2_step)
6367 lst_indexed = list(map(
lambda l: l[idx2], lst[idx1]))
6368 tensor_indexed = tensor[idx1, idx2]
6370 lst_indexed = lst[idx1]
6371 tensor_indexed = tensor[idx1]
6372 self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
6374 self.assertRaises(ValueError,
lambda: reference[1:9:0])
6375 self.assertRaises(ValueError,
lambda: reference[1:9:-1])
6377 self.assertRaises(IndexError,
lambda: reference[1, 1, 1, 1])
6378 self.assertRaises(IndexError,
lambda: reference[1, 1, 1, 1:1])
6379 self.assertRaises(IndexError,
lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
6381 self.assertRaises(IndexError,
lambda: reference[0.0])
6382 self.assertRaises(TypeError,
lambda: reference[0.0:2.0])
6383 self.assertRaises(IndexError,
lambda: reference[0.0, 0.0:2.0])
6384 self.assertRaises(IndexError,
lambda: reference[0.0, :, 0.0:2.0])
6385 self.assertRaises(IndexError,
lambda: reference[0.0, ..., 0.0:2.0])
6386 self.assertRaises(IndexError,
lambda: reference[0.0, :, 0.0])
6391 self.assertRaises(TypeError, delitem)
6393 def test_index(self):
6394 self._test_index(self,
lambda x: x)
6397 def _test_advancedindex(self, conv_fn):
6401 def consec(size, start=1):
6402 numel = reduce(
lambda x, y: x * y, size, 1)
6403 sequence = torch.ones(numel).cumsum(0)
6404 sequence.add_(start - 1)
6405 return sequence.view(*size)
6409 choice = random.randint(0, 2)
6411 return conv_fn(torch.LongTensor(indices))
6413 return list(indices)
6415 return tuple(indices)
6417 def validate_indexing(x):
6418 self.assertEqual(x[[0]], consec((1,)))
6419 self.assertEqual(x[ri([0]), ], consec((1,)))
6420 self.assertEqual(x[ri([3]), ], consec((1,), 4))
6421 self.assertEqual(x[[2, 3, 4]], consec((3,), 3))
6422 self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3))
6423 self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([1, 3, 5]))
6425 def validate_setting(x):
6428 self.assertEqual(x[[0]], torch.Tensor([-2]).type(dtype))
6430 self.assertEqual(x[ri([0]), ], torch.Tensor([-1]).type(dtype))
6432 self.assertEqual(x[[2, 3, 4]], torch.Tensor([4, 4, 4]).type(dtype))
6433 x[ri([2, 3, 4]), ] = 3
6434 self.assertEqual(x[ri([2, 3, 4]), ], torch.Tensor([3, 3, 3]).type(dtype))
6435 x[ri([0, 2, 4]), ] = conv_fn(torch.Tensor([5, 4, 3])).type(dtype)
6436 self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([5, 4, 3]).type(dtype))
6441 reference = conv_fn(consec((10,)))
6442 validate_indexing(reference)
6443 validate_indexing(reference.type(torch.half))
6446 validate_setting(reference)
6447 validate_setting(reference.type(torch.half))
6452 reference = conv_fn(consec((10,)))
6453 strided = conv_fn(torch.Tensor())
6454 strided.set_(reference.storage(), storage_offset=0,
6455 size=torch.Size([4]), stride=[2])
6457 self.assertEqual(strided[[0]], torch.Tensor([1]))
6458 self.assertEqual(strided[ri([0]), ], torch.Tensor([1]))
6459 self.assertEqual(strided[ri([3]), ], torch.Tensor([7]))
6460 self.assertEqual(strided[[1, 2]], torch.Tensor([3, 5]))
6461 self.assertEqual(strided[ri([1, 2]), ], torch.Tensor([3, 5]))
6462 self.assertEqual(strided[ri([[2, 1], [0, 3]]), ],
6463 torch.Tensor([[5, 3], [1, 7]]))
6466 strided = conv_fn(torch.Tensor())
6467 strided.set_(reference.storage(), storage_offset=4,
6468 size=torch.Size([2]), stride=[4])
6469 self.assertEqual(strided[[0]], torch.Tensor([5]))
6470 self.assertEqual(strided[ri([0]), ], torch.Tensor([5]))
6471 self.assertEqual(strided[ri([1]), ], torch.Tensor([9]))
6472 self.assertEqual(strided[[0, 1]], torch.Tensor([5, 9]))
6473 self.assertEqual(strided[ri([0, 1]), ], torch.Tensor([5, 9]))
6474 self.assertEqual(strided[ri([[0, 1], [1, 0]]), ],
6475 torch.Tensor([[5, 9], [9, 5]]))
6480 reference = conv_fn(consec((3, 2)))
6481 self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([1, 3, 5]))
6482 self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([2, 4, 6]))
6483 self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
6484 self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
6485 self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([1, 2]))
6486 self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
6487 torch.Tensor([2, 4, 4, 2, 6]))
6488 self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6489 torch.Tensor([1, 2, 3, 3]))
6494 self.assertEqual(reference[rows, columns], torch.Tensor([[1, 1],
6499 columns = ri([1, 0])
6500 self.assertEqual(reference[rows, columns], torch.Tensor([[2, 1],
6504 columns = ri([[0, 1],
6506 self.assertEqual(reference[rows, columns], torch.Tensor([[1, 2],
6510 reference[ri([0]), ri([1])] = -1
6511 self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1]))
6512 reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4]))
6513 self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1,
6515 reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6516 self.assertEqual(reference[rows, columns],
6517 torch.Tensor([[4, 6], [2, 3]]))
6521 reference = conv_fn(torch.Tensor([[0, 1, 2, 3],
6523 [8, 9, 10, 11]])).t_()
6530 self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([0, 1,
6532 self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([4, 5,
6534 self.assertEqual(reference[ri([0]), ri([0])], torch.Tensor([0]))
6535 self.assertEqual(reference[ri([2]), ri([1])], torch.Tensor([6]))
6536 self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([0, 4]))
6537 self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
6538 torch.Tensor([4, 5, 5, 4, 7]))
6539 self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6540 torch.Tensor([0, 4, 1, 1]))
6545 self.assertEqual(reference[rows, columns], torch.Tensor([[0, 0],
6550 columns = ri([1, 0])
6551 self.assertEqual(reference[rows, columns], torch.Tensor([[4, 0],
6555 columns = ri([[0, 1],
6557 self.assertEqual(reference[rows, columns], torch.Tensor([[0, 4],
6561 reference[ri([0]), ri([1])] = -1
6562 self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1]))
6563 reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4]))
6564 self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1,
6566 reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6567 self.assertEqual(reference[rows, columns],
6568 torch.Tensor([[4, 6], [2, 3]]))
6575 reference = conv_fn(torch.arange(0., 24).view(3, 8))
6576 strided = conv_fn(torch.Tensor())
6577 strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
6580 self.assertEqual(strided[ri([0, 1]), ri([0])], torch.Tensor([1, 9]))
6581 self.assertEqual(strided[ri([0, 1]), ri([1])], torch.Tensor([3, 11]))
6582 self.assertEqual(strided[ri([0]), ri([0])], torch.Tensor([1]))
6583 self.assertEqual(strided[ri([1]), ri([3])], torch.Tensor([15]))
6584 self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]], torch.Tensor([1, 7]))
6585 self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
6586 torch.Tensor([9, 11, 11, 9, 15]))
6587 self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6588 torch.Tensor([1, 3, 9, 9]))
6593 self.assertEqual(strided[rows, columns], torch.Tensor([[1, 1],
6598 columns = ri([1, 2])
6599 self.assertEqual(strided[rows, columns], torch.Tensor([[3, 13],
6603 columns = ri([[0, 1],
6605 self.assertEqual(strided[rows, columns], torch.Tensor([[1, 3],
6613 reference = conv_fn(torch.arange(0., 24).view(3, 8))
6614 strided = conv_fn(torch.Tensor())
6615 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6617 self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([11]))
6618 strided[ri([0]), ri([1])] = -1
6619 self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([-1]))
6621 reference = conv_fn(torch.arange(0., 24).view(3, 8))
6622 strided = conv_fn(torch.Tensor())
6623 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6625 self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([11,
6627 strided[ri([0, 1]), ri([1, 0])] = conv_fn(torch.Tensor([-1, 2]))
6628 self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([-1,
6631 reference = conv_fn(torch.arange(0., 24).view(3, 8))
6632 strided = conv_fn(torch.Tensor())
6633 strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6638 columns = ri([[0, 1],
6640 self.assertEqual(strided[rows, columns],
6641 torch.Tensor([[10, 11], [17, 18]]))
6642 strided[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6643 self.assertEqual(strided[rows, columns],
6644 torch.Tensor([[4, 6], [2, 3]]))
6651 reference = conv_fn(consec((3, 2)))
6652 self.assertEqual(reference[ri([0, 2]), ], torch.Tensor([[1, 2], [5, 6]]))
6653 self.assertEqual(reference[ri([1]), ...], torch.Tensor([[3, 4]]))
6654 self.assertEqual(reference[..., ri([1])], torch.Tensor([[2], [4], [6]]))
6657 with self.assertRaises(IndexError):
6658 reference[ri([1]), ri([0, 2]), ri([3])]
6661 reference = conv_fn(torch.empty(10))
6663 if not reference.is_cuda:
6664 for err_idx
in (10, -11):
6665 with self.assertRaisesRegex(IndexError,
r'out of'):
6667 with self.assertRaisesRegex(IndexError,
r'out of'):
6668 reference[conv_fn(torch.LongTensor([err_idx]))]
6669 with self.assertRaisesRegex(IndexError,
r'out of'):
6670 reference[[err_idx]]
6677 def tensor_indices_to_np(tensor, indices):
6679 if (tensor.is_cuda):
6680 tensor = tensor.cpu()
6681 npt = tensor.numpy()
6684 idxs = tuple(i.tolist()
if isinstance(i, torch.LongTensor)
else 6689 def get_numpy(tensor, indices):
6690 npt, idxs = tensor_indices_to_np(tensor, indices)
6693 return torch.Tensor(npt[idxs])
6695 def set_numpy(tensor, indices, value):
6696 if not isinstance(value, int):
6699 value = value.numpy()
6701 npt, idxs = tensor_indices_to_np(tensor, indices)
6705 def assert_get_eq(tensor, indexer):
6706 self.assertEqual(tensor[indexer],
6707 conv_fn(get_numpy(tensor, indexer)))
6709 def assert_set_eq(tensor, indexer, val):
6710 pyt = tensor.clone()
6711 numt = tensor.clone()
6713 numt = conv_fn(torch.Tensor(set_numpy(numt, indexer, val)))
6714 self.assertEqual(pyt, numt)
6716 def get_set_tensor(indexed, indexer):
6717 set_size = indexed[indexer].size()
6718 set_count = indexed[indexer].numel()
6719 set_tensor = conv_fn(torch.randperm(set_count).view(set_size).double())
6726 reference = conv_fn(torch.arange(0., 20).view(4, 5))
6730 [slice(
None), [1, 3]],
6733 [[0, 2], slice(
None)],
6736 [slice(
None), [[0, 1],
6741 [slice(
None), [-1]],
6745 get_indices_to_test = indices_to_test + [[slice(
None), [0, 1, 1, 2, 2]]]
6747 for indexer
in get_indices_to_test:
6748 assert_get_eq(reference, indexer)
6750 for indexer
in indices_to_test:
6751 assert_set_eq(reference, indexer, 44)
6752 assert_set_eq(reference,
6754 get_set_tensor(reference, indexer))
6756 reference = conv_fn(torch.arange(0., 160).view(4, 8, 5))
6759 [slice(
None), slice(
None), [0, 3, 4]],
6760 [slice(
None), [2, 4, 5, 7], slice(
None)],
6761 [[2, 3], slice(
None), slice(
None)],
6762 [slice(
None), [0, 2, 3], [1, 3, 4]],
6763 [slice(
None), [0], [1, 2, 4]],
6764 [slice(
None), [0, 1, 3], [4]],
6765 [slice(
None), [[0, 1], [1, 0]], [[2, 3]]],
6766 [slice(
None), [[0, 1], [2, 3]], [[0]]],
6767 [slice(
None), [[5, 6]], [[0, 3], [4, 4]]],
6768 [[0, 2, 3], [1, 3, 4], slice(
None)],
6769 [[0], [1, 2, 4], slice(
None)],
6770 [[0, 1, 3], [4], slice(
None)],
6771 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(
None)],
6772 [[[0, 1], [1, 0]], [[2, 3]], slice(
None)],
6773 [[[0, 1], [2, 3]], [[0]], slice(
None)],
6774 [[[2, 1]], [[0, 3], [4, 4]], slice(
None)],
6775 [[[2]], [[0, 3], [4, 1]], slice(
None)],
6779 [[0, 2], slice(
None)],
6781 [[0, 2], slice(
None), Ellipsis],
6782 [[0, 2], Ellipsis, slice(
None)],
6784 [[0, 2], [1, 3], Ellipsis],
6785 [Ellipsis, [1, 3], [2, 3]],
6786 [Ellipsis, [2, 3, 4]],
6787 [Ellipsis, slice(
None), [2, 3, 4]],
6788 [slice(
None), Ellipsis, [2, 3, 4]],
6791 [Ellipsis, slice(
None), slice(
None), [0, 3, 4]],
6792 [slice(
None), Ellipsis, slice(
None), [0, 3, 4]],
6793 [slice(
None), slice(
None), Ellipsis, [0, 3, 4]],
6794 [slice(
None), slice(
None), [0, 3, 4], Ellipsis],
6795 [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(
None)],
6796 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(
None)],
6797 [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(
None), Ellipsis],
6800 for indexer
in indices_to_test:
6801 assert_get_eq(reference, indexer)
6802 assert_set_eq(reference, indexer, 212)
6803 assert_set_eq(reference,
6805 get_set_tensor(reference, indexer))
6807 reference = conv_fn(torch.arange(0., 1296).view(3, 9, 8, 6))
6810 [slice(
None), slice(
None), slice(
None), [0, 3, 4]],
6811 [slice(
None), slice(
None), [2, 4, 5, 7], slice(
None)],
6812 [slice(
None), [2, 3], slice(
None), slice(
None)],
6813 [[1, 2], slice(
None), slice(
None), slice(
None)],
6814 [slice(
None), slice(
None), [0, 2, 3], [1, 3, 4]],
6815 [slice(
None), slice(
None), [0], [1, 2, 4]],
6816 [slice(
None), slice(
None), [0, 1, 3], [4]],
6817 [slice(
None), slice(
None), [[0, 1], [1, 0]], [[2, 3]]],
6818 [slice(
None), slice(
None), [[0, 1], [2, 3]], [[0]]],
6819 [slice(
None), slice(
None), [[5, 6]], [[0, 3], [4, 4]]],
6820 [slice(
None), [0, 2, 3], [1, 3, 4], slice(
None)],
6821 [slice(
None), [0], [1, 2, 4], slice(
None)],
6822 [slice(
None), [0, 1, 3], [4], slice(
None)],
6823 [slice(
None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(
None)],
6824 [slice(
None), [[0, 1], [3, 4]], [[2, 3]], slice(
None)],
6825 [slice(
None), [[0, 1], [3, 2]], [[0]], slice(
None)],
6826 [slice(
None), [[2, 1]], [[0, 3], [6, 4]], slice(
None)],
6827 [slice(
None), [[2]], [[0, 3], [4, 2]], slice(
None)],
6828 [[0, 1, 2], [1, 3, 4], slice(
None), slice(
None)],
6829 [[0], [1, 2, 4], slice(
None), slice(
None)],
6830 [[0, 1, 2], [4], slice(
None), slice(
None)],
6831 [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(
None), slice(
None)],
6832 [[[0, 1], [1, 2]], [[2, 0]], slice(
None), slice(
None)],
6833 [[[2, 2]], [[0, 3], [4, 5]], slice(
None), slice(
None)],
6834 [[[2]], [[0, 3], [4, 5]], slice(
None), slice(
None)],
6835 [slice(
None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
6836 [slice(
None), [2, 3, 4], [1, 3, 4], [4]],
6837 [slice(
None), [0, 1, 3], [4], [1, 3, 4]],
6838 [slice(
None), [6], [0, 2, 3], [1, 3, 4]],
6839 [slice(
None), [2, 3, 5], [3], [4]],
6840 [slice(
None), [0], [4], [1, 3, 4]],
6841 [slice(
None), [6], [0, 2, 3], [1]],
6842 [slice(
None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
6843 [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(
None)],
6844 [[2, 0, 1], [1, 2, 3], [4], slice(
None)],
6845 [[0, 1, 2], [4], [1, 3, 4], slice(
None)],
6846 [[0], [0, 2, 3], [1, 3, 4], slice(
None)],
6847 [[0, 2, 1], [3], [4], slice(
None)],
6848 [[0], [4], [1, 3, 4], slice(
None)],
6849 [[1], [0, 2, 3], [1], slice(
None)],
6850 [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(
None)],
6853 [Ellipsis, [0, 3, 4]],
6854 [Ellipsis, slice(
None), [0, 3, 4]],
6855 [Ellipsis, slice(
None), slice(
None), [0, 3, 4]],
6856 [slice(
None), Ellipsis, [0, 3, 4]],
6857 [slice(
None), slice(
None), Ellipsis, [0, 3, 4]],
6858 [slice(
None), [0, 2, 3], [1, 3, 4]],
6859 [slice(
None), [0, 2, 3], [1, 3, 4], Ellipsis],
6860 [Ellipsis, [0, 2, 3], [1, 3, 4], slice(
None)],
6862 [[0], [1, 2, 4], slice(
None)],
6863 [[0], [1, 2, 4], Ellipsis],
6864 [[0], [1, 2, 4], Ellipsis, slice(
None)],
6866 [[0, 2, 1], [3], [4]],
6867 [[0, 2, 1], [3], [4], slice(
None)],
6868 [[0, 2, 1], [3], [4], Ellipsis],
6869 [Ellipsis, [0, 2, 1], [3], [4]],
6872 for indexer
in indices_to_test:
6873 assert_get_eq(reference, indexer)
6874 assert_set_eq(reference, indexer, 1333)
6875 assert_set_eq(reference,
6877 get_set_tensor(reference, indexer))
6878 indices_to_test += [
6879 [slice(
None), slice(
None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
6880 [slice(
None), slice(
None), [[2]], [[0, 3], [4, 4]]],
6882 for indexer
in indices_to_test:
6883 assert_get_eq(reference, indexer)
6884 assert_set_eq(reference, indexer, 1333)
6886 def test_advancedindex(self):
6887 self._test_advancedindex(self,
lambda x: x)
6890 def _test_advancedindex_big(self, conv_fn):
6891 reference = conv_fn(torch.arange(0, 123344).int())
6893 self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
6894 torch.LongTensor([0, 123, 44488, 68807, 123343]))
6896 def test_advancedindex_big(self):
6897 self._test_advancedindex_big(self,
lambda x: x)
6899 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
6900 def test_newaxis_numpy_comparison(self):
6901 def run_test(tensor, *idx):
6902 npt = tensor.numpy()
6903 self.assertEqual(tensor[idx], npt[idx])
6906 x = torch.arange(0, 10)
6914 [Ellipsis,
None, 2],
6915 [Ellipsis, 2,
None],
6916 [2, Ellipsis,
None],
6917 [2,
None, Ellipsis],
6918 [
None, 2, Ellipsis],
6919 [
None, Ellipsis, 2],
6926 x = torch.arange(0, 12).view(3, 4)
6932 [Ellipsis,
None,
None],
6934 [
None, Ellipsis,
None],
6935 [
None,
None, Ellipsis],
6937 [2,
None, Ellipsis],
6938 [2, Ellipsis,
None],
6939 [
None, 2, Ellipsis],
6940 [Ellipsis, 2,
None],
6941 [Ellipsis,
None, 2],
6942 [
None, Ellipsis, 2],
6944 [1, 2, Ellipsis,
None],
6945 [1, Ellipsis, 2,
None],
6946 [Ellipsis, 1,
None, 2],
6947 [Ellipsis, 1, 2,
None],
6948 [1,
None, 2, Ellipsis],
6949 [
None, 1, Ellipsis, 2],
6950 [
None, 1, 2, Ellipsis],
6956 def test_newindex(self):
6957 reference = self._consecutive((3, 3, 3))
6960 def checkPartialAssign(index):
6961 reference = torch.zeros(3, 3, 3)
6962 reference[index] = self._consecutive((3, 3, 3))[index]
6963 self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], 0)
6964 reference[index] = 0
6965 self.assertEqual(reference, torch.zeros(3, 3, 3), 0)
6967 checkPartialAssign(0)
6968 checkPartialAssign(1)
6969 checkPartialAssign(2)
6970 checkPartialAssign((0, 1))
6971 checkPartialAssign((1, 2))
6972 checkPartialAssign((0, 2))
6973 checkPartialAssign(torch.LongTensor((0, 2)))
6975 with self.assertRaises(IndexError):
6976 reference[1, 1, 1, 1] = 1
6977 with self.assertRaises(IndexError):
6978 reference[1, 1, 1, (1, 1)] = 1
6979 with self.assertRaises(IndexError):
6980 reference[3, 3, 3, 3, 3, 3, 3, 3] = 1
6981 with self.assertRaises(IndexError):
6983 with self.assertRaises(TypeError):
6984 reference[0.0:2.0] = 1
6985 with self.assertRaises(IndexError):
6986 reference[0.0, 0.0:2.0] = 1
6987 with self.assertRaises(IndexError):
6988 reference[0.0, :, 0.0:2.0] = 1
6989 with self.assertRaises(IndexError):
6990 reference[0.0, ..., 0.0:2.0] = 1
6991 with self.assertRaises(IndexError):
6992 reference[0.0, :, 0.0] = 1
6994 def test_index_copy(self):
6995 num_copy, num_dest = 3, 20
6996 dest = torch.randn(num_dest, 4, 5)
6997 src = torch.randn(num_copy, 4, 5)
6998 idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
6999 dest2 = dest.clone()
7000 dest.index_copy_(0, idx, src)
7001 for i
in range(idx.size(0)):
7002 dest2[idx[i]] = src[i]
7003 self.assertEqual(dest, dest2, 0)
7005 dest = torch.randn(num_dest)
7006 src = torch.randn(num_copy)
7007 idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7008 dest2 = dest.clone()
7009 dest.index_copy_(0, idx, src)
7010 for i
in range(idx.size(0)):
7011 dest2[idx[i]] = src[i]
7012 self.assertEqual(dest, dest2, 0)
7014 def test_index_add(self):
7015 num_copy, num_dest = 3, 3
7016 dest = torch.randn(num_dest, 4, 5)
7017 src = torch.randn(num_copy, 4, 5)
7018 idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7019 dest2 = dest.clone()
7020 dest.index_add_(0, idx, src)
7021 for i
in range(idx.size(0)):
7022 dest2[idx[i]] += src[i]
7023 self.assertEqual(dest, dest2)
7025 dest = torch.randn(num_dest)
7026 src = torch.randn(num_copy)
7027 idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7028 dest2 = dest.clone()
7029 dest.index_add_(0, idx, src)
7030 for i
in range(idx.size(0)):
7031 dest2[idx[i]] = dest2[idx[i]] + src[i]
7032 self.assertEqual(dest, dest2)
7034 def test_index_select(self):
7035 src = torch.randn(3, 4, 5)
7037 idx = torch.LongTensor([2, 1, 0, 1, 2])
7038 dest = torch.index_select(src, 0, idx)
7039 self.assertEqual(dest.shape, (5, 4, 5))
7040 for i
in range(idx.size(0)):
7041 self.assertEqual(dest[i], src[idx[i]])
7044 out = torch.randn(5 * 4 * 5)
7045 dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5))
7046 self.assertEqual(dest.shape, (5, 4, 5))
7047 for i
in range(idx.size(0)):
7048 self.assertEqual(dest[i], src[idx[i]])
7050 self.assertEqual(out, dest.view(-1))
7055 self.assertEqual(x, x.t())
7057 self.assertEqual(x, x.t())
7061 self.assertEqual(x, x.t())
7063 self.assertEqual(x, x.t())
7066 x = torch.rand((2, 2))
7067 self.assertEqual(x.t(), x.transpose(0, 1))
7069 self.assertEqual(x.t(), x.transpose(0, 1))
7072 x = torch.rand((2, 2, 2))
7073 with self.assertRaisesRegex(RuntimeError,
'expects a tensor with <= 2 dimensions, but self is 3D'):
7076 with self.assertRaisesRegex(RuntimeError,
'expects a tensor with <= 2 sparse and 0 dense dimensions'):
7079 def test_take(self):
7080 def check(src, idx):
7081 expected = src.contiguous().view(-1).index_select(
7082 0, idx.contiguous().view(-1)).view_as(idx)
7083 actual = src.take(idx)
7084 self.assertEqual(actual.size(), idx.size())
7085 self.assertEqual(expected, actual)
7087 src = torch.randn(2, 3, 5)
7088 idx = torch.LongTensor([[0, 2], [3, 4]])
7090 check(src.transpose(1, 2), idx)
7092 def test_take_empty(self):
7094 for device
in devices:
7095 for input_shape
in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
7096 for indices_shape
in [(0,), (0, 1, 2, 0)]:
7097 input = torch.empty(input_shape, device=device)
7098 indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
7099 self.assertEqual(indices, torch.take(input, indices))
7101 def test_put_(self):
7102 def check(dst, idx, value):
7103 expected = dst.clone().view(-1).index_copy_(
7104 0, idx.contiguous().view(-1), value.contiguous().view(-1))
7105 expected = expected.view_as(dst)
7106 dst.put_(idx, value)
7107 self.assertEqual(expected, dst)
7109 dst = torch.randn(2, 3, 5)
7110 idx = torch.LongTensor([[0, 2], [3, 4]])
7111 values = torch.randn(2, 2)
7112 check(dst, idx, values)
7113 check(dst.transpose(1, 2), idx, values)
7115 def test_put_accumulate(self):
7116 dst = torch.ones(2, 2)
7117 idx = torch.LongTensor([[0, 1], [0, 1]])
7118 src = torch.Tensor([1, 2, 3, 4])
7119 dst.put_(idx, src, accumulate=
True)
7120 self.assertEqual(dst.tolist(), [[5, 7], [1, 1]])
7122 def test_put_empty(self):
7124 for device
in devices:
7125 for dst_shape
in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
7126 for indices_shape
in [(0,), (0, 1, 2, 0)]:
7127 for accumulate
in [
False,
True]:
7128 dst = torch.randn(dst_shape, device=device)
7129 indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
7130 src = torch.randn(indices_shape, device=device)
7131 self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate))
7135 def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
7136 for i
in range(1
if dim == 0
else m):
7137 for j
in range(1
if dim == 1
else n):
7138 for k
in range(1
if dim == 2
else o):
7140 ii[dim] = slice(0, idx.size(dim) + 1)
7141 idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
7143 def test_flatten(self):
7144 src = torch.randn(5, 5, 5, 5)
7145 flat = src.flatten(0, -1)
7146 self.assertEqual(flat.shape, torch.Size([625]))
7147 self.assertEqual(src.view(-1), flat.view(-1))
7149 flat = src.flatten(0, 2)
7150 self.assertEqual(flat.shape, torch.Size([125, 5]))
7151 self.assertEqual(src.view(-1), flat.view(-1))
7153 flat = src.flatten(0, 1)
7154 self.assertEqual(flat.shape, torch.Size([25, 5, 5]))
7155 self.assertEqual(src.view(-1), flat.view(-1))
7157 flat = src.flatten(1, 2)
7158 self.assertEqual(flat.shape, torch.Size([5, 25, 5]))
7159 self.assertEqual(src.view(-1), flat.view(-1))
7161 flat = src.flatten(2, 3)
7162 self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
7163 self.assertEqual(src.view(-1), flat.view(-1))
7165 flat = src.flatten(-2, -1)
7166 self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
7167 self.assertEqual(src.view(-1), flat.view(-1))
7169 flat = src.flatten(2, 2)
7170 self.assertEqual(flat, src)
7173 with self.assertRaisesRegex(IndexError,
'Dimension out of range'):
7177 with self.assertRaisesRegex(RuntimeError,
'start_dim cannot come after end_dim'):
7181 def _test_gather(self, cast, test_bounds=True):
7182 m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
7183 elems_per_row = random.randint(1, 10)
7184 dim = random.randrange(3)
7186 src = torch.randn(m, n, o)
7187 idx_size = [m, n, o]
7188 idx_size[dim] = elems_per_row
7189 idx = torch.LongTensor().resize_(*idx_size)
7190 _TestTorchMixin._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
7195 actual = torch.gather(src, dim, idx)
7196 expected = cast(torch.Tensor().resize_(*idx_size))
7197 for i
in range(idx_size[0]):
7198 for j
in range(idx_size[1]):
7199 for k
in range(idx_size[2]):
7201 ii[dim] = idx[i, j, k]
7202 expected[i, j, k] = src[tuple(ii)]
7203 self.assertEqual(actual, expected, 0)
7207 self.assertRaises(RuntimeError,
lambda: torch.gather(src, dim, idx))
7209 src = cast(torch.randn(3, 4, 5))
7210 expected, idx = src.max(2,
True)
7211 expected = cast(expected)
7213 actual = torch.gather(src, 2, idx)
7214 self.assertEqual(actual, expected, 0)
7216 def test_gather(self):
7217 self._test_gather(self,
lambda t: t)
7220 def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True):
7221 m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
7222 elems_per_row = random.randint(1, 10)
7223 dim = random.randrange(3)
7225 idx_size = [m, n, o]
7226 idx_size[dim] = elems_per_row
7227 idx = cast(torch.LongTensor().resize_(*idx_size))
7228 _TestTorchMixin._fill_indices(self, idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o)
7231 src = random.random()
7233 src = cast(torch.Tensor(*idx_size).normal_())
7235 base = cast(torch.randn(m, n, o))
7236 actual = getattr(base.clone(), method)(dim, idx, src)
7237 expected = base.clone()
7238 for i
in range(idx_size[0]):
7239 for j
in range(idx_size[1]):
7240 for k
in range(idx_size[2]):
7242 ii[dim] = idx[i, j, k]
7243 if method ==
'scatter_' and not is_scalar:
7244 expected[tuple(ii)] = src[i, j, k]
7245 elif method ==
'scatter_add_':
7246 expected[tuple(ii)] += src[i, j, k]
7248 expected[tuple(ii)] = src
7249 self.assertEqual(actual, expected, 0)
7253 with self.assertRaises(RuntimeError):
7254 getattr(base.clone(), method)(dim, idx, src)
7257 idx = cast(torch.LongTensor())
7258 actual = getattr(base.clone(), method)(dim, idx, src)
7259 self.assertEqual(actual, base, 0)
7261 def test_scatter(self):
7262 self._test_scatter_base(self,
lambda t: t,
'scatter_')
7264 def test_scatterAdd(self):
7265 self._test_scatter_base(self,
lambda t: t,
'scatter_add_')
7267 def test_scatterFill(self):
7268 self._test_scatter_base(self,
lambda t: t,
'scatter_',
True)
7270 def test_masked_scatter(self):
7271 num_copy, num_dest = 3, 10
7272 dest = torch.randn(num_dest)
7273 src = torch.randn(num_copy)
7274 mask = torch.ByteTensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0))
7275 dest2 = dest.clone()
7276 dest.masked_scatter_(mask, src)
7278 for i
in range(num_dest):
7282 self.assertEqual(dest, dest2, 0)
7285 src = torch.randn(num_dest)
7286 dest.masked_scatter_(mask, src)
7289 src = torch.randn(num_copy - 1)
7290 with self.assertRaises(RuntimeError):
7291 dest.masked_scatter_(mask, src)
7293 def test_masked_select(self):
7295 src = torch.randn(num_src)
7296 mask = torch.rand(num_src).clamp(0, 1).mul(2).floor().byte()
7297 dst = src.masked_select(mask)
7299 for i
in range(num_src):
7302 self.assertEqual(dst, torch.Tensor(dst2), 0)
7304 def test_masked_fill(self):
7306 dst = torch.randn(num_dest)
7307 mask = torch.rand(num_dest).mul(2).floor().byte()
7308 val = random.random()
7310 dst.masked_fill_(mask, val)
7311 for i
in range(num_dest):
7314 self.assertEqual(dst, dst2, 0)
7317 dst = torch.randn(num_dest, num_dest, num_dest).permute((2, 0, 1))
7319 dst.masked_fill_(dst > 0, val)
7320 dst2.masked_fill_(dst2 > 0, val)
7321 self.assertEqual(dst, dst2, 0)
7324 def _test_abs(tensors_dict):
7325 for _category, tensors
in tensors_dict.items():
7326 for data
in tensors:
7327 _test_abs_single(data)
7329 def _test_abs_single(data):
7330 switch = torch.rand(data.size()).mul(2).floor().mul(2).add(-1).type(data.dtype)
7331 res = torch.mul(data, switch)
7332 self.assertTensorsSlowEqual(res.abs(), data, 1e-16)
7334 shapes = [(3, 4), (3, 5, 7), (2, 2, 5, 8, 2, 3), (1000,), (10, 10, 10)]
7336 for shape
in shapes:
7338 _test_abs(self._make_tensors(shape, val_range=(0, 1000)))
7341 _test_abs_single(torch.CharTensor(*shape).random_(0, 100))
7344 byte_tensor = torch.ByteTensor(*shape).random_(0, 100)
7345 self.assertTensorsSlowEqual(byte_tensor, byte_tensor.abs(), 1e-16)
7348 bignumber = 2 ^ 31 + 1
7349 res = torch.LongTensor((-bignumber,))
7350 self.assertGreater(res.abs()[0], 0)
7353 rec = torch.randn(2, 2, 3, 7, 6, 2).type(torch.float64).clamp(0, 1)
7354 val1 = rec.select(-1, -1).data[0][0][0].sum()
7355 val2 = rec.select(-1, -1).data.abs()[0][0][0].sum()
7356 self.assertEqual(val1, val2, 1e-8,
'absolute value')
7358 def test_namedtuple_return(self):
7359 a = torch.randn(5, 5)
7362 for f
in [
'max',
'min',
'median',
'mode']:
7363 ret = getattr(a, f)(dim=0)
7364 self.assertEqual(ret.values, ret[0])
7365 self.assertEqual(ret.indices, ret[1])
7366 ret1 = getattr(torch, f)(a, dim=0, out=ret)
7367 self.assertEqual(ret1.values, ret1[0])
7368 self.assertEqual(ret1.indices, ret1[1])
7369 self.assertEqual(ret1.values, ret[0])
7370 self.assertEqual(ret1.indices, ret[1])
7373 ret = a.kthvalue(1, dim=0)
7374 self.assertEqual(ret.values, ret[0])
7375 self.assertEqual(ret.indices, ret[1])
7376 ret1 = torch.kthvalue(a, 1, dim=0, out=ret)
7377 self.assertEqual(ret1.values, ret1[0])
7378 self.assertEqual(ret1.indices, ret1[1])
7379 self.assertEqual(ret1.values, ret[0])
7380 self.assertEqual(ret1.indices, ret[1])
7384 self.assertEqual(ret.U, ret[0])
7385 self.assertEqual(ret.S, ret[1])
7386 self.assertEqual(ret.V, ret[2])
7387 ret1 = torch.svd(a, out=ret)
7388 self.assertEqual(ret1.U, ret1[0])
7389 self.assertEqual(ret1.S, ret1[1])
7390 self.assertEqual(ret1.V, ret1[2])
7391 self.assertEqual(ret1.U, ret[0])
7392 self.assertEqual(ret1.S, ret[1])
7393 self.assertEqual(ret1.V, ret[2])
7396 fn = [
'symeig',
'eig']
7398 ret = getattr(torch, f)(a, eigenvectors=
True)
7399 self.assertEqual(ret.eigenvalues, ret[0])
7400 self.assertEqual(ret.eigenvectors, ret[1])
7401 ret1 = getattr(torch, f)(a, out=tuple(ret))
7402 self.assertEqual(ret1.eigenvalues, ret[0])
7403 self.assertEqual(ret1.eigenvectors, ret[1])
7404 self.assertEqual(ret1.eigenvalues, ret1[0])
7405 self.assertEqual(ret1.eigenvectors, ret1[1])
7408 b = torch.mm(a, a.t())
7410 for i
in range(a.size(0)):
7411 b[i][i] = b[i][i] + 1e-7
7413 self.assertEqual(ret.u, ret[0])
7414 self.assertEqual(ret.pivot, ret[1])
7415 ret1 = torch.pstrf(b, out=tuple(ret))
7416 self.assertEqual(ret1.u, ret1[0])
7417 self.assertEqual(ret1.pivot, ret1[1])
7418 self.assertEqual(ret1.u, ret[0])
7419 self.assertEqual(ret1.pivot, ret[1])
7423 self.assertEqual(ret.Q, ret[0])
7424 self.assertEqual(ret.R, ret[1])
7425 ret1 = torch.qr(a, out=tuple(ret))
7426 self.assertEqual(ret1.Q, ret1[0])
7427 self.assertEqual(ret1.R, ret1[1])
7431 self.assertEqual(ret.a, ret[0])
7432 self.assertEqual(ret.tau, ret[1])
7433 ret1 = torch.geqrf(a, out=tuple(ret))
7434 self.assertEqual(ret1.a, ret1[0])
7435 self.assertEqual(ret1.tau, ret1[1])
7437 def test_hardshrink(self):
7438 data_original =
torch.tensor([1, 0.5, 0.3, 0.6]).view(2, 2)
7440 'torch.DoubleTensor',
7443 for t
in float_types:
7444 data = data_original.type(t)
7445 self.assertEqual(
torch.tensor([1, 0.5, 0, 0.6]).view(2, 2), data.hardshrink(0.3))
7446 self.assertEqual(
torch.tensor([1, 0, 0, 0.6]).view(2, 2), data.hardshrink(0.5))
7449 self.assertEqual(data.hardshrink(), data.hardshrink(0.5))
7452 self.assertEqual(
torch.tensor([1, 0, 0.5, 0.6]).view(2, 2), data.t().hardshrink(0.3))
7454 def test_unbiased(self):
7455 tensor = torch.randn(100)
7456 self.assertEqual(tensor.var(0), tensor.var(0, unbiased=
True))
7457 self.assertEqual(tensor.var(), tensor.var(unbiased=
True))
7458 self.assertEqual(tensor.var(unbiased=
False), tensor.var(0, unbiased=
False))
7460 tensor = torch.FloatTensor([1.0, 2.0])
7461 self.assertEqual(tensor.var(unbiased=
True), 0.5)
7462 self.assertEqual(tensor.var(unbiased=
False), 0.25)
7464 tensor = torch.FloatTensor([1.0, 2.0, 3.0])
7465 self.assertEqual(tensor.var(unbiased=
True), 1.0)
7466 self.assertEqual(tensor.var(unbiased=
False), 2.0 / 3.0)
7468 tensor = torch.randn(100)
7469 self.assertEqual(tensor.std(0), tensor.std(0, unbiased=
True))
7470 self.assertEqual(tensor.std(), tensor.std(unbiased=
True))
7471 self.assertEqual(tensor.std(unbiased=
False), tensor.std(0, unbiased=
False))
7473 def test_structseq_repr(self):
7474 a = torch.arange(250).reshape(5, 5, 10)
7476 torch.return_types.max( 7477 values=tensor([[ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], 7478 [ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], 7479 [140, 141, 142, 143, 144, 145, 146, 147, 148, 149], 7480 [190, 191, 192, 193, 194, 195, 196, 197, 198, 199], 7481 [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]), 7482 indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 7483 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 7484 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 7485 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 7486 [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))""" 7487 self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip())
7489 def test_var_stability(self):
7490 tensor = torch.FloatTensor([2281.5, 2281.25])
7491 self.assertEqual(tensor.var(dim=0), 0.03125)
7492 self.assertEqual(tensor.var(), 0.03125)
7495 def _test_view(self, cast):
7496 tensor = cast(torch.rand(15))
7497 template = cast(torch.rand(3, 5))
7498 empty = cast(torch.empty(0))
7499 target = template.size()
7500 self.assertEqual(tensor.view_as(template).size(), target)
7501 self.assertEqual(tensor.view(3, 5).size(), target)
7502 self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
7503 self.assertEqual(tensor.view(-1, 5).size(), target)
7504 self.assertEqual(tensor.view(3, -1).size(), target)
7505 tensor_view = tensor.view(5, 3)
7506 tensor_view.fill_(random.uniform(0, 1))
7507 self.assertEqual(empty.view_as(empty), empty)
7508 self.assertEqual(empty.view(0), empty)
7509 self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
7510 self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
7513 self.assertEqual(empty.view(-1).size(), torch.Size([0]))
7514 self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
7516 with self.assertRaisesRegex(RuntimeError,
r"because the unspecified dimension size -1 can be any value"):
7519 with self.assertRaisesRegex(RuntimeError,
r"because the unspecified dimension size -1 can be any value"):
7520 empty.view(3, 0, -1, 0)
7522 self.assertRaises(RuntimeError,
lambda: tensor.view(15, 0))
7523 self.assertRaises(RuntimeError,
lambda: tensor.view(7, -1))
7524 self.assertRaises(RuntimeError,
lambda: tensor.view(15, -1, -1))
7528 tensor = cast(torch.rand(4, 2, 5, 1, 6, 2, 9, 3)).transpose(-1, 2).transpose(-2, 3)
7533 contig_tensor = tensor.clone()
7539 view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5]
7540 self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
7546 view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1]
7547 self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
7549 view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1]
7550 self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
7553 self.assertRaises(RuntimeError,
lambda: tensor.view(-1))
7555 self.assertRaises(RuntimeError,
lambda: tensor.view(24, 9, 6, 2, 1, 5))
7557 self.assertRaises(RuntimeError,
lambda: tensor.view(8, 3, 9, 6, 10))
7559 self.assertRaises(RuntimeError,
lambda: tensor.view(8, 3, 54, 2, 1, 5))
7562 tensor = cast(torch.empty(1, 1)).expand(3, 4)
7563 contig_tensor = tensor.clone()
7564 self.assertEqual(tensor.view(-1), contig_tensor.view(-1))
7565 self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1))
7566 self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1))
7567 self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1))
7568 self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1))
7570 def test_view(self):
7571 _TestTorchMixin._test_view(self,
lambda x: x)
7573 def test_view_empty(self):
7574 x = torch.randn(0, 6)
7575 self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
7577 def test_reshape(self):
7578 x = torch.randn(3, 3)
7579 self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
7580 self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
7581 self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
7582 self.assertRaises(RuntimeError,
lambda: x.reshape(-1, -1))
7584 y = torch.randn(4, 4, 4)[:, 0, :]
7585 self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
7586 self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
7587 self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
7590 self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
7591 self.assertEqual(s.reshape(-1).shape, (1,))
7592 self.assertRaises(RuntimeError,
lambda: s.reshape(2))
7595 self.assertEqual(empty, empty.reshape(-1))
7596 self.assertEqual(empty, empty.reshape([0]))
7598 self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
7599 self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
7600 self.assertRaises(RuntimeError,
lambda: empty.reshape(1))
7602 x = torch.randn(3, 3)
7603 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
7604 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
7605 self.assertRaises(RuntimeError,
lambda: x.reshape_as(torch.rand(10)))
7607 def test_empty_reshape(self):
7608 x = torch.randn(0, 6)
7609 self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
7611 self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
7614 self.assertRaises(RuntimeError,
lambda: x.reshape(0, -1))
7616 def test_tensor_shape_empty(self):
7618 for device
in devices:
7619 x = torch.randn((0, 1, 3, 0), device=device)
7621 self.assertEqual((0,), torch.flatten(x, 0, 3).shape)
7622 self.assertEqual((0, 0), torch.flatten(x, 0, 2).shape)
7623 self.assertEqual((0, 3, 0), torch.flatten(x, 1, 2).shape)
7626 self.assertEqual((0, 1, 1, 3, 0), torch.unsqueeze(x, 1).shape)
7627 self.assertEqual((0, 3, 0), torch.squeeze(x, 1).shape)
7628 self.assertEqual((0, 3, 0), torch.squeeze(x).shape)
7631 self.assertEqual((0, 0, 3, 1), torch.transpose(x, 1, 3).shape)
7632 y = torch.randn((5, 0), device=device)
7633 self.assertEqual((0, 5), y.t().shape)
7636 self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape)
7638 self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
7639 y = torch.randn((0, 1, 3), device=device)
7640 self.assertEqual((1, 1, 3, 0), y.unfold(0, 0, 4).shape)
7643 self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape)
7644 self.assertEqual((3, 0, 0, 1), x.permute(2, 3, 0, 1).shape)
7647 self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device)).shape)
7648 self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device)).shape)
7650 self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device), offset=1).shape)
7651 self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device), offset=1).shape)
7653 self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=45252).shape)
7654 self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=-45252).shape)
7656 self.assertEqual((0, 0), torch.diagflat(
torch.tensor([], device=device)).shape)
7657 self.assertEqual(torch.zeros(1, 1), torch.diagflat(
torch.tensor([], device=device), offset=1))
7658 self.assertEqual((0, 0), torch.diagflat(
torch.tensor([[]], device=device)).shape)
7659 self.assertEqual(torch.zeros(1, 1), torch.diagflat(
torch.tensor([[]], device=device), offset=1))
7662 self.assertEqual((4, 0, 1, 3, 0), torch.stack((x, x, x, x)).shape)
7663 self.assertEqual([(0, 1, 3, 0)],
7664 [z.shape
for z
in torch.chunk(x, 1, dim=0)])
7666 self.assertEqual([(0, 1, 3, 0), ] * 3, [z.shape
for z
in torch.chunk(x, 3, dim=0)])
7667 self.assertEqual([(0, 1, 1, 0), ] * 3, [z.shape
for z
in torch.chunk(x, 3, dim=2)])
7671 self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)],
7672 [z.shape
for z
in torch.split(x, (0, 1, 2), dim=2)])
7674 self.assertRaises(RuntimeError,
lambda: torch.split(x, 0, dim=1))
7677 self.assertEqual([(0, 1, 3, 0)], [z.shape
for z
in torch.split(x, 1, dim=0)])
7678 self.assertEqual([(0, 1, 3, 0)], [z.shape
for z
in torch.split(x, 0, dim=0)])
7682 def test_dim_function_empty(self):
7684 for device
in devices:
7685 shape = (0, 1, 2, 0)
7686 x = torch.randn(shape, device=device)
7689 self.assertEqual(0, x.size(3))
7690 self.assertEqual(2, x.size(2))
7691 self.assertEqual(2, x.stride(0))
7692 self.assertEqual(1, x.stride(2))
7707 self.assertEqual(shape, torch.cumsum(x, 0).shape)
7708 self.assertEqual(shape, torch.cumsum(x, 2).shape)
7709 self.assertEqual(shape, torch.cumprod(x, 0).shape)
7710 self.assertEqual(shape, torch.cumprod(x, 2).shape)
7713 self.assertEqual(x, x.flip(0))
7714 self.assertEqual(x, x.flip(2))
7717 self.assertEqual(x, x.roll(0, 1).roll(0, -1))
7718 self.assertEqual(x, x.roll(1, x.size(1)))
7719 self.assertEqual(x, x.roll(1))
7720 self.assertEqual(x, x.roll((1, 1), (3, 1)))
7723 self.assertEqual((), x.unbind(0))
7724 self.assertEqual((torch.empty((0, 1, 0), device=device), torch.empty((0, 1, 0), device=device)),
7728 y = torch.randn((0, 1, 3, 0), device=device)
7729 self.assertEqual(y.shape, torch.cross(y, y).shape)
7732 self.assertEqual(shape, torch.renorm(x, 1, 0, 5).shape)
7733 self.assertEqual(shape, torch.renorm(x, 1, 2, 5).shape)
7736 self.assertEqual([shape, shape], [z.shape
for z
in torch.sort(x, dim=0)])
7737 self.assertEqual([shape, shape], [z.shape
for z
in torch.sort(x, dim=2)])
7740 self.assertEqual([shape, shape], [z.shape
for z
in torch.topk(x, 0, dim=0)])
7741 self.assertEqual([(0, 1, 1, 0), (0, 1, 1, 0)], [z.shape
for z
in torch.topk(x, 1, dim=2)])
7743 y = torch.randn((2, 3, 4), device=device)
7744 self.assertEqual([(2, 3, 0), (2, 3, 0)], [z.shape
for z
in torch.topk(y, 0)])
7747 self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64, device=device)).shape)
7748 self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64, device=device)).shape)
7749 larger_shape = torch.empty((0, 1, 3, 0), dtype=torch.int64, device=device)
7750 self.assertEqual(larger_shape.shape, torch.gather(x, 2, larger_shape).shape)
7751 smaller_shape = torch.empty((0, 1, 0, 0), dtype=torch.int64, device=device)
7752 self.assertEqual(smaller_shape.shape, torch.gather(x, 2, smaller_shape).shape)
7753 y = torch.randn((2, 3, 4), device=device)
7754 self.assertEqual((0, 3, 4),
7755 torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64, device=device)).shape)
7759 y = torch.randn(shape, device=device)
7760 y_src = torch.randn(shape, device=device)
7761 ind = torch.empty(shape, dtype=torch.int64, device=device)
7762 self.assertEqual(shape, y.scatter_(dim, ind, y_src).shape)
7763 self.assertEqual(shape, y.scatter_add_(dim, ind, y_src).shape)
7765 z = torch.randn((2, 3, 4), device=device)
7766 z_src = torch.randn((2, 3, 4), device=device)
7767 self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src))
7768 self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src))
7773 ind_empty =
torch.tensor([], dtype=torch.int64, device=device)
7774 ind_01 =
torch.tensor([0, 1], dtype=torch.int64, device=device)
7775 self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
7776 self.assertEqual(c_clone, c.index_fill_(2, ind_empty, -1))
7777 self.assertEqual(c_clone, c.index_fill_(2,
torch.tensor([0, 1], dtype=torch.int64, device=device), -1))
7778 self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device)))
7779 self.assertEqual(c_clone, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device)))
7780 self.assertEqual(c_clone, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device)))
7781 self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device)))
7782 self.assertEqual(c_clone, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device)))
7783 self.assertEqual(c_clone, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device)))
7785 c = torch.randn((0, 1, 2), device=device)
7787 self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
7788 self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
7789 self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
7790 self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1))
7791 self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
7792 self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device)))
7795 z = torch.randn((2, 3, 4), device=device)
7796 self.assertEqual(z, z.index_fill_(0, ind_empty, -1))
7797 z = torch.randn((2, 3, 4), device=device)
7798 self.assertEqual(z, z.index_copy_(0, ind_empty, torch.empty((0, 3, 4), device=device)))
7799 z = torch.randn((2, 3, 4), device=device)
7800 self.assertEqual(z, z.index_add_(0, ind_empty, torch.empty((0, 3, 4), device=device)))
7803 self.assertEqual(x, x.index_select(0, ind_empty))
7804 self.assertEqual((0, 1, 0, 0), x.index_select(2, ind_empty).shape)
7805 self.assertEqual(x, x.index_select(2, ind_01))
7806 z = torch.randn((2, 3, 4), device=device)
7807 self.assertEqual((0, 3, 4), z.index_select(0, ind_empty).shape)
7808 c = torch.randn((0, 1, 2), device=device)
7809 self.assertEqual(c, c.index_select(0, ind_empty))
7810 c = torch.randn((0, 1, 2), device=device)
7811 self.assertEqual(c, c.index_select(0, ind_empty))
7814 def test_blas_empty(self):
7816 for device
in devices:
7818 def fn(torchfn, *args):
7819 return torchfn(*tuple(torch.randn(shape, device=device)
if isinstance(shape, tuple)
else shape
7823 self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
7824 self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
7825 self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
7826 self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
7827 self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)))
7829 self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
7830 self.assertEqual((5, 6), fn(torch.addmm, (5, 6), (5, 0), (0, 6)).shape)
7833 self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
7834 self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
7835 self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
7837 self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
7838 self.assertEqual((3,), fn(torch.addmv, (3,), (3, 0), (0,)).shape)
7841 self.assertEqual((0, 0), fn(torch.ger, (0,), (0,)).shape)
7842 self.assertEqual((5, 0), fn(torch.ger, (5,), (0,)).shape)
7843 self.assertEqual((0, 4), fn(torch.ger, (0,), (4,)).shape)
7845 self.assertEqual((0, 0), fn(torch.addr, (0, 0), (0,), (0,)).shape)
7846 self.assertEqual((5, 0), fn(torch.addr, (5, 0), (5,), (0,)).shape)
7847 self.assertEqual((0, 4), fn(torch.addr, (0, 4), (0,), (4,)).shape)
7850 self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
7851 self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
7852 self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
7853 self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)))
7855 self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape)
7856 self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape)
7857 self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape)
7858 self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape)
7861 self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
7862 self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
7863 self.assertEqual((5, 6), fn(torch.addbmm, (5, 6), (0, 5, 0), (0, 0, 6)).shape)
7866 self.assertEqual(
torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,)))
7867 self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
7868 self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
7869 self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
7870 self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4)))
7873 self.assertEqual(
torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
7875 if torch._C.has_lapack:
7877 A_LU, pivots = fn(torch.btrifact, (0, 5, 5))
7878 self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape])
7879 A_LU, pivots = fn(torch.btrifact, (0, 0, 0))
7880 self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape])
7881 A_LU, pivots = fn(torch.btrifact, (2, 0, 0))
7882 self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
7885 def test_blas_alpha_beta_empty(self):
7887 for device
in devices:
7890 input = torch.full((2,), value, device=device)
7891 mat = torch.ones((2, 0), device=device)
7892 vec = torch.ones((0,), device=device)
7893 out = torch.randn((2,), device=device)
7896 self.assertEqual(torch.full((2,), beta * value, device=device),
7897 torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta))
7898 self.assertEqual(torch.full((2,), beta * value, device=device),
7899 torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out))
7902 input = torch.full((2, 3), value, device=device)
7903 mat2 = torch.ones((0, 3), device=device)
7904 out = torch.randn((2, 3), device=device)
7905 self.assertEqual(torch.full((2, 3), beta * value, device=device),
7906 torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta))
7907 self.assertEqual(torch.full((2, 3), beta * value, device=device),
7908 torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out))
7911 def test_lapack_empty(self):
7918 for device
in devices:
7921 empty = torch.randn((0, 0), device=device)
7922 if device ==
'cuda' and not torch.cuda.has_magma:
7925 def fn(torchfn, *args):
7926 return torchfn(*tuple(torch.randn(shape, device=device)
if isinstance(shape, tuple)
else shape
7930 self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape)
7931 self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape)
7932 self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape)
7933 self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape)
7936 self.assertRaises(RuntimeError,
lambda: fn(torch.svd, (0, 0)))
7939 self.assertEqual(
torch.tensor(1., device=device), fn(torch.det, (0, 0)))
7940 self.assertEqual(
torch.tensor(0., device=device), fn(torch.logdet, (0, 0)))
7942 fn(torch.slogdet, (0, 0)))
7945 evalues, evectors = fn(torch.eig, (0, 0),
True)
7946 self.assertEqual([(0, 2), (0, 0)], [evalues.shape, evectors.shape])
7947 evalues, evectors = fn(torch.symeig, (0, 0),
True)
7948 self.assertEqual([(0,), (0, 0)], [evalues.shape, evectors.shape])
7951 self.assertRaises(RuntimeError,
lambda: torch.qr(torch.randn(0, 0)))
7952 self.assertRaises(RuntimeError,
lambda: torch.gels(torch.randn(0, 0), torch.randn(0, 0)))
7953 self.assertRaises(RuntimeError,
lambda: torch.gels(torch.randn(0,), torch.randn(0, 0)))
7955 def test_expand(self):
7956 tensor = torch.rand(1, 8, 1)
7957 tensor2 = torch.rand(5)
7958 template = torch.rand(4, 8, 5)
7959 target = template.size()
7960 self.assertEqual(tensor.expand_as(template).size(), target)
7961 self.assertEqual(tensor.expand(4, 8, 5).size(), target)
7962 self.assertEqual(tensor.expand(target).size(), target)
7963 self.assertEqual(tensor2.expand_as(template).size(), target)
7964 self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
7965 self.assertEqual(tensor2.expand(target).size(), target)
7968 self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
7971 noncontig = torch.randn(5, 2, 1, 3)[:, 0]
7972 self.assertFalse(noncontig.is_contiguous())
7973 self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
7976 expanded = tensor2.expand(1, 1, 5)
7977 unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
7978 self.assertEqual(expanded, unsqueezed)
7979 self.assertEqual(expanded.stride(), unsqueezed.stride())
7982 self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
7983 self.assertRaises(RuntimeError,
lambda: tensor2.expand(-1, -1))
7986 self.assertEqual(torch.zeros(0).expand((0,)), torch.zeros(0))
7988 def test_repeat(self):
7990 initial_shape = (8, 4)
7991 tensor = torch.rand(*initial_shape)
7994 torchSize = torch.Size(size)
7996 self.assertEqual(tensor.repeat(*size).size(), target,
'Error in repeat')
7997 self.assertEqual(tensor.repeat(torchSize).size(), target,
7998 'Error in repeat using LongStorage')
7999 result = tensor.repeat(*size)
8000 self.assertEqual(result.size(), target,
'Error in repeat using result')
8001 result = tensor.repeat(torchSize)
8002 self.assertEqual(result.size(), target,
'Error in repeat using result and LongStorage')
8003 self.assertEqual(result.mean(0).view(8, 4), tensor,
'Error in repeat (not equal)')
8005 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
8006 def test_repeat_tile(self):
8008 initial_shape = (8, 4)
8010 repeats = ((3, 1, 1),
8015 def _generate_noncontiguous_input():
8017 out = np.broadcast_to(np.random.random((1, 4)),
8020 assert not (out.flags.c_contiguous
or out.flags.f_contiguous)
8024 for repeat
in repeats:
8025 for tensor
in (torch.from_numpy(np.random.random(initial_shape)),
8026 torch.from_numpy(_generate_noncontiguous_input()),):
8028 self.assertEqual(tensor.repeat(*repeat).numpy(),
8029 np.tile(tensor.numpy(), repeat))
8031 def test_is_same_size(self):
8032 t1 = torch.Tensor(3, 4, 9, 10)
8033 t2 = torch.Tensor(3, 4)
8034 t3 = torch.Tensor(1, 9, 3, 3)
8035 t4 = torch.Tensor(3, 4, 9, 10)
8037 self.assertFalse(t1.is_same_size(t2))
8038 self.assertFalse(t1.is_same_size(t3))
8039 self.assertTrue(t1.is_same_size(t4))
8041 def test_is_set_to(self):
8042 t1 = torch.Tensor(3, 4, 9, 10)
8043 t2 = torch.Tensor(3, 4, 9, 10)
8044 t3 = torch.Tensor().set_(t1)
8045 t4 = t3.clone().resize_(12, 90)
8046 self.assertFalse(t1.is_set_to(t2))
8047 self.assertTrue(t1.is_set_to(t3))
8048 self.assertTrue(t3.is_set_to(t1),
"is_set_to should be symmetric")
8049 self.assertFalse(t1.is_set_to(t4))
8050 self.assertFalse(torch.Tensor().is_set_to(torch.Tensor()),
8051 "Tensors with no storages should not appear to be set " 8054 def test_tensor_set(self):
8056 t2 = torch.Tensor(3, 4, 9, 10).uniform_()
8058 self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
8059 size = torch.Size([9, 3, 4, 10])
8060 t1.set_(t2.storage(), 0, size)
8061 self.assertEqual(t1.size(), size)
8062 t1.set_(t2.storage(), 0, tuple(size))
8063 self.assertEqual(t1.size(), size)
8064 self.assertEqual(t1.stride(), (120, 40, 10, 1))
8065 stride = (10, 360, 90, 1)
8066 t1.set_(t2.storage(), 0, size, stride)
8067 self.assertEqual(t1.stride(), stride)
8068 t1.set_(t2.storage(), 0, size=size, stride=stride)
8069 self.assertEqual(t1.size(), size)
8070 self.assertEqual(t1.stride(), stride)
8076 self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
8078 t1.set_(source=t2.storage())
8079 self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
8081 t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride)
8082 self.assertEqual(t1.size(), size)
8083 self.assertEqual(t1.stride(), stride)
8085 def test_equal(self):
8087 t1 = torch.Tensor((3, 4, 9, 10))
8088 t2 = t1.contiguous()
8089 t3 = torch.Tensor((1, 9, 3, 10))
8090 t4 = torch.Tensor((3, 4, 9))
8092 self.assertTrue(t1.equal(t2))
8093 self.assertFalse(t1.equal(t3))
8094 self.assertFalse(t1.equal(t4))
8095 self.assertFalse(t1.equal(t5))
8096 self.assertTrue(torch.equal(t1, t2))
8097 self.assertFalse(torch.equal(t1, t3))
8098 self.assertFalse(torch.equal(t1, t4))
8099 self.assertFalse(torch.equal(t1, t5))
8102 s = torch.Tensor(((1, 2, 3, 4), (5, 6, 7, 8)))
8105 s3 = torch.Tensor(((2, 3), (6, 7)))
8106 s4 = torch.Tensor(((0, 0), (0, 0)))
8108 self.assertFalse(s1.is_contiguous())
8109 self.assertTrue(s1.equal(s2))
8110 self.assertTrue(s1.equal(s3))
8111 self.assertFalse(s1.equal(s4))
8112 self.assertTrue(torch.equal(s1, s2))
8113 self.assertTrue(torch.equal(s1, s3))
8114 self.assertFalse(torch.equal(s1, s4))
8116 def test_element_size(self):
8126 self.assertEqual(byte, torch.ByteTensor().element_size())
8127 self.assertEqual(char, torch.CharTensor().element_size())
8128 self.assertEqual(short, torch.ShortTensor().element_size())
8129 self.assertEqual(int, torch.IntTensor().element_size())
8130 self.assertEqual(long, torch.LongTensor().element_size())
8131 self.assertEqual(float, torch.FloatTensor().element_size())
8132 self.assertEqual(double, torch.DoubleTensor().element_size())
8134 self.assertGreater(byte, 0)
8135 self.assertGreater(char, 0)
8136 self.assertGreater(short, 0)
8137 self.assertGreater(int, 0)
8138 self.assertGreater(long, 0)
8139 self.assertGreater(float, 0)
8140 self.assertGreater(double, 0)
8141 self.assertGreater(bool, 0)
8144 self.assertEqual(byte, 1)
8145 self.assertEqual(char, 1)
8146 self.assertEqual(bool, 1)
8147 self.assertGreaterEqual(short, 2)
8148 self.assertGreaterEqual(int, 2)
8149 self.assertGreaterEqual(int, short)
8150 self.assertGreaterEqual(long, 4)
8151 self.assertGreaterEqual(long, int)
8152 self.assertGreaterEqual(double, float)
8154 def test_split(self):
8155 tensor = torch.rand(7, 4)
8158 target_sizes = ([3, 4], [3, 4], [1, 4])
8159 splits = tensor.split(split_size, dim)
8161 for target_size, split
in zip(target_sizes, splits):
8162 self.assertEqual(split.size(), target_size)
8163 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
8164 start = start + target_size[dim]
8167 tensor = torch.randn(20, 10)
8169 split_sizes = [5, 5, 10]
8170 target_sizes = ([[5, 10], [5, 10], [10, 10]])
8171 splits = tensor.split(split_sizes, dim)
8173 for target_size, split
in zip(target_sizes, splits):
8174 self.assertEqual(split.size(), target_size)
8175 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
8176 start = start + target_size[dim]
8178 split_sizes = [2, 2, 6]
8179 target_sizes = ([20, 2], [20, 2], [20, 6])
8181 splits = tensor.split(split_sizes, dim)
8183 for target_size, split
in zip(target_sizes, splits):
8184 self.assertEqual(split.size(), target_size)
8185 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
8186 start = start + target_size[dim]
8188 def test_chunk(self):
8189 tensor = torch.rand(4, 7)
8192 target_sizes = ([4, 3], [4, 3], [4, 1])
8193 splits = tensor.chunk(num_chunks, dim)
8195 for target_size, split
in zip(target_sizes, splits):
8196 self.assertEqual(split.size(), target_size)
8197 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
8198 start = start + target_size[dim]
8201 error_regex =
'chunk expects.*greater than 0' 8202 with self.assertRaisesRegex(RuntimeError, error_regex):
8204 with self.assertRaisesRegex(RuntimeError, error_regex):
8207 def test_tolist(self):
8209 tensor0D = torch.Tensor(list0D)
8210 self.assertEqual(tensor0D.tolist(), list0D)
8213 tensor1D = torch.Tensor(table1D)
8214 storage = torch.Storage(table1D)
8215 self.assertEqual(tensor1D.tolist(), table1D)
8216 self.assertEqual(storage.tolist(), table1D)
8217 self.assertEqual(tensor1D.tolist(), table1D)
8218 self.assertEqual(storage.tolist(), table1D)
8220 table2D = [[1, 2], [3, 4]]
8221 tensor2D = torch.Tensor(table2D)
8222 self.assertEqual(tensor2D.tolist(), table2D)
8224 tensor3D = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
8225 tensorNonContig = tensor3D.select(1, 1)
8226 self.assertFalse(tensorNonContig.is_contiguous())
8227 self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]])
8229 def test_permute(self):
8230 orig = [1, 2, 3, 4, 5, 6, 7]
8231 perm = torch.randperm(7).tolist()
8232 x = torch.Tensor(*orig).fill_(0)
8233 new = list(map(
lambda x: x - 1, x.permute(*perm).size()))
8234 self.assertEqual(perm, new)
8235 self.assertEqual(x.size(), orig)
8238 def _test_flip(self, use_cuda=False):
8239 device = torch.device(
'cuda')
if use_cuda
else torch.device(
'cpu')
8240 data =
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2)
8242 self.assertEqual(
torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0))
8243 self.assertEqual(
torch.tensor([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2), data.flip(1))
8244 self.assertEqual(
torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(2))
8245 self.assertEqual(
torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1))
8246 self.assertEqual(
torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2))
8249 self.assertEqual(
torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1))
8251 self.assertEqual(
torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2))
8252 self.assertEqual(
torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0))
8255 self.assertRaises(RuntimeError,
lambda: data.flip(0, 1, 1))
8257 self.assertRaises(TypeError,
lambda: data.flip())
8260 self.assertRaises(IndexError,
lambda: data.flip(0, 1, 2, 3))
8262 self.assertRaises(IndexError,
lambda: data.flip(3))
8265 expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2)
8266 tranposed_data = torch.arange(1, 9, device=device).view(2, 2, 2).transpose(0, 1)
8267 self.assertEqual(
torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), expanded_data.flip(0))
8268 self.assertEqual(
torch.tensor([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2), tranposed_data.flip(0, 1, 2))
8271 data = torch.randn(2, 3, 4, device=device)
8274 for i
in range(1, 3):
8275 test_dims += combinations(range(len(size)), i)
8277 for ds
in test_dims:
8278 self.assertEqual(size, list(data.flip(ds).size()))
8286 flip0_result = flip0_result.cuda()
8287 flip1_result = flip1_result.cuda()
8288 self.assertEqual(flip0_result, data.flip(0))
8289 self.assertEqual(flip1_result, data.flip(1))
8293 self.assertEqual(data, data.flip(0))
8295 def test_flip(self):
8296 self._test_flip(self, use_cuda=
False)
8298 def test_roll(self):
8300 for device
in devices:
8301 numbers = torch.arange(1, 9, device=device)
8303 single_roll = numbers.roll(1, 0)
8304 expected =
torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device)
8305 self.assertEqual(single_roll, expected,
"{} did not equal expected result".format(single_roll))
8307 roll_backwards = numbers.roll(-2, 0)
8308 expected =
torch.tensor([3, 4, 5, 6, 7, 8, 1, 2], device=device)
8309 self.assertEqual(roll_backwards, expected,
"{} did not equal expected result".format(roll_backwards))
8311 data = numbers.view(2, 2, 2)
8312 rolled = data.roll(1, 0)
8313 expected =
torch.tensor([5, 6, 7, 8, 1, 2, 3, 4], device=device).view(2, 2, 2)
8314 self.assertEqual(expected, rolled,
"{} did not equal expected result: {}".format(rolled, expected))
8316 data = data.view(2, 4)
8318 loop_rolled = data.roll(2, 0).roll(4, 1)
8319 self.assertEqual(data, loop_rolled,
"{} did not equal the original: {}".format(loop_rolled, data))
8321 self.assertEqual(data, data.roll(-20, 0).roll(-40, 1))
8322 self.assertEqual(
torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device), numbers.roll(1, 0))
8326 strided = numbers.view(2, 4).transpose(0, 1)
8327 self.assertFalse(strided.is_contiguous(),
"this test needs a non-contiguous tensor")
8328 expected =
torch.tensor([4, 8, 1, 5, 2, 6, 3, 7]).view(4, 2)
8329 rolled = strided.roll(1, 0)
8330 self.assertEqual(expected, rolled,
8331 "non contiguous tensor rolled to {} instead of {} ".format(rolled, expected))
8334 expected = numbers.roll(1, 0).view(2, 4)
8335 self.assertEqual(expected, data.roll(1),
"roll with no dims should flatten and roll.")
8336 self.assertEqual(expected, data.roll(1, dims=
None),
"roll with no dims should flatten and roll.")
8339 expected =
torch.tensor([[7, 8, 5, 6], [3, 4, 1, 2]], device=device)
8340 double_rolled = data.roll(shifts=(2, -1), dims=(1, 0))
8341 self.assertEqual(double_rolled, expected,
8342 "should be able to roll over two dimensions, got {}".format(double_rolled))
8344 self.assertRaisesRegex(RuntimeError,
"required",
lambda: data.roll(shifts=(), dims=()))
8345 self.assertRaisesRegex(RuntimeError,
"required",
lambda: data.roll(shifts=(), dims=1))
8347 self.assertRaisesRegex(RuntimeError,
"align",
lambda: data.roll(shifts=(1, 2), dims=(1,)))
8348 self.assertRaisesRegex(RuntimeError,
"align",
lambda: data.roll(shifts=(1,), dims=(1, 2)))
8350 def test_reversed(self):
8351 val = torch.arange(0, 10)
8352 self.assertEqual(reversed(val), torch.arange(9, -1, -1))
8354 val = torch.arange(1, 10).view(3, 3)
8355 self.assertEqual(reversed(val),
torch.tensor([[7, 8, 9], [4, 5, 6], [1, 2, 3]]))
8360 def test_contains(self):
8361 x = torch.arange(0, 10)
8362 self.assertEqual(4
in x,
True)
8363 self.assertEqual(12
in x,
False)
8365 x = torch.arange(1, 10).view(3, 3)
8366 val = torch.arange(1, 4)
8367 self.assertEqual(val
in x,
True)
8369 self.assertEqual(val
in x,
False)
8372 def _test_rot90(self, use_cuda=False):
8373 device = torch.device(
"cuda" if use_cuda
else "cpu")
8374 data = torch.arange(1, 5, device=device).view(2, 2)
8375 self.assertEqual(
torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1]))
8376 self.assertEqual(
torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1]))
8377 self.assertEqual(
torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1]))
8378 self.assertEqual(
torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1]))
8381 self.assertEqual(data.rot90(), data.rot90(1, [0, 1]))
8384 self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0]))
8387 self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1]))
8388 self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1]))
8389 self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1]))
8392 self.assertRaises(RuntimeError,
lambda: data.rot90(1, [0, -3]))
8393 self.assertRaises(RuntimeError,
lambda: data.rot90(1, [0, 2]))
8396 data = torch.arange(1, 9, device=device).view(2, 2, 2)
8397 self.assertEqual(
torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]))
8398 self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2]))
8401 self.assertRaises(RuntimeError,
lambda: data.rot90(1, [0, 3]))
8402 self.assertRaises(RuntimeError,
lambda: data.rot90(1, [1, 1]))
8403 self.assertRaises(RuntimeError,
lambda: data.rot90(1, [0, 1, 2]))
8404 self.assertRaises(RuntimeError,
lambda: data.rot90(1, [0]))
8406 def test_rot90(self):
8407 self._test_rot90(self, use_cuda=
False)
8409 def test_storage(self):
8410 v = torch.randn(3, 5)
8411 self.assertEqual(v.storage()[0], v.data[0][0])
8412 self.assertEqual(v.storage()[14], v.data[2][4])
8414 def test_nonzero(self):
8420 'torch.ShortTensor',
8422 'torch.FloatTensor',
8423 'torch.DoubleTensor',
8429 torch.Size((12, 1)),
8430 torch.Size((1, 12)),
8432 torch.Size((3, 2, 2)),
8437 tensor = torch.rand(num_src).mul(2).floor().type(t)
8438 if tensor.sum() > 0:
8440 for shape
in shapes:
8441 tensor = tensor.clone().resize_(shape)
8442 dst1 = torch.nonzero(tensor)
8443 dst2 = tensor.nonzero()
8444 dst3 = torch.LongTensor()
8445 torch.nonzero(tensor, out=dst3)
8448 for i
in range(num_src):
8452 self.assertEqual(dst1.select(1, 0), torch.LongTensor(dst), 0)
8453 self.assertEqual(dst2.select(1, 0), torch.LongTensor(dst), 0)
8454 self.assertEqual(dst3.select(1, 0), torch.LongTensor(dst), 0)
8455 elif len(shape) == 2:
8458 for i
in range(dst1.size(0)):
8459 self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1]].item(), 0)
8460 elif len(shape) == 3:
8463 for i
in range(dst1.size(0)):
8464 self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1], dst1[i, 2]].item(), 0)
8466 def test_nonzero_empty(self):
8468 for device
in devices:
8469 x = torch.randn(0, 2, 0, 5, 0, device=device)
8470 y = torch.nonzero(x)
8471 self.assertEqual(0, y.numel())
8472 self.assertEqual(torch.Size([0, 5]), y.shape)
8475 y = torch.nonzero(x)
8476 self.assertEqual(torch.Size([1, 0]), y.shape)
8478 x = torch.zeros((), device=device)
8479 y = torch.nonzero(x)
8480 self.assertEqual(torch.Size([0, 0]), y.shape)
8482 def test_deepcopy(self):
8483 from copy
import deepcopy
8484 a = torch.randn(5, 5)
8485 b = torch.randn(5, 5)
8487 q = [a, [a.storage(), b.storage()], b, c]
8489 self.assertEqual(w[0], q[0], 0)
8490 self.assertEqual(w[1][0], q[1][0], 0)
8491 self.assertEqual(w[1][1], q[1][1], 0)
8492 self.assertEqual(w[1], q[1], 0)
8493 self.assertEqual(w[2], q[2], 0)
8497 for i
in range(a.numel()):
8498 self.assertEqual(w[1][0][i], q[1][0][i] + 1)
8499 self.assertEqual(w[3], c + 1)
8501 for i
in range(a.numel()):
8502 self.assertEqual(w[1][1][i], q[1][1][i] - 1)
8504 def test_deepcopy_scalar(self):
8505 from copy
import deepcopy
8507 self.assertEqual(a.size(), deepcopy(a).size())
8508 self.assertEqual(a, deepcopy(a))
8510 def test_deepcopy_parameter(self):
8511 from copy
import deepcopy
8512 l = torch.nn.Linear(10, 1)
8513 s = l.state_dict(keep_vars=
True)
8514 self.assertEqual(torch.nn.Parameter, type(s[
'weight']))
8515 self.assertEqual(torch.nn.Parameter, type(s[
'bias']))
8518 self.assertEqual(torch.nn.Parameter, type(s2[
'weight']))
8519 self.assertEqual(torch.nn.Parameter, type(s2[
'bias']))
8521 def test_copy(self):
8522 from copy
import copy
8523 a = torch.randn(5, 5)
8529 self.assertEqual(a, b)
8531 def test_pickle(self):
8532 if sys.version_info[0] == 2:
8533 import cPickle
as pickle
8536 a = torch.randn(5, 5)
8537 serialized = pickle.dumps(a)
8538 b = pickle.loads(serialized)
8539 self.assertEqual(a, b)
8541 def test_pickle_parameter(self):
8542 if sys.version_info[0] == 2:
8543 import cPickle
as pickle
8546 a = torch.nn.Parameter(torch.randn(5, 5))
8547 serialized = pickle.dumps(a)
8548 b = pickle.loads(serialized)
8549 self.assertTrue(isinstance(b, torch.nn.Parameter))
8550 self.assertEqual(a.requires_grad, b.requires_grad)
8551 self.assertEqual(a, b)
8553 def test_pickle_parameter_no_requires_grad(self):
8554 if sys.version_info[0] == 2:
8555 import cPickle
as pickle
8558 a = torch.nn.Parameter(torch.randn(5, 5), requires_grad=
False)
8559 serialized = pickle.dumps(a)
8560 b = pickle.loads(serialized)
8561 self.assertTrue(isinstance(b, torch.nn.Parameter))
8562 self.assertEqual(a.requires_grad, b.requires_grad)
8563 self.assertEqual(a, b)
8565 def test_norm_fastpaths(self):
8566 x = torch.randn(3, 5)
8569 result = torch.norm(x, 4.5, 1)
8570 expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5)
8571 self.assertEqual(result, expected)
8574 result = torch.norm(x, 0, 1)
8575 expected = (x != 0).type_as(x).sum(1)
8576 self.assertEqual(result, expected)
8579 result = torch.norm(x, 1, 1)
8580 expected = x.abs().sum(1)
8581 self.assertEqual(result, expected)
8584 result = torch.norm(x, 2, 1)
8585 expected = torch.sqrt(x.pow(2).sum(1))
8586 self.assertEqual(result, expected)
8589 result = torch.norm(x, 3, 1)
8590 expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0)
8591 self.assertEqual(result, expected)
8594 def _test_bernoulli(self, t_dtype, p_dtype, device):
8595 for trivial_p
in ([0, 1], [1, 0, 1, 1, 0, 1]):
8596 x =
torch.tensor(trivial_p, dtype=p_dtype, device=device)
8597 self.assertEqual(x.bernoulli().tolist(), trivial_p)
8600 return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0
8602 p = torch.rand(5, 5, dtype=p_dtype, device=device)
8603 self.assertTrue(isBinary(p.bernoulli()))
8605 p = torch.rand(5, dtype=p_dtype, device=device).expand(5, 5)
8606 self.assertTrue(isBinary(p.bernoulli()))
8608 p = torch.rand(5, 5, dtype=p_dtype, device=device)
8609 torch.bernoulli(torch.rand_like(p), out=p)
8610 self.assertTrue(isBinary(p))
8612 p = torch.rand(5, dtype=p_dtype, device=device).expand(5, 5)
8613 torch.bernoulli(torch.rand_like(p), out=p)
8614 self.assertTrue(isBinary(p))
8616 t = torch.empty(10, 10, dtype=t_dtype, device=device)
8620 self.assertTrue(isBinary(t))
8622 p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10)
8625 self.assertTrue(isBinary(t))
8628 torch.bernoulli(torch.rand_like(t, dtype=p_dtype), out=t)
8629 self.assertTrue(isBinary(t))
8632 t.bernoulli_(torch.rand_like(t, dtype=p_dtype))
8633 self.assertTrue(isBinary(t))
8635 def test_bernoulli(self):
8636 self._test_bernoulli(self, torch.float32, torch.float64,
'cpu')
8638 self._test_bernoulli(self, torch.uint8, torch.float64,
'cpu')
8640 def test_normal(self):
8641 q = torch.Tensor(100, 100)
8643 self.assertEqual(q.mean(), 0, 0.2)
8644 self.assertEqual(q.std(), 1, 0.2)
8647 self.assertEqual(q.mean(), 2, 0.3)
8648 self.assertEqual(q.std(), 3, 0.3)
8650 q = torch.Tensor(100, 100)
8651 q_row1 = q[0:1].clone()
8653 self.assertEqual(q[99:100].mean(), 0, 0.2)
8654 self.assertEqual(q[99:100].std(), 1, 0.2)
8655 self.assertEqual(q[0:1].clone(), q_row1)
8657 mean = torch.Tensor(100, 100)
8658 std = torch.Tensor(100, 100)
8664 r = torch.normal(mean)
8665 self.assertEqual(r[:50].mean(), 0, 0.2)
8666 self.assertEqual(r[50:].mean(), 1, 0.2)
8667 self.assertEqual(r.std(), 1, 0.2)
8669 r = torch.normal(mean, 3)
8670 self.assertEqual(r[:50].mean(), 0, 0.2)
8671 self.assertEqual(r[50:].mean(), 1, 0.2)
8672 self.assertEqual(r.std(), 3, 0.2)
8674 r = torch.normal(2, std)
8675 self.assertEqual(r.mean(), 2, 0.2)
8676 self.assertEqual(r[:, :50].std(), 4, 0.3)
8677 self.assertEqual(r[:, 50:].std(), 1, 0.2)
8679 r = torch.normal(mean, std)
8680 self.assertEqual(r[:50].mean(), 0, 0.2)
8681 self.assertEqual(r[50:].mean(), 1, 0.2)
8682 self.assertEqual(r[:, :50].std(), 4, 0.3)
8683 self.assertEqual(r[:, 50:].std(), 1, 0.2)
8685 def test_parsing_int64(self):
8687 x = torch.cumsum(torch.ones(5, 5), 0)
8688 self.assertEqual(x, torch.cumsum(torch.ones(5, 5),
torch.tensor(0)))
8690 self.assertRaises(TypeError,
lambda: torch.cumsum(torch.ones(5, 5),
torch.tensor(0.)))
8692 def test_parsing_double(self):
8694 x = torch.randn(2, 3)
8695 torch.isclose(x, x, 1, 1)
8696 self.assertTrue(torch.isclose(x, x, 1, 1).all())
8697 self.assertTrue(torch.isclose(x, x, 1.5, 1.).all())
8702 self.assertRaises(TypeError,
8705 def test_parsing_intlist(self):
8711 self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape)
8712 self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape)
8713 self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape)
8714 self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape)
8720 self.assertRaises(TypeError,
lambda: torch.ones((np.float(3.),
torch.tensor(4))))
8721 self.assertRaises(TypeError,
lambda: torch.ones((np.array(3.),
torch.tensor(4))))
8724 self.assertRaises(TypeError,
lambda: torch.ones(
torch.tensor(3, 3)))
8725 self.assertRaises(TypeError,
lambda: torch.ones((
torch.tensor(3, 3))))
8727 self.assertRaises(TypeError,
lambda: torch.ones(np.array(3, 3)))
8728 self.assertRaises(TypeError,
lambda: torch.ones((np.array(3, 3))))
8731 self.assertRaisesRegex(TypeError,
8732 "received an invalid combination of arguments",
8733 lambda: torch.LongTensor((6, 0), 1, 1, 0))
8734 self.assertRaisesRegex(TypeError,
8735 "missing 1 required positional arguments",
8738 def _test_serialization_data(self):
8739 a = [torch.randn(5, 5).float()
for i
in range(2)]
8740 b = [a[i % 2]
for i
in range(4)]
8742 b += [a[0].reshape(-1)[1:4].
storage()]
8743 b += [torch.arange(1, 11).int()]
8744 t1 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().
storage(), 0, (3,), (1,))
8745 t2 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().
storage(), 0, (3,), (1,))
8746 b += [(t1.storage(), t1.storage(), t2.storage())]
8747 b += [a[0].reshape(-1)[0:2].
storage()]
8750 def _test_serialization_assert(self, b, c):
8751 self.assertEqual(b, c, 0)
8752 self.assertTrue(isinstance(c[0], torch.FloatTensor))
8753 self.assertTrue(isinstance(c[1], torch.FloatTensor))
8754 self.assertTrue(isinstance(c[2], torch.FloatTensor))
8755 self.assertTrue(isinstance(c[3], torch.FloatTensor))
8758 self.assertEqual(c[0], c[2], 0)
8761 self.assertEqual(c[1], c[3], 0)
8765 self.assertEqual(c[4][i + 1], c[5][i])
8770 self.assertEqual(views[0]._cdata, views[1]._cdata)
8771 self.assertEqual(views[0], views[2])
8772 self.assertNotEqual(views[0]._cdata, views[2]._cdata)
8775 self.assertEqual(rootview.data_ptr(), c[0].data_ptr())
8777 def test_serialization(self):
8779 b = self._test_serialization_data()
8780 for use_name
in (
False,
True):
8783 if sys.platform ==
"win32" and use_name:
8785 with tempfile.NamedTemporaryFile()
as f:
8786 handle = f
if not use_name
else f.name
8787 torch.save(b, handle)
8789 c = torch.load(handle)
8790 self._test_serialization_assert(b, c)
8797 b
'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.' 8798 b
'\x80\x02}q\x01(U\x10protocol_versionq\x02M\xe9\x03U\n' 8799 b
'type_sizesq\x03}q\x04(U\x03intq\x05K\x04U\x05shortq\x06K\x02U' 8800 b
'\x04longq\x07K\x04uU\rlittle_endianq\x08\x88u.\x80\x02]q' 8801 b
'\x01(U\x0e\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85' 8802 b
'\xc5\xbcq\x02ctorch._utils\n_rebuild_tensor_v2\nq\x03((U' 8803 b
'\x07storageq\x04ctorch\nFloatStorage\nq\x05U\x0845640624q' 8804 b
'\x06U\x03cpuq\x07\x8a\x01\x01NtQK\x00K\x01\x85K\x01\x85' 8805 b
'\x89NtRq\x08K\x02e.\x80\x02]q\x01U\x0845640624q\x02a.\x01\x00' 8806 b
'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' 8808 buf = io.BytesIO(serialized)
8809 utf8_bytes = b
'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc' 8810 utf8_str = utf8_bytes.decode(
'utf-8')
8812 with self.assertRaisesRegex(UnicodeDecodeError,
"'ascii' codec can't decode byte"):
8813 loaded = torch.load(buf)
8815 loaded_utf8 = torch.load(buf, encoding=
'utf-8')
8816 self.assertEqual(loaded_utf8, [utf8_str, torch.zeros(1, dtype=torch.float), 2])
8818 loaded_bytes = torch.load(buf, encoding=
'bytes')
8820 loaded_bytes = torch.load(buf)
8821 self.assertEqual(loaded_bytes, [utf8_bytes, torch.zeros(1, dtype=torch.float), 2])
8823 def test_serialization_filelike(self):
8825 b = self._test_serialization_data()
8826 with BytesIOContext()
as f:
8830 self._test_serialization_assert(b, c)
8832 def test_serialization_gzip(self):
8834 b = self._test_serialization_data()
8835 f1 = tempfile.NamedTemporaryFile(delete=
False)
8836 f2 = tempfile.NamedTemporaryFile(delete=
False)
8838 with open(f1.name,
'rb')
as f_in, gzip.open(f2.name,
'wb')
as f_out:
8839 shutil.copyfileobj(f_in, f_out)
8841 with gzip.open(f2.name,
'rb')
as f:
8843 self._test_serialization_assert(b, c)
8845 def test_serialization_offset(self):
8846 a = torch.randn(5, 5)
8848 for use_name
in (
False,
True):
8851 if sys.platform ==
"win32" and use_name:
8853 with tempfile.NamedTemporaryFile()
as f:
8854 handle = f
if not use_name
else f.name
8860 self.assertTrue(torch.equal(a, b))
8861 self.assertEqual(i, j)
8863 def test_serialization_offset_filelike(self):
8864 a = torch.randn(5, 5)
8866 with BytesIOContext()
as f:
8872 self.assertTrue(torch.equal(a, b))
8873 self.assertEqual(i, j)
8875 def test_serialization_offset_gzip(self):
8876 a = torch.randn(5, 5)
8878 f1 = tempfile.NamedTemporaryFile(delete=
False)
8879 f2 = tempfile.NamedTemporaryFile(delete=
False)
8880 with open(f1.name,
'wb')
as f:
8883 with open(f1.name,
'rb')
as f_in, gzip.open(f2.name,
'wb')
as f_out:
8884 shutil.copyfileobj(f_in, f_out)
8886 with gzip.open(f2.name,
'rb')
as f:
8889 self.assertTrue(torch.equal(a, b))
8890 self.assertEqual(i, j)
8892 def test_half_tensor(self):
8893 x = torch.randn(5, 5).float()
8894 y = torch.randn(5, 5).float()
8895 xh, yh = x.half(), y.half()
8897 self.assertEqual(x.half().float(), x, 1e-3)
8899 z = torch.Tensor(5, 5)
8900 self.assertEqual(z.copy_(xh), x, 1e-3)
8902 with tempfile.NamedTemporaryFile()
as f:
8906 self.assertEqual(xh.float(), xh2.float())
8908 def test_serialize_device(self):
8909 device_str = [
'cpu',
'cpu:0',
'cuda',
'cuda:0']
8910 device_obj = [torch.device(d)
for d
in device_str]
8911 for device
in device_obj:
8912 device_copied = copy.deepcopy(device)
8913 self.assertEqual(device, device_copied)
8916 def test_half_tensor_cuda(self):
8917 x = torch.randn(5, 5).half()
8918 self.assertEqual(x.cuda(), x)
8921 with tempfile.NamedTemporaryFile()
as f:
8925 self.assertIsInstance(xc2, type(xc))
8926 self.assertEqual(xc.float(), xc2.float())
8928 def _test_serialization_cuda(self, filecontext_lambda):
8930 t0 = torch.cuda.FloatTensor(5).fill_(1)
8932 tn = torch.cuda.FloatTensor(3).fill_(2)
8935 with filecontext_lambda()
as f:
8939 self.assertEqual(b, c, 0)
8941 self.assertEqual(u0.get_device(), 0)
8942 self.assertEqual(un.get_device(), device_count - 1)
8945 def test_serialization_cuda(self):
8946 self._test_serialization_cuda(tempfile.NamedTemporaryFile)
8949 def test_serialization_cuda_filelike(self):
8950 self._test_serialization_cuda(BytesIOContext)
8952 def test_serialization_backwards_compat(self):
8953 a = [torch.arange(1 + i, 26 + i).view(5, 5).float()
for i
in range(2)]
8954 b = [a[i % 2]
for i
in range(4)]
8956 b += [a[0].reshape(-1)[1:4].clone().
storage()]
8957 path = download_file(
'https://download.pytorch.org/test_data/legacy_serialized.pt')
8958 c = torch.load(path)
8959 self.assertEqual(b, c, 0)
8960 self.assertTrue(isinstance(c[0], torch.FloatTensor))
8961 self.assertTrue(isinstance(c[1], torch.FloatTensor))
8962 self.assertTrue(isinstance(c[2], torch.FloatTensor))
8963 self.assertTrue(isinstance(c[3], torch.FloatTensor))
8966 self.assertEqual(c[0], c[2], 0)
8969 self.assertEqual(c[1], c[3], 0)
8972 class OldTensorBase(object):
8973 def __init__(self, new_tensor):
8974 self.new_tensor = new_tensor
8976 def __getstate__(self):
8977 return (self.new_tensor.storage(),
8978 self.new_tensor.storage_offset(),
8979 tuple(self.new_tensor.size()),
8980 self.new_tensor.stride())
8982 class OldTensorV1(OldTensorBase):
8983 def __reduce__(self):
8984 return (torch.Tensor, (), self.__getstate__())
8986 class OldTensorV2(OldTensorBase):
8987 def __reduce__(self):
8988 return (_rebuild_tensor, self.__getstate__())
8990 x = torch.randn(30).as_strided([2, 3], [9, 3], 2)
8991 for old_cls
in [OldTensorV1, OldTensorV2]:
8992 with tempfile.NamedTemporaryFile()
as f:
8994 torch.save(old_x, f)
8996 load_x = torch.load(f)
8997 self.assertEqual(x.storage(), load_x.storage())
8998 self.assertEqual(x.storage_offset(), load_x.storage_offset())
8999 self.assertEqual(x.size(), load_x.size())
9000 self.assertEqual(x.stride(), load_x.stride())
9004 def _test_serialization_container(self, unique_key, filecontext_lambda):
9005 tmpmodule_name =
'tmpmodule{}'.format(unique_key)
9007 def import_module(name, filename):
9008 if sys.version_info >= (3, 5):
9009 import importlib.util
9010 spec = importlib.util.spec_from_file_location(name, filename)
9011 module = importlib.util.module_from_spec(spec)
9012 spec.loader.exec_module(module)
9015 module = imp.load_source(name, filename)
9016 sys.modules[module.__name__] = module
9019 with filecontext_lambda()
as checkpoint:
9020 fname = get_file_path_2(os.path.dirname(__file__),
'data',
'network1.py')
9021 module = import_module(tmpmodule_name, fname)
9022 torch.save(module.Net(), checkpoint)
9026 with warnings.catch_warnings(record=
True)
as w:
9027 loaded = torch.load(checkpoint)
9028 self.assertTrue(isinstance(loaded, module.Net))
9029 if can_retrieve_source:
9030 self.assertEquals(len(w), 0)
9033 fname = get_file_path_2(os.path.dirname(__file__),
'data',
'network2.py')
9034 module = import_module(tmpmodule_name, fname)
9036 with warnings.catch_warnings(record=
True)
as w:
9037 loaded = torch.load(checkpoint)
9038 self.assertTrue(isinstance(loaded, module.Net))
9039 if can_retrieve_source:
9040 self.assertEquals(len(w), 1)
9041 self.assertTrue(w[0].category,
'SourceChangeWarning')
9043 def test_serialization_container(self):
9044 self._test_serialization_container(
'file', tempfile.NamedTemporaryFile)
9046 def test_serialization_container_filelike(self):
9047 self._test_serialization_container(
'filelike', BytesIOContext)
9049 def test_serialization_map_location(self):
9050 test_file_path = download_file(
'https://download.pytorch.org/test_data/gpu_tensors.pt')
9052 def map_location(storage, loc):
9056 with open(test_file_path,
'rb')
as f:
9057 return io.BytesIO(f.read())
9059 fileobject_lambdas = [
lambda: test_file_path, load_bytes]
9060 cpu_map_locations = [
9064 torch.device(
'cpu'),
9066 gpu_0_map_locations = [
9067 {
'cuda:0':
'cuda:0'},
9070 torch.device(
'cuda'),
9071 torch.device(
'cuda', 0)
9073 gpu_last_map_locations = [
9077 def check_map_locations(map_locations, tensor_class, intended_device):
9078 for fileobject_lambda
in fileobject_lambdas:
9079 for map_location
in map_locations:
9080 tensor = torch.load(fileobject_lambda(), map_location=map_location)
9082 self.assertEqual(tensor.device, intended_device)
9083 self.assertIsInstance(tensor, tensor_class)
9084 self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]]))
9086 check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device(
'cpu'))
9088 check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device(
'cuda', 0))
9089 check_map_locations(
9090 gpu_last_map_locations,
9091 torch.cuda.FloatTensor,
9096 @unittest.skipIf(
not PY3,
"Test tensors were serialized using python 3")
9097 def test_load_nonexistent_device(self):
9100 serialized = (b
'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9' 9101 b
'\x03.\x80\x02}q\x00(X\x10\x00\x00\x00protocol_versionq' 9102 b
'\x01M\xe9\x03X\r\x00\x00\x00little_endianq\x02\x88X\n' 9103 b
'\x00\x00\x00type_sizesq\x03}q\x04(X\x05\x00\x00\x00shortq' 9104 b
'\x05K\x02X\x03\x00\x00\x00intq\x06K\x04X\x04\x00\x00\x00' 9105 b
'longq\x07K\x04uu.\x80\x02ctorch._utils\n_rebuild_tensor_v2' 9106 b
'\nq\x00((X\x07\x00\x00\x00storageq\x01ctorch\nFloatStorage' 9107 b
'\nq\x02X\x0e\x00\x00\x0094919395964320q\x03X\x06\x00\x00' 9108 b
'\x00cuda:0q\x04K\x02Ntq\x05QK\x00K\x02\x85q\x06K\x01\x85q' 9109 b
'\x07\x89Ntq\x08Rq\t.\x80\x02]q\x00X\x0e\x00\x00\x00' 9110 b
'94919395964320q\x01a.\x02\x00\x00\x00\x00\x00\x00\x00\xbb' 9111 b
'\x1f\x82\xbe\xea\x81\xd1>')
9113 buf = io.BytesIO(serialized)
9115 error_msg =
r'Attempting to deserialize object on a CUDA device' 9116 with self.assertRaisesRegex(RuntimeError, error_msg):
9119 def test_serialization_filelike_api_requirements(self):
9120 filemock = FilelikeMock(b
'', has_readinto=
False)
9121 tensor = torch.randn(3, 5)
9122 torch.save(tensor, filemock)
9123 expected_superset = {
'write',
'flush'}
9124 self.assertTrue(expected_superset.issuperset(filemock.calls))
9128 filemock.calls.clear()
9130 _ = torch.load(filemock)
9131 expected_superset = {
'read',
'readline',
'seek',
'tell'}
9132 self.assertTrue(expected_superset.issuperset(filemock.calls))
9134 def _test_serialization_filelike(self, tensor, mock, desc):
9136 torch.save(tensor, f)
9138 data = mock(f.read())
9140 msg =
'filelike serialization with {}' 9142 b = torch.load(data)
9143 self.assertTrue(torch.equal(tensor, b), msg.format(desc))
9145 def test_serialization_filelike_missing_attrs(self):
9150 (
'no readinto',
lambda x: FilelikeMock(x)),
9151 (
'has readinto',
lambda x: FilelikeMock(x, has_readinto=
True)),
9152 (
'no fileno',
lambda x: FilelikeMock(x, has_fileno=
False)),
9155 to_serialize = torch.randn(3, 10)
9156 for desc, mock
in mocks:
9157 self._test_serialization_filelike(to_serialize, mock, desc)
9159 def test_serialization_filelike_stress(self):
9160 a = torch.randn(11 * (2 ** 9) + 1, 5 * (2 ** 9))
9163 self._test_serialization_filelike(a,
lambda x: FilelikeMock(x, has_readinto=
False),
9164 'read() stress test')
9165 self._test_serialization_filelike(a,
lambda x: FilelikeMock(x, has_readinto=
True),
9166 'readinto() stress test')
9168 def test_serialization_filelike_uses_readinto(self):
9171 a = torch.randn(5, 4)
9176 data = FilelikeMock(f.read(), has_readinto=
True)
9178 b = torch.load(data)
9179 self.assertTrue(data.was_called(
'readinto'))
9181 def test_serialization_storage_slice(self):
9190 serialized = (b
'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03' 9191 b
'.\x80\x02}q\x00(X\n\x00\x00\x00type_sizesq\x01}q\x02(X\x03' 9192 b
'\x00\x00\x00intq\x03K\x04X\x05\x00\x00\x00shortq\x04K\x02X' 9193 b
'\x04\x00\x00\x00longq\x05K\x04uX\x10\x00\x00\x00protocol_versionq' 9194 b
'\x06M\xe9\x03X\r\x00\x00\x00little_endianq\x07\x88u.\x80\x02' 9195 b
'(X\x07\x00\x00\x00storageq\x00ctorch\nFloatStorage\nq\x01X\x0e' 9196 b
'\x00\x00\x0094279043900432q\x02X\x03\x00\x00\x00cpuq\x03K\x02' 9197 b
'X\x0e\x00\x00\x0094279029750368q\x04K\x00K\x01\x87q\x05tq\x06' 9198 b
'Q(h\x00h\x01X\x0e\x00\x00\x0094279043900432q\x07h\x03K\x02X' 9199 b
'\x0e\x00\x00\x0094279029750432q\x08K\x01K\x01\x87q\ttq\nQ' 9200 b
'\x86q\x0b.\x80\x02]q\x00X\x0e\x00\x00\x0094279043900432q' 9201 b
'\x01a.\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' 9202 b
'\x00\x00\x00\x00')
9204 buf = io.BytesIO(serialized)
9205 (s1, s2) = torch.load(buf)
9206 self.assertEqual(s1[0], 0)
9207 self.assertEqual(s2[0], 0)
9208 self.assertEqual(s1.data_ptr() + 4, s2.data_ptr())
9210 def test_load_error_msg(self):
9211 expected_err_msg = (
".*You can only torch.load from a file that is seekable. " +
9212 "Please pre-load the data into a buffer like io.BytesIO and " +
9213 "try to load from it instead.")
9215 resource = FilelikeMock(data=b
"data")
9216 delattr(resource,
"tell")
9217 delattr(resource,
"seek")
9218 self.assertRaisesRegex(AttributeError, expected_err_msg,
lambda: torch.load(resource))
9220 def test_from_buffer(self):
9221 a = bytearray([1, 2, 3, 4])
9222 self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
9223 shorts = torch.ShortStorage.from_buffer(a,
'big')
9224 self.assertEqual(shorts.size(), 2)
9225 self.assertEqual(shorts.tolist(), [258, 772])
9226 ints = torch.IntStorage.from_buffer(a,
'little')
9227 self.assertEqual(ints.size(), 1)
9228 self.assertEqual(ints[0], 67305985)
9229 f = bytearray([0x40, 0x10, 0x00, 0x00])
9230 floats = torch.FloatStorage.from_buffer(f,
'big')
9231 self.assertEqual(floats.size(), 1)
9232 self.assertEqual(floats[0], 2.25)
9234 f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
9235 bools = torch.BoolStorage.from_buffer(f,
'big')
9236 self.assertEqual(bools.size(), 8)
9237 self.assertEqual(bools.tolist(), [
False,
True,
True,
True,
True,
True,
True,
True])
9238 self.assertEqual(bools.type(),
'torch.BoolStorage')
9240 f = bytearray(b
'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9')
9241 bools = torch.BoolStorage.from_buffer(f,
'big')
9242 self.assertEqual(bools.size(), 19)
9244 f = bytearray(b
'\0x4A')
9245 bools = torch.BoolStorage.from_buffer(f,
'big')
9246 self.assertEqual(bools.size(), 4)
9247 self.assertEqual(bools.tolist(), [
False,
True,
True,
True])
9249 def test_storage_casts(self):
9251 self.assertEqual(storage.size(), 6)
9252 self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4])
9253 self.assertEqual(storage.type(),
'torch.IntStorage')
9255 floatStorage = storage.float()
9256 self.assertEqual(floatStorage.size(), 6)
9257 self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4])
9258 self.assertEqual(floatStorage.type(),
'torch.FloatStorage')
9259 self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
9261 halfStorage = storage.half()
9262 self.assertEqual(halfStorage.size(), 6)
9263 self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4])
9264 self.assertEqual(halfStorage.type(),
'torch.HalfStorage')
9265 self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
9267 longStorage = storage.long()
9268 self.assertEqual(longStorage.size(), 6)
9269 self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
9270 self.assertEqual(longStorage.type(),
'torch.LongStorage')
9271 self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
9273 shortStorage = storage.short()
9274 self.assertEqual(shortStorage.size(), 6)
9275 self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4])
9276 self.assertEqual(shortStorage.type(),
'torch.ShortStorage')
9277 self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
9279 doubleStorage = storage.double()
9280 self.assertEqual(doubleStorage.size(), 6)
9281 self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
9282 self.assertEqual(doubleStorage.type(),
'torch.DoubleStorage')
9283 self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
9285 charStorage = storage.char()
9286 self.assertEqual(charStorage.size(), 6)
9287 self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
9288 self.assertEqual(charStorage.type(),
'torch.CharStorage')
9289 self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
9291 byteStorage = storage.byte()
9292 self.assertEqual(byteStorage.size(), 6)
9293 self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4])
9294 self.assertEqual(byteStorage.type(),
'torch.ByteStorage')
9295 self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4])
9297 boolStorage = storage.bool()
9298 self.assertEqual(boolStorage.size(), 6)
9299 self.assertEqual(boolStorage.tolist(), [
True,
False,
True,
True,
True,
True])
9300 self.assertEqual(boolStorage.type(),
'torch.BoolStorage')
9301 self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1])
9303 @unittest.skipIf(IS_WINDOWS,
"TODO: need to fix this test case for Windows")
9304 def test_from_file(self):
9306 with tempfile.NamedTemporaryFile()
as f:
9307 s1 = torch.FloatStorage.from_file(f.name,
True, size)
9308 t1 = torch.FloatTensor(s1).copy_(torch.randn(size))
9311 s2 = torch.FloatStorage.from_file(f.name,
True, size)
9312 t2 = torch.FloatTensor(s2)
9313 self.assertEqual(t1, t2, 0)
9316 rnum = random.uniform(-1, 1)
9318 self.assertEqual(t1, t2, 0)
9321 rnum = random.uniform(-1, 1)
9323 self.assertEqual(t1, t2, 0)
9325 def test_print(self):
9326 default_type = torch.Tensor().type()
9327 for t
in torch._tensor_classes:
9328 if t == torch.HalfTensor:
9334 obj = t(100, 100).fill_(1)
9338 obj = torch.rand(100, 100, device=
'cpu').half()
9341 for t
in torch._storage_classes:
9345 obj = t(100).fill_(
True)
9347 obj = t(100).fill_(1)
9353 self.assertEqual(x.__repr__(), str(x))
9354 self.assertExpectedInline(str(x),
'''tensor(2341234123412341)''')
9358 self.assertEqual(x.__repr__(), str(x))
9359 self.assertExpectedInline(str(x),
'''tensor([1.0000e+28, 1.0000e-28])''')
9363 torch.set_printoptions(sci_mode=
True)
9364 self.assertEqual(x.__repr__(), str(x))
9365 self.assertExpectedInline(str(x),
'''tensor([1.0000e+02, 1.0000e-02])''')
9366 torch.set_printoptions(sci_mode=
False)
9367 self.assertEqual(x.__repr__(), str(x))
9368 self.assertExpectedInline(str(x),
'''tensor([ 100.0000, 0.0100])''')
9369 torch.set_printoptions(sci_mode=
None)
9373 self.assertEqual(x.__repr__(), str(x))
9374 self.assertExpectedInline(str(x),
'''tensor([1, 2])''')
9378 self.assertEqual(x.__repr__(), str(x))
9379 self.assertExpectedInline(str(x),
'''tensor([ 1, -2])''')
9383 self.assertEqual(x.__repr__(), str(x))
9384 self.assertExpectedInline(str(x),
'''tensor([4.0000, inf, 1.5000, -inf, 0.0000, nan, 1.0000])''')
9388 x =
torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64)
9389 self.assertEqual(x.__repr__(), str(x))
9391 tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, 9392 inf], dtype=torch.float64)''' 9393 self.assertExpectedInline(str(x), expected_str)
9397 self.assertEqual(x.__repr__(), str(x))
9399 tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, 9401 self.assertExpectedInline(str(x), expected_str)
9404 x = torch.zeros(10000)
9405 self.assertEqual(x.__repr__(), str(x))
9406 self.assertExpectedInline(str(x),
'''tensor([0., 0., 0., ..., 0., 0., 0.])''')
9409 x = torch.rand(1, 20, 5, 30)
9411 self.assertEqual(summary.shape, (1, 6, 5, 6))
9412 first_and_last = [0, 1, 2, -3, -2, -1]
9413 self.assertEqual(summary, x[:, first_and_last][..., first_and_last])
9418 self.assertEqual(x.__repr__(), str(x))
9419 self.assertExpectedInline(str(x),
'''tensor([123], device='cuda:0')''')
9423 self.assertEqual(x.__repr__(), str(x))
9424 self.assertExpectedInline(str(x),
'''tensor([123])''')
9429 self.assertEqual(x.__repr__(), str(x))
9430 self.assertExpectedInline(str(x),
'''tensor([123.], requires_grad=True)''')
9434 x = torch.ones(100, 2, 2, 10)
9435 y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1))
9436 self.assertEqual(str(y), y.__repr__())
9438 tensor([[[1., 1., 1., ..., 1., 1., 1.], 9439 [1., 1., 1., ..., 1., 1., 1.]], 9441 [[1., 1., 1., ..., 1., 1., 1.], 9442 [1., 1., 1., ..., 1., 1., 1.]], 9444 [[1., 1., 1., ..., 1., 1., 1.], 9445 [1., 1., 1., ..., 1., 1., 1.]], 9449 [[1., 1., 1., ..., 1., 1., 1.], 9450 [1., 1., 1., ..., 1., 1., 1.]], 9452 [[1., 1., 1., ..., 1., 1., 1.], 9453 [1., 1., 1., ..., 1., 1., 1.]], 9455 [[1., 1., 1., ..., 1., 1., 1.], 9456 [1., 1., 1., ..., 1., 1., 1.]]])\ 9459 self.assertExpectedInline(str(y), expected_str)
9463 self.assertEqual(x.__repr__(), str(x))
9464 self.assertExpectedInline(str(x),
'''tensor(2.0000e-05)''')
9468 self.assertEqual(x.__repr__(), str(x))
9469 self.assertExpectedInline(str(x),
'''tensor([2.0000e-05])''')
9475 self.assertEqual(x.__repr__(), str(x))
9476 self.assertExpectedInline(str(x),
'''tensor([1.2346e+08])''')
9480 self.assertEqual(x.__repr__(), str(x))
9481 self.assertExpectedInline(str(x),
'''tensor([1.0000e-02, 1.1000e+01])''')
9485 self.assertEqual(x.__repr__(), str(x))
9486 self.assertExpectedInline(str(x),
'''tensor([ 1, 1010])''')
9490 self.assertEqual(x.__repr__(), str(x))
9491 self.assertExpectedInline(str(x),
'''tensor([1000000000])''')
9495 self.assertEqual(x.__repr__(), str(x))
9496 self.assertExpectedInline(str(x),
'''tensor([ 1., 1000.])''')
9500 self.assertEqual(x.__repr__(), str(x))
9501 self.assertExpectedInline(str(x),
'''tensor([1.0000e+00, 1.0100e+03])''')
9503 def test_sizeof(self):
9504 sizeof_empty = torch.randn(0).
storage().__sizeof__()
9505 sizeof_10 = torch.randn(10).
storage().__sizeof__()
9506 sizeof_100 = torch.randn(100).
storage().__sizeof__()
9507 self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
9508 self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
9510 sizeof_empty = torch.randn(0).type(torch.ByteTensor).
storage().__sizeof__()
9511 sizeof_10 = torch.randn(10).type(torch.ByteTensor).
storage().__sizeof__()
9512 sizeof_100 = torch.randn(100).type(torch.ByteTensor).
storage().__sizeof__()
9513 self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
9514 self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
9516 def test_unsqueeze(self):
9517 x = torch.randn(2, 3, 4)
9519 self.assertEqual(y, x.view(2, 1, 3, 4))
9520 y = x.clone().unsqueeze_(2)
9521 self.assertEqual(y, x.view(2, 3, 1, 4))
9524 self.assertFalse(x.is_contiguous())
9526 self.assertEqual(y, x.contiguous().view(2, 1, 4))
9527 y = x.clone().unsqueeze_(2)
9528 self.assertEqual(y, x.contiguous().view(2, 4, 1))
9530 def test_iter(self):
9531 x = torch.randn(5, 5)
9532 for i, sub
in enumerate(x):
9533 self.assertEqual(sub, x[i])
9536 self.assertEqual(list(x), [])
9538 def test_accreal_type(self):
9539 x = torch.ones(2, 3, 4)
9540 self.assertIsInstance(x.double().sum().item(), float)
9541 self.assertIsInstance(x.float().sum().item(), float)
9542 self.assertIsInstance(x.long().sum().item(), int)
9543 self.assertIsInstance(x.int().sum().item(), int)
9544 self.assertIsInstance(x.short().sum().item(), int)
9545 self.assertIsInstance(x.char().sum().item(), int)
9546 self.assertIsInstance(x.byte().sum().item(), int)
9548 def test_assertEqual(self):
9549 x = torch.FloatTensor([0])
9550 self.assertEqual(x, 0)
9551 xv = torch.autograd.Variable(x)
9552 self.assertEqual(xv, 0)
9553 self.assertEqual(x, xv)
9554 self.assertEqual(xv, x)
9557 x = torch.autograd.Variable(torch.Tensor())
9558 y = torch.autograd.Variable(torch.randn(4, 4))
9559 z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
9560 self.assertEqual(x.new().shape, [0])
9561 self.assertEqual(x.new(), x)
9562 self.assertEqual(x.new(1, 2).shape, [1, 2])
9563 self.assertEqual(x.new(torch.Size([3, 4])).shape, [3, 4])
9564 self.assertEqual(x.new([3, 4]).shape, [2])
9565 self.assertEqual(x.new([3, 4]).tolist(), [3, 4])
9566 self.assertEqual(x.new((3, 4)).tolist(), [3, 4])
9568 self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4])
9569 self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4])
9570 self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4])
9571 self.assertEqual(x.new(size=(3, 4)).shape, [3, 4])
9572 self.assertEqual(x.new(()).shape, [0])
9573 self.assertEqual(x.new(y.storage()).data_ptr(), y.data_ptr())
9574 self.assertEqual(x.new(y).data_ptr(), y.data_ptr())
9575 self.assertIsNot(x.new(y), y)
9577 self.assertRaises(TypeError,
lambda: x.new(z))
9579 self.assertRaises(RuntimeError,
lambda: x.new(z.storage()))
9581 def test_empty_like(self):
9582 x = torch.autograd.Variable(torch.Tensor())
9583 y = torch.autograd.Variable(torch.randn(4, 4))
9584 z = torch.autograd.Variable(torch.IntTensor([1, 2, 3]))
9586 self.assertEqual(torch.empty_like(a).shape, a.shape)
9587 self.assertEqual(torch.empty_like(a).type(), a.type())
9589 def test_empty_strided(self):
9591 for device
in devices:
9592 for shape
in [(2, 3, 4), (0, 2, 0)]:
9595 for strides
in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]:
9596 empty_strided = torch.empty_strided(shape, strides, device=device)
9600 as_strided = torch.empty(empty_strided.storage().size(),
9601 device=device).as_strided(shape, strides)
9602 self.assertEqual(empty_strided.shape, as_strided.shape)
9603 self.assertEqual(empty_strided.stride(), as_strided.stride())
9606 def test_pin_memory(self):
9607 x = torch.randn(3, 5)
9608 self.assertFalse(x.is_pinned())
9609 pinned = x.pin_memory()
9610 self.assertTrue(pinned.is_pinned())
9611 self.assertEqual(pinned, x)
9612 self.assertNotEqual(pinned.data_ptr(), x.data_ptr())
9614 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
9615 def test_numpy_unresizable(self):
9616 x = np.zeros((2, 2))
9617 y = torch.from_numpy(x)
9618 with self.assertRaises(ValueError):
9621 z = torch.randn(5, 5)
9623 with self.assertRaises(RuntimeError):
9625 with self.assertRaises(ValueError):
9628 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
9629 def test_to_numpy(self):
9630 def get_castable_tensor(shape, tp):
9632 if dtype.is_floating_point:
9633 dtype_info = torch.finfo(dtype)
9636 low = max(dtype_info.min, -1e10)
9637 high = min(dtype_info.max, 1e10)
9638 t = torch.empty(shape, dtype=torch.float64).uniform_(low, high)
9642 dtype_info = torch.iinfo(dtype)
9643 low = max(dtype_info.min, int(-1e10))
9644 high = min(dtype_info.max, int(1e10))
9645 dtype_info = torch.iinfo(dtype)
9646 t = torch.empty(shape, dtype=torch.int64).random_(low, high)
9662 x = get_castable_tensor(sz, tp)
9665 self.assertEqual(x[i], y[i])
9668 xm = get_castable_tensor(sz * 2, tp)
9669 x = xm.narrow(0, sz - 1, sz)
9670 self.assertTrue(x.storage_offset() > 0)
9673 self.assertEqual(x[i], y[i])
9676 for i
in range(sz1):
9677 for j
in range(sz2):
9678 self.assertEqual(x[i][j], y[i][j])
9681 x = torch.Tensor().type(tp)
9683 self.assertEqual(y.size, 0)
9688 x = get_castable_tensor((sz1, sz2), tp)
9691 self.assertTrue(y.flags[
'C_CONTIGUOUS'])
9694 xm = get_castable_tensor((sz1 * 2, sz2), tp)
9695 x = xm.narrow(0, sz1 - 1, sz1)
9697 self.assertTrue(x.storage_offset() > 0)
9699 self.assertTrue(y.flags[
'C_CONTIGUOUS'])
9702 x = get_castable_tensor((sz2, sz1), tp).t()
9705 self.assertFalse(y.flags[
'C_CONTIGUOUS'])
9708 xm = get_castable_tensor((sz2 * 2, sz1), tp)
9709 x = xm.narrow(0, sz2 - 1, sz2).t()
9711 self.assertTrue(x.storage_offset() > 0)
9715 xm = get_castable_tensor((sz2 * 2, sz1 * 2), tp)
9716 x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t()
9718 self.assertTrue(x.storage_offset() > 0)
9721 if tp != torch.HalfTensor:
9723 x = get_castable_tensor((3, 4), tp)
9725 self.assertTrue(y.flags.writeable)
9727 self.assertTrue(x[0][1] == 3)
9729 self.assertTrue(y.flags.writeable)
9731 self.assertTrue(x[0][1] == 3)
9733 def test_dlpack_conversion(self):
9734 x = torch.randn(1, 2, 3, 4).type(
'torch.FloatTensor')
9735 z = from_dlpack(to_dlpack(x))
9736 self.assertEqual(z, x)
9739 def test_dlpack_cuda(self):
9740 x = torch.randn(1, 2, 3, 4).cuda()
9741 z = from_dlpack(to_dlpack(x))
9742 self.assertEqual(z, x)
9744 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
9745 def test_from_numpy(self):
9757 for dtype
in dtypes:
9758 array = np.array([1, 2, 3, 4], dtype=dtype)
9759 tensor_from_array = torch.from_numpy(array)
9762 for i
in range(len(array)):
9763 self.assertEqual(tensor_from_array[i], array[i])
9766 x = np.linspace(1, 125, 125)
9769 expected = torch.arange(1, 126).view(5, 5, 5)[1]
9770 self.assertEqual(torch.from_numpy(x), expected)
9773 x = np.linspace(1, 25, 25)
9775 expected = torch.arange(1, 26).view(5, 5).t()
9776 self.assertEqual(torch.from_numpy(x.T), expected)
9779 x = np.linspace(1, 125, 125)
9782 expected = torch.arange(1, 126).view(5, 5, 5)[:, 1]
9783 self.assertEqual(torch.from_numpy(x), expected)
9786 x = np.zeros((0, 2))
9787 self.assertEqual(torch.from_numpy(x).shape, (0, 2))
9788 x = np.zeros((2, 0))
9789 self.assertEqual(torch.from_numpy(x).shape, (2, 0))
9792 x = np.array([3., 5., 8.])
9794 self.assertRaises(ValueError,
lambda: torch.from_numpy(x))
9796 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
9797 def test_ctor_with_numpy_array(self):
9809 incorrect_byteorder =
'>' if sys.byteorder ==
'little' else '<' 9810 incorrect_dtypes = map(
lambda t: incorrect_byteorder + t, [
'd',
'f'])
9812 for dtype
in correct_dtypes:
9813 array = np.array([1, 2, 3, 4], dtype=dtype)
9816 tensor = torch.DoubleTensor(array)
9817 for i
in range(len(array)):
9818 self.assertEqual(tensor[i], array[i])
9821 tensor = torch.cuda.DoubleTensor(array)
9822 for i
in range(len(array)):
9823 self.assertEqual(tensor[i], array[i])
9826 tensor = torch.FloatTensor(array)
9827 for i
in range(len(array)):
9828 self.assertEqual(tensor[i], array[i])
9830 tensor = torch.HalfTensor(array)
9831 for i
in range(len(array)):
9832 self.assertEqual(tensor[i], array[i])
9835 tensor = torch.cuda.FloatTensor(array)
9836 for i
in range(len(array)):
9837 self.assertEqual(tensor[i], array[i])
9839 tensor = torch.cuda.HalfTensor(array)
9840 for i
in range(len(array)):
9841 self.assertEqual(tensor[i], array[i])
9843 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
9844 def test_ctor_with_numpy_scalar_ctor(self):
9854 for dtype
in dtypes:
9855 self.assertEqual(dtype(42),
torch.tensor(dtype(42)).item())
9857 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
9858 def test_numpy_index(self):
9859 i = np.int32([0, 1, 2])
9860 x = torch.randn(5, 5)
9862 self.assertFalse(isinstance(idx, int))
9863 self.assertEqual(x[idx], x[int(idx)])
9865 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
9866 def test_numpy_array_interface(self):
9885 for tp, dtype
in zip(types, dtypes):
9886 if np.dtype(dtype).kind ==
'u': 9887 x = torch.Tensor([1, 2, 3, 4]).type(tp) 9888 array = np.array([1, 2, 3, 4], dtype=dtype) 9890 x = torch.Tensor([1, -2, 3, -4]).type(tp)
9891 array = np.array([1, -2, 3, -4], dtype=dtype)
9894 asarray = np.asarray(x)
9895 self.assertIsInstance(asarray, np.ndarray)
9896 self.assertEqual(asarray.dtype, dtype)
9897 for i
in range(len(x)):
9898 self.assertEqual(asarray[i], x[i])
9902 abs_array = np.abs(array)
9903 self.assertIsInstance(abs_x, tp)
9904 for i
in range(len(x)):
9905 self.assertEqual(abs_x[i], abs_array[i])
9908 for dtype
in dtypes:
9909 x = torch.IntTensor([1, -2, 3, -4])
9910 asarray = np.asarray(x, dtype=dtype)
9911 self.assertEqual(asarray.dtype, dtype)
9912 if np.dtype(dtype).kind ==
'u': 9913 wrapped_x = np.array([1, -2, 3, -4], dtype=dtype) 9914 for i
in range(len(x)):
9915 self.assertEqual(asarray[i], wrapped_x[i])
9917 for i
in range(len(x)):
9918 self.assertEqual(asarray[i], x[i])
9921 float_types = [torch.DoubleTensor, torch.FloatTensor]
9922 float_dtypes = [np.float64, np.float32]
9923 for tp, dtype
in zip(float_types, float_dtypes):
9924 x = torch.Tensor([1, 2, 3, 4]).type(tp)
9925 array = np.array([1, 2, 3, 4], dtype=dtype)
9926 for func
in [
'sin',
'sqrt',
'ceil']:
9927 ufunc = getattr(np, func)
9929 res_array = ufunc(array)
9930 self.assertIsInstance(res_x, tp)
9931 for i
in range(len(x)):
9932 self.assertEqual(res_x[i], res_array[i])
9935 for tp, dtype
in zip(types, dtypes):
9936 x = torch.Tensor([1, 2, 3, 4]).type(tp)
9937 array = np.array([1, 2, 3, 4], dtype=dtype)
9938 geq2_x = np.greater_equal(x, 2)
9939 geq2_array = np.greater_equal(array, 2).astype(
'uint8')
9940 self.assertIsInstance(geq2_x, torch.ByteTensor)
9941 for i
in range(len(x)):
9942 self.assertEqual(geq2_x[i], geq2_array[i])
9944 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
9945 def test_multiplication_numpy_scalar(self):
9946 for np_dtype
in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]:
9947 for t_dtype
in [torch.float, torch.double]:
9948 np_sc = np_dtype(2.0)
9949 t = torch.ones(2, requires_grad=
True, dtype=t_dtype)
9951 self.assertIsInstance(r1, torch.Tensor)
9952 self.assertTrue(r1.dtype == t_dtype)
9953 self.assertTrue(r1.requires_grad)
9955 self.assertIsInstance(r2, torch.Tensor)
9956 self.assertTrue(r2.dtype == t_dtype)
9957 self.assertTrue(r2.requires_grad)
9959 def test_error_msg_type_translation(self):
9960 with self.assertRaisesRegex(
9963 '(?=.*Double)(?=.*Long)'):
9966 input = torch.autograd.Variable(torch.randn(1, 1, 1, 6).double())
9967 weight = torch.zeros(1, 1, 1, 3).long()
9968 model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=
False)
9969 model.weight.data = weight
9972 def test_tensor_from_sequence(self):
9973 class MockSequence(object):
9974 def __init__(self, lst):
9978 return len(self.lst)
9980 def __getitem__(self, item):
9983 class GoodMockSequence(MockSequence):
9984 def __getitem__(self, item):
9985 return self.lst[item]
9987 bad_mock_seq = MockSequence([1.0, 2.0, 3.0])
9988 good_mock_seq = GoodMockSequence([1.0, 2.0, 3.0])
9989 with self.assertRaisesRegex(ValueError,
'could not determine the shape'):
9990 torch.Tensor(bad_mock_seq)
9991 self.assertEqual(torch.Tensor([1.0, 2.0, 3.0]), torch.Tensor(good_mock_seq))
9993 def test_comparison_ops(self):
9994 x = torch.randn(5, 5)
9995 y = torch.randn(5, 5)
9998 for idx
in iter_indices(x):
9999 self.assertEqual(x[idx] == y[idx], eq[idx] == 1)
10002 for idx
in iter_indices(x):
10003 self.assertEqual(x[idx] != y[idx], ne[idx] == 1)
10006 for idx
in iter_indices(x):
10007 self.assertEqual(x[idx] < y[idx], lt[idx] == 1)
10010 for idx
in iter_indices(x):
10011 self.assertEqual(x[idx] <= y[idx], le[idx] == 1)
10014 for idx
in iter_indices(x):
10015 self.assertEqual(x[idx] > y[idx], gt[idx] == 1)
10018 for idx
in iter_indices(x):
10019 self.assertEqual(x[idx] >= y[idx], ge[idx] == 1)
10021 def test_bitwise_ops(self):
10022 x = torch.randn(5, 5).gt(0)
10023 y = torch.randn(5, 5).gt(0)
10026 for idx
in iter_indices(x):
10027 if and_result[idx]:
10028 self.assertTrue(x[idx]
and y[idx])
10030 self.assertFalse(x[idx]
and y[idx])
10033 for idx
in iter_indices(x):
10035 self.assertTrue(x[idx]
or y[idx])
10037 self.assertFalse(x[idx]
or y[idx])
10040 for idx
in iter_indices(x):
10041 if xor_result[idx]:
10042 self.assertTrue(x[idx] ^ y[idx])
10044 self.assertFalse(x[idx] ^ y[idx])
10047 for idx
in iter_indices(x):
10048 self.assertEqual(1 - x[idx], invert_result[idx])
10050 x_clone = x.clone()
10052 self.assertEqual(x_clone, and_result)
10054 x_clone = x.clone()
10056 self.assertEqual(x_clone, or_result)
10058 x_clone = x.clone()
10060 self.assertEqual(x_clone, xor_result)
10062 def test_invert(self):
10063 x = torch.ByteTensor([0, 1, 1])
10064 self.assertEqual((~x).tolist(), [1, 0, 0])
10066 def test_apply(self):
10067 x = torch.arange(1, 6)
10068 res = x.clone().apply_(
lambda k: k + k)
10069 self.assertEqual(res, x * 2)
10070 self.assertRaises(TypeError,
lambda: x.apply_(
lambda k:
"str"))
10072 def test_map(self):
10073 x = torch.autograd.Variable(torch.randn(3, 3))
10074 y = torch.autograd.Variable(torch.randn(3))
10076 res.map_(y,
lambda a, b: a + b)
10077 self.assertEqual(res, x + y)
10078 self.assertRaisesRegex(TypeError,
"not callable",
lambda: res.map_(y,
"str"))
10080 def test_map2(self):
10081 x = torch.autograd.Variable(torch.randn(3, 3))
10082 y = torch.autograd.Variable(torch.randn(3))
10083 z = torch.autograd.Variable(torch.randn(1, 3))
10085 res.map2_(y, z,
lambda a, b, c: a + b * c)
10086 self.assertEqual(res, x + y * z)
10087 z.requires_grad =
True 10088 self.assertRaisesRegex(
10089 RuntimeError,
"requires grad",
10090 lambda: res.map2_(y, z,
lambda a, b, c: a + b * c))
10092 def test_Size(self):
10093 x = torch.Size([1, 2, 3])
10094 self.assertIsInstance(x, tuple)
10095 self.assertEqual(x[0], 1)
10096 self.assertEqual(x[1], 2)
10097 self.assertEqual(x[2], 3)
10098 self.assertEqual(len(x), 3)
10099 self.assertRaises(TypeError,
lambda: torch.Size(torch.ones(3)))
10101 self.assertIsInstance(x * 2, torch.Size)
10102 self.assertIsInstance(x[:-1], torch.Size)
10103 self.assertIsInstance(x + x, torch.Size)
10105 def test_Size_scalar(self):
10108 x = torch.Size([0, 1, two, three, 4])
10109 for i
in range(1, 5):
10110 self.assertEqual(x[i], i)
10112 def test_Size_iter(self):
10113 for sizes
in [iter([1, 2, 3, 4, 5]), range(1, 6)]:
10114 x = torch.Size(sizes)
10115 for i
in range(0, 5):
10116 self.assertEqual(x[i], i + 1)
10118 def test_t_not_2d_error(self):
10119 self.assertRaises(RuntimeError,
lambda: torch.randn(2, 3, 4).t())
10120 self.assertRaises(RuntimeError,
lambda: torch.randn(2, 3, 4).t_())
10123 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
10124 def test_big_transpose(self):
10125 t = torch.rand(456, 789)
10126 t1 = t.t().contiguous()
10127 t2 = torch.from_numpy(t.numpy().transpose())
10128 self.assertEqual(t1, t2)
10130 def test_inplace_division(self):
10131 t = torch.rand(5, 5)
10135 self.assertEqual(id_before, id_after)
10137 def test_simple_scalar_cast(self):
10138 ok = [torch.Tensor([1.5]), torch.zeros(1, 1, 1, 1)]
10139 ok_values = [1.5, 0]
10141 not_ok = map(torch.Tensor, [[], [1, 2], [[1, 2], [3, 4]]])
10143 for tensor, value
in zip(ok, ok_values):
10144 self.assertEqual(int(tensor), int(value))
10145 self.assertEqual(float(tensor), float(value))
10146 if sys.version_info[0] < 3:
10147 self.assertEqual(long(tensor), long(value))
10149 for tensor
in not_ok:
10150 self.assertRaises(ValueError,
lambda: int(tensor))
10151 self.assertRaises(ValueError,
lambda: float(tensor))
10152 if sys.version_info[0] < 3:
10153 self.assertRaises(ValueError,
lambda: long(tensor))
10155 def test_offset_scalar_cast(self):
10156 x = torch.Tensor([1, 2, 3])
10158 self.assertEqual(int(y), 3)
10161 @unittest.skipIf(
True,
"flush_denormal not supported")
10162 def test_set_flush_denormal(self):
10164 tiny_double = 1e-320
10165 float_tensor = torch.FloatTensor([1.0, tiny_float])
10166 double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double])
10168 self.assertEqual(float_tensor[0], 1.0, prec=0.0)
10169 self.assertEqual(float_tensor[1], tiny_float, prec=tiny_float / 16)
10170 self.assertEqual(double_tensor[0], 1.0, prec=0.0)
10171 self.assertEqual(double_tensor[1], tiny_float, prec=0.0)
10172 self.assertEqual(double_tensor[2], tiny_double, prec=0.0)
10174 torch.set_flush_denormal(
True)
10175 self.assertEqual(float_tensor[0], 1.0, prec=0.0)
10176 self.assertEqual(float_tensor[1], 0.0, prec=0.0)
10177 self.assertEqual(double_tensor[0], 1.0, prec=0.0)
10179 self.assertEqual(double_tensor[1], tiny_float, prec=0.0)
10180 self.assertEqual(double_tensor[2], 0.0, prec=0.0)
10181 torch.set_flush_denormal(
False)
10183 def test_unique(self):
10184 x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3])
10185 expected_unique = torch.LongTensor([1, 2, 3, 5, 8])
10186 expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2])
10188 x_unique = torch.unique(x)
10190 expected_unique.tolist(), sorted(x_unique.tolist()))
10192 x_unique, x_inverse = x.unique(return_inverse=
True)
10194 expected_unique.tolist(), sorted(x_unique.tolist()))
10195 self.assertEqual(expected_inverse.numel(), x_inverse.numel())
10197 x_unique = x.unique(sorted=
True)
10198 self.assertEqual(expected_unique, x_unique)
10200 x_unique, x_inverse = torch.unique(
10201 x, sorted=
True, return_inverse=
True)
10202 self.assertEqual(expected_unique, x_unique)
10203 self.assertEqual(expected_inverse, x_inverse)
10206 y = x.view(2, 2, 2)
10207 y_unique, y_inverse = y.unique(sorted=
True, return_inverse=
True)
10208 self.assertEqual(expected_unique, y_unique)
10209 self.assertEqual(expected_inverse.view(y.size()), y_inverse)
10212 int_unique, int_inverse = torch.unique(
10213 torch.IntTensor([2, 1, 2]), sorted=
True, return_inverse=
True)
10214 self.assertEqual(torch.IntTensor([1, 2]), int_unique)
10215 self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse)
10217 double_unique, double_inverse = torch.unique(
10218 torch.DoubleTensor([2., 1.5, 2.1, 2.]),
10220 return_inverse=
True,
10222 self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique)
10223 self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse)
10225 byte_unique, byte_inverse = torch.unique(
10226 torch.ByteTensor([133, 7, 7, 7, 42, 128]),
10228 return_inverse=
True,
10230 self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
10231 self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)
10233 def test_unique_dim(self):
10234 def run_test(dtype=torch.float):
10242 [0., 1.]]], dtype=dtype)
10246 [0., 1.]]], dtype=dtype)
10253 [2., 1.]]], dtype=dtype)
10262 [0., 1.]]], dtype=dtype)
10266 x_unique = torch.unique(x, dim=0)
10267 self.assertEqual(expected_unique_dim0, x_unique)
10269 x_unique, x_inverse = torch.unique(x, return_inverse=
True, dim=0)
10270 self.assertEqual(expected_unique_dim0, x_unique)
10271 self.assertEqual(expected_inverse_dim0, x_inverse)
10274 x_unique = torch.unique(x, dim=1)
10275 self.assertEqual(expected_unique_dim1, x_unique)
10277 x_unique, x_inverse = torch.unique(x, return_inverse=
True, dim=1)
10278 self.assertEqual(expected_unique_dim1, x_unique)
10279 self.assertEqual(expected_inverse_dim1, x_inverse)
10282 x_unique = torch.unique(x, dim=2)
10283 self.assertEqual(expected_unique_dim2, x_unique)
10285 x_unique, x_inverse = torch.unique(x, return_inverse=
True, dim=2)
10286 self.assertEqual(expected_unique_dim2, x_unique)
10287 self.assertEqual(expected_inverse_dim2, x_inverse)
10289 run_test(torch.float)
10290 run_test(torch.double)
10291 run_test(torch.long)
10292 run_test(torch.uint8)
10295 def _test_bincount(self, device):
10297 with self.assertRaisesRegex(RuntimeError,
'1-d non-negative integral'):
10300 with self.assertRaisesRegex(RuntimeError,
'1-d non-negative integral'):
10301 torch.bincount(
torch.tensor([[1, 2], [3, 4]], device=device))
10303 with self.assertRaisesRegex(RuntimeError,
'not implemented'):
10304 torch.bincount(
torch.tensor([1., 0.3], device=device))
10306 with self.assertRaisesRegex(RuntimeError,
'minlength should be >= 0'):
10311 with self.assertRaisesRegex(RuntimeError,
'same length'):
10315 self.assertEqual(torch.bincount(
torch.tensor([], device=device, dtype=torch.long)),
10316 torch.zeros(0, dtype=torch.long, device=device))
10318 self.assertEqual(torch.bincount(
torch.tensor([], device=device, dtype=torch.long), minlength=10),
10319 torch.zeros(10, dtype=torch.long, device=device))
10323 [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
10325 torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
10328 int_counts = torch.bincount(
10329 torch.tensor([1, 1, 1, 1], device=device), minlength=5)
10331 torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
10334 byte_counts = torch.bincount(
10338 torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
10339 byte_counts = torch.bincount(
10341 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
10343 torch.tensor([1, 9, 0, 0, 5], device=device), byte_counts)
10345 inputs =
torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device)
10346 weights =
torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
10348 assert not inputs[:, i].is_contiguous(),
"Inputs are supposed to be non-contiguous" 10349 assert not weights[:, i].is_contiguous(),
"Weights are supposed to be non-contiguous" 10351 self.assertEqual(inputs[:, 0].bincount(),
torch.tensor([1, 1, 1, 2]))
10353 self.assertEqual(inputs[:, 1].bincount(weights[:, 1]),
torch.tensor([1, 9, 0, 0, 5]))
10355 self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
10359 all0s = torch.zeros((32, 2), dtype=torch.int64, device=device)
10360 self.assertEqual(all0s[:, 0].bincount(),
torch.tensor([32]))
10362 all1s = torch.ones((32, 2), dtype=torch.int64, device=device)
10363 self.assertEqual(all1s[:, 0].bincount(),
torch.tensor([0, 32]))
10366 big_exp = torch.zeros(10000000, device=device)
10369 big_out =
torch.tensor([9999999] * 100, device=device).bincount(big_w)
10370 self.assertEqual(big_exp, big_out)
10372 big_exp = torch.zeros(2, device=device)
10373 big_exp[1] = 1000000
10374 big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount()
10375 self.assertEqual(big_exp, big_out)
10377 def test_bincount_cpu(self):
10378 self._test_bincount(self, device=
'cpu')
10380 def test_is_nonzero(self):
10381 self.assertExpectedRaises(RuntimeError,
lambda:
torch.tensor([]).is_nonzero(), subname=
"empty")
10382 self.assertExpectedRaises(RuntimeError,
lambda:
torch.tensor([0, 0]).is_nonzero(), subname=
"multiple")
10390 def test_meshgrid(self):
10394 grid_a, grid_b, grid_c = torch.meshgrid([a, b, c])
10395 self.assertEqual(grid_a.shape, torch.Size([1, 3, 2]))
10396 self.assertEqual(grid_b.shape, torch.Size([1, 3, 2]))
10397 self.assertEqual(grid_c.shape, torch.Size([1, 3, 2]))
10398 grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c)
10399 self.assertEqual(grid_a2.shape, torch.Size([1, 3, 2]))
10400 self.assertEqual(grid_b2.shape, torch.Size([1, 3, 2]))
10401 self.assertEqual(grid_c2.shape, torch.Size([1, 3, 2]))
10402 expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64)
10409 self.assertTrue(grid_a.equal(expected_grid_a))
10410 self.assertTrue(grid_b.equal(expected_grid_b))
10411 self.assertTrue(grid_c.equal(expected_grid_c))
10412 self.assertTrue(grid_a2.equal(expected_grid_a))
10413 self.assertTrue(grid_b2.equal(expected_grid_b))
10414 self.assertTrue(grid_c2.equal(expected_grid_c))
10416 @unittest.skipIf(
torch.cuda.is_available()
or IS_SANDCASTLE,
"CUDA is available, can't test CUDA not built error")
10417 def test_cuda_not_built(self):
10418 msg =
"Torch not compiled with CUDA enabled" 10420 self.assertRaisesRegex(AssertionError, msg,
lambda:
torch.tensor([1], device=
"cuda"))
10421 self.assertRaisesRegex(AssertionError, msg,
lambda:
torch.tensor([1]).cuda())
10422 self.assertRaisesRegex(AssertionError, msg,
lambda: torch.cuda.FloatTensor())
10423 self.assertRaisesRegex(AssertionError, msg,
lambda:
torch.tensor([1]).to(device=
"cuda"))
10425 def test_cast_binary_op(self):
10434 self.assertEqual(a.type(), a_copy.type())
10435 self.assertEqual(a.data.type(), a_copy.data.type())
10436 self.assertEqual(b.type(), b_copy.type())
10437 self.assertEqual(b.data.type(), b_copy.type())
10439 def test_cartesian_prod(self):
10443 prod = torch.cartesian_prod(a, b, c)
10445 self.assertEqual(expected, prod)
10448 d = torch.empty(0, dtype=b.dtype)
10449 prod = torch.cartesian_prod(a, b, c, d)
10450 expected = torch.empty(0, 4, dtype=b.dtype)
10451 self.assertEqual(expected, prod)
10454 prod = torch.cartesian_prod(b)
10455 self.assertEqual(b, prod)
10457 def test_combinations(self):
10460 c = torch.combinations(a, r=1)
10462 self.assertEqual(c, expected)
10464 c = torch.combinations(a, r=1, with_replacement=
True)
10465 expected =
torch.tensor(list(combinations_with_replacement(a, r=1)))
10466 self.assertEqual(c, expected)
10468 c = torch.combinations(a)
10470 self.assertEqual(c, expected)
10472 c = torch.combinations(a, with_replacement=
True)
10473 expected =
torch.tensor(list(combinations_with_replacement(a, r=2)))
10474 self.assertEqual(c, expected)
10476 c = torch.combinations(a, r=3)
10478 self.assertEqual(c, expected)
10480 c = torch.combinations(a, r=4)
10481 expected = torch.empty(0, 4, dtype=a.dtype)
10482 self.assertEqual(c, expected)
10484 c = torch.combinations(a, r=5)
10485 expected = torch.empty(0, 5, dtype=a.dtype)
10486 self.assertEqual(c, expected)
10490 c1 = torch.combinations(a)
10491 c2 = torch.combinations(a, with_replacement=
True)
10492 expected = torch.empty(0, 2, dtype=a.dtype)
10493 self.assertEqual(c1, expected)
10494 self.assertEqual(c2, expected)
10496 def test_has_internal_overlap(self):
10499 OVERLAP_TOO_HARD = 2
10502 a = torch.randn(3, 3)
10503 self.assertEqual(torch._debug_has_internal_overlap(a), OVERLAP_NO)
10506 b = torch.randn(1, 3)
10507 b_expanded = b.expand(4, 3)
10508 self.assertEqual(torch._debug_has_internal_overlap(b_expanded), OVERLAP_YES)
10511 def unary_check_mem_overlap(self, inplace_op, value=-0.5, device='cpu'):
10512 tensor =
torch.tensor(value, device=device).expand(3, 3)
10513 with self.assertRaisesRegex(RuntimeError,
'single memory location'):
10517 def _test_inplace_unary_mem_overlap(self, device='cpu'):
10518 TestTorch.unary_check_mem_overlap(self,
lambda t: t.acos_(), device=device)
10519 TestTorch.unary_check_mem_overlap(self,
lambda t: t.asin_(), device=device)
10520 TestTorch.unary_check_mem_overlap(self,
lambda t: t.atan_(), device=device)
10521 TestTorch.unary_check_mem_overlap(self,
lambda t: t.ceil_(), device=device)
10522 TestTorch.unary_check_mem_overlap(self,
lambda t: t.cos_(), device=device)
10523 TestTorch.unary_check_mem_overlap(self,
lambda t: t.erf_(), device=device)
10524 TestTorch.unary_check_mem_overlap(self,
lambda t: t.erfc_(), device=device)
10525 TestTorch.unary_check_mem_overlap(self,
lambda t: t.exp_(), device=device)
10526 TestTorch.unary_check_mem_overlap(self,
lambda t: t.expm1_(), device=device)
10527 TestTorch.unary_check_mem_overlap(self,
lambda t: t.floor_(), device=device)
10528 TestTorch.unary_check_mem_overlap(self,
lambda t: t.log_(), device=device)
10529 TestTorch.unary_check_mem_overlap(self,
lambda t: t.log10_(), device=device)
10530 TestTorch.unary_check_mem_overlap(self,
lambda t: t.log1p_(), device=device)
10531 TestTorch.unary_check_mem_overlap(self,
lambda t: t.log2_(), device=device)
10532 TestTorch.unary_check_mem_overlap(self,
lambda t: t.round_(), device=device)
10533 TestTorch.unary_check_mem_overlap(self,
lambda t: t.rsqrt_(), device=device)
10534 TestTorch.unary_check_mem_overlap(self,
lambda t: t.sin_(), device=device)
10535 TestTorch.unary_check_mem_overlap(self,
lambda t: t.sqrt_(), device=device)
10536 TestTorch.unary_check_mem_overlap(self,
lambda t: t.tan_(), device=device)
10537 TestTorch.unary_check_mem_overlap(self,
lambda t: t.tanh_(), device=device)
10538 TestTorch.unary_check_mem_overlap(self,
lambda t: t.trunc_(), device=device)
10540 def test_inplace_unary_mem_overlap(self):
10541 return self._test_inplace_unary_mem_overlap(self)
10543 @unittest.expectedFailure
10544 def test_abs_unary_mem_overlap(self):
10545 self.unary_check_mem_overlap(
lambda t: t.abs_())
10547 @unittest.expectedFailure
10548 def test_sinh_unary_mem_overlap(self):
10549 self.unary_check_mem_overlap(
lambda t: t.sinh_())
10551 @unittest.expectedFailure
10552 def test_cosh_unary_mem_overlap(self):
10553 self.unary_check_mem_overlap(
lambda t: t.cosh_())
10556 def test_reverse_binary_ops_multiple_device(self):
10563 with self.assertRaisesRegex(RuntimeError,
"expected both inputs to be on same device"):
10565 with self.assertRaisesRegex(RuntimeError,
"expected both inputs to be on same device"):
10567 with self.assertRaisesRegex(RuntimeError,
"expected both inputs to be on same device"):
10569 with self.assertRaisesRegex(RuntimeError,
"expected both inputs to be on same device"):
10571 with self.assertRaisesRegex(RuntimeError,
"expected both inputs to be on same device"):
10574 def test_allow_tensor_metadata_change(self):
10576 with self.assertRaisesRegex(
10578 "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()"):
10580 with self.assertRaisesRegex(
10582 "set_storage is not allowed on Tensor created from .data or .detach()"):
10584 with self.assertRaisesRegex(
10586 "set_storage_offset is not allowed on Tensor created from .data or .detach()"):
10587 t.set_(t.storage(), 0, t.size(), list(t.stride()))
10592 def test_c10_layer_norm(self):
10594 X = torch.rand(5, 5, dtype=torch.float)
10598 actual_norm, actual_mean, actual_stdev = \
10599 torch.ops._caffe2.LayerNorm(
torch.tensor(X), 1, epsilon)
10609 def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0):
10610 def neg_dim_test(self):
10611 if isinstance(tensor_arg, list):
10612 assert METHOD
not in types
and INPLACE_METHOD
not in types
10613 x = [torch.randn(arg)
for arg
in tensor_arg]
10614 ndim = len(tensor_arg[-1])
10616 x = torch.randn(*tensor_arg)
10617 ndim = len(tensor_arg)
10620 n_dim_to_test = sum(map(
lambda e: e
is DIM_ARG, arg_constr()))
10622 for dims_val
in combinations(range(ndim), n_dim_to_test):
10624 arg_neg = copy.deepcopy(arg)
10626 for i, v
in enumerate(arg):
10628 arg[i] = dims_val[idx]
10629 arg_neg[i] = dims_val[idx] - ndim
10632 if METHOD
in types:
10633 a = getattr(x, name)(*arg)
10634 b = getattr(x, name)(*arg_neg)
10635 self.assertEqual(a, b)
10637 if INPLACE_METHOD
in types:
10639 getattr(a, name +
'_')(*arg)
10641 getattr(b, name +
'_')(*arg_neg)
10642 self.assertEqual(a, b)
10644 if FUNCTIONAL
in types:
10645 a = getattr(torch, name)(x, *arg)
10646 b = getattr(torch, name)(x, *arg_neg)
10647 self.assertEqual(a, b)
10649 return neg_dim_test
10652 def idx_tensor(size, max_val):
10653 return torch.LongTensor(*size).random_(0, max_val - 1)
10656 def add_neg_dim_tests():
10658 (
'narrow', (10, 20, 30),
lambda: [DIM_ARG, 0, 5], [METHOD]),
10659 (
'transpose', (10, 20, 30),
lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10660 (
'size', (10, 20, 30),
lambda: [DIM_ARG], [METHOD]),
10661 (
'cat', [(2, 3, 4), (2, 3, 4)],
lambda: [DIM_ARG], [FUNCTIONAL]),
10662 (
'chunk', (10, 20, 30),
lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10663 (
'gather', (10, 20),
lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]),
10664 (
'index_select', (10, 10),
lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]),
10665 (
'split', (10, 20),
lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10666 (
'squeeze', (10, 1, 20, 1),
lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10667 (
'unbind', (2, 3, 4),
lambda: [DIM_ARG], [FUNCTIONAL]),
10668 (
'unsqueeze', (10, 20),
lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1),
10669 (
'cumprod', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10670 (
'cumsum', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10671 (
'mean', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10672 (
'median', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10673 (
'mode', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10674 (
'norm', (10, 20),
lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]),
10675 (
'prod', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10676 (
'std', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10677 (
'sum', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10678 (
'var', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10679 (
'kthvalue', (10, 20),
lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]),
10680 (
'max', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10681 (
'min', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10682 (
'sort', (10, 20),
lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
10683 (
'topk', (10, 20),
lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
10684 (
'renorm', (10, 20),
lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
10685 (
'index_add', (10, 10),
lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10686 (
'index_copy', (10, 10),
lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10687 (
'index_fill', (10, 10),
lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]),
10688 (
'scatter', (10, 10),
lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
10689 (
'select', (10, 20),
lambda: [DIM_ARG, 3], [METHOD]),
10690 (
'unfold', (10, 20),
lambda: [DIM_ARG, 5, 2], [METHOD]),
10693 for decl
in neg_dim_tests:
10695 name, tensor_arg, arg_constr, types = decl
10697 elif len(decl) == 5:
10698 name, tensor_arg, arg_constr, types, extra_dim = decl
10700 test_name =
'test_' + name +
'_neg_dim' 10702 assert not hasattr(_TestTorchMixin, test_name),
"Duplicated test name: " + test_name
10703 setattr(_TestTorchMixin, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim))
10705 add_neg_dim_tests()
10711 if __name__ ==
'__main__':
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True)
def softmax(input, dim=None, _stacklevel=3, dtype=None)
def log_softmax(input, dim=None, _stacklevel=3, dtype=None)
def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5)
def get_summarized_data(self)
Module caffe2.python.helpers.conv.
def set_default_tensor_type(t)
def readinto_opt(self, view)