Caffe2 - Python API
A deep learning, cross platform ML framework
test_torch.py
1 import sys
2 import io
3 import os
4 import math
5 import random
6 import operator
7 import copy
8 import shutil
9 import torch
10 import torch.cuda
11 import tempfile
12 import unittest
13 import warnings
14 import pickle
15 import gzip
16 import types
17 import textwrap
18 import re
19 from torch._utils_internal import get_file_path, get_file_path_2
20 from torch.utils.dlpack import from_dlpack, to_dlpack
21 from torch._utils import _rebuild_tensor
22 from torch._six import inf, nan, string_classes, istuple
23 from itertools import product, combinations, combinations_with_replacement
24 from functools import reduce
25 from torch import multiprocessing as mp
26 from common_methods_invocations import tri_tests_args, run_additional_tri_tests, \
27  _compare_trilu_indices
28 from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
29  TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
30  IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
31  IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist
32 from multiprocessing.reduction import ForkingPickler
33 
34 # load_tests from common_utils is used to automatically filter tests for
35 # sharding on sandcastle. This line silences flake warnings
36 load_tests = load_tests
37 
38 if TEST_NUMPY:
39  import numpy as np
40 
41 if TEST_SCIPY:
42  from scipy import signal
43 
44 if TEST_LIBROSA:
45  import librosa
46 
47 SIZE = 100
48 
49 can_retrieve_source = True
50 with warnings.catch_warnings(record=True) as warns:
51  with tempfile.NamedTemporaryFile() as checkpoint:
52  x = torch.save(torch.nn.Module(), checkpoint)
53  for warn in warns:
54  if "Couldn't retrieve source code" in warn.message.args[0]:
55  can_retrieve_source = False
56  break
57 
58 
59 class FilelikeMock(object):
60  def __init__(self, data, has_fileno=True, has_readinto=False):
61  if has_readinto:
62  setattr(self, 'readinto', self.readinto_opt)
63  if has_fileno:
64  # Python 2's StringIO.StringIO has no fileno attribute.
65  # This is used to test that.
66  setattr(self, 'fileno', self.fileno_opt)
67 
68  self.calls = set()
69  self.bytesio = io.BytesIO(data)
70 
71  def trace(fn, name):
72  def result(*args, **kwargs):
73  self.calls.add(name)
74  return fn(*args, **kwargs)
75  return result
76 
77  for attr in ['read', 'readline', 'seek', 'tell', 'write', 'flush']:
78  traced_fn = trace(getattr(self.bytesio, attr), attr)
79  setattr(self, attr, traced_fn)
80 
81  def fileno_opt(self):
82  raise io.UnsupportedOperation('Not a real file')
83 
84  def readinto_opt(self, view):
85  self.calls.add('readinto')
86  return self.bytesio.readinto(view)
87 
88  def was_called(self, name):
89  return name in self.calls
90 
91 
92 class BytesIOContext(io.BytesIO):
93  def __enter__(self):
94  return self
95 
96  def __exit__(self, *args):
97  pass
98 
99 
100 # This is intentionally prefixed by an underscore. Otherwise pytest will try to
101 # run its methods as test cases.
102 class _TestTorchMixin(object):
103  def _check_sum_dim(tensors, dim):
104  for tensor in tensors:
105  expected = tensor.numpy().sum(dim)
106  actual = tensor.sum(dim)
107  self.assertEqual(expected.shape, actual.shape)
108  if actual.dtype == torch.float:
109  self.assertTrue(np.allclose(expected, actual.numpy(), rtol=1e-03, atol=1e-05))
110  else:
111  self.assertTrue(np.allclose(expected, actual.numpy()))
112 
113  def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True):
114  float_types = [torch.double,
115  torch.float]
116  int_types = [torch.int64,
117  torch.int32,
118  torch.int16]
119 
120  def make_contiguous(shape, dtype):
121  if dtype in float_types:
122  val = torch.randn(shape, dtype=dtype)
123  val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
124  val = val + ((val_range[1] - val_range[0]) / 2.0)
125  val = torch.clamp(val, min=val_range[0], max=val_range[1])
126  return val
127  result = torch.zeros(shape, dtype=dtype)
128  result.apply_(lambda x: random.randint(val_range[0], val_range[1]))
129  return result
130 
131  def make_non_contiguous(shape, dtype):
132  contig = make_contiguous(shape, dtype)
133  non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
134  non_contig = non_contig.select(-1, -1)
135  non_contig.copy_(contig)
136  self.assertFalse(non_contig.is_contiguous())
137  return non_contig
138 
139  def make_contiguous_slice(size, dtype):
140  contig = make_contiguous((1, size), dtype)
141  non_contig = contig[:1, 1:size - 1]
142  self.assertTrue(non_contig.is_contiguous())
143  return contig
144 
145  types = []
146  if use_floating:
147  types += float_types
148  if use_integral:
149  types += int_types
150  tensors = {"cont": [], "noncont": [], "slice": []}
151  for dtype in types:
152  tensors["cont"].append(make_contiguous(shape, dtype))
153  tensors["noncont"].append(make_non_contiguous(shape, dtype))
154  tensors["slice"].append(make_contiguous_slice(sum(list(shape)), dtype))
155 
156  return tensors
157 
158  def test_dir(self):
159  dir(torch)
160 
161  def test_doc(self):
162  checked_types = (types.MethodType, types.FunctionType,
163  types.BuiltinFunctionType, types.BuiltinMethodType)
164 
165  def test_namespace(ns, *skips):
166  if isinstance(ns, object):
167  ns_name = ns.__class__.__name__
168  else:
169  ns_name = ns.__name__
170  skip_regexes = []
171  for r in skips:
172  if isinstance(r, string_classes):
173  skip_regexes.append(re.compile('^{}$'.format(re.escape(r))))
174  else:
175  skip_regexes.append(r)
176  for name in dir(ns):
177  if name.startswith('_'):
178  continue
179  var = getattr(ns, name)
180  if not isinstance(var, checked_types):
181  continue
182  doc = var.__doc__
183  has_doc = doc is not None and len(doc.strip()) > 0
184  full_name = ns_name + '.' + name
185  if any(r.match(name) for r in skip_regexes):
186  self.assertFalse(has_doc,
187  'New docs have been added for {}, please remove '
188  'it from the skipped list in TestTorch.test_doc'.format(full_name))
189  else:
190  self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
191 
192  # FIXME: fix all the skipped ones below!
193  test_namespace(torch.randn(1),
194  'as_strided',
195  'as_strided_',
196  re.compile('^clamp_(min|max)_?$'),
197  'coalesce',
198  'index_put',
199  'is_coalesced',
200  'is_distributed',
201  'is_complex',
202  'is_nonzero',
203  'is_same_size',
204  'is_signed',
205  'isclose',
206  'lgamma',
207  'lgamma_',
208  'log_softmax',
209  'map2_',
210  'new',
211  'pin_memory',
212  'polygamma',
213  'polygamma_',
214  'record_stream',
215  'reinforce',
216  'relu',
217  'relu_',
218  'prelu',
219  'resize',
220  'resize_as',
221  'smm',
222  'softmax',
223  'split_with_sizes',
224  'sspaddmm',
225  'storage_type',
226  'tan',
227  'to_dense',
228  'sparse_resize_',
229  'sparse_resize_and_clear_',
230  )
231  test_namespace(torch.nn)
232  test_namespace(torch.nn.functional, 'assert_int_or_pair', 'bilinear', 'feature_alpha_dropout')
233  # TODO: add torch.* tests when we have proper namespacing on ATen functions
234  # test_namespace(torch)
235 
236  def test_dot(self):
237  types = {
238  'torch.DoubleTensor': 1e-8,
239  'torch.FloatTensor': 1e-4,
240  }
241  for tname, _prec in types.items():
242  v1 = torch.randn(100).type(tname)
243  v2 = torch.randn(100).type(tname)
244  res1 = torch.dot(v1, v2)
245  res2 = 0
246  for i, j in zip(v1, v2):
247  res2 += i * j
248  self.assertEqual(res1, res2)
249  out = torch.randn(()).type(tname)
250  torch.dot(v1, v2, out=out)
251  self.assertEqual(res1, out)
252 
253  # Test 0-strided
254  for tname, _prec in types.items():
255  v1 = torch.randn(1).type(tname).expand(100)
256  v2 = torch.randn(100).type(tname)
257  res1 = torch.dot(v1, v2)
258  res2 = 0
259  for i, j in zip(v1, v2):
260  res2 += i * j
261  self.assertEqual(res1, res2)
262  out = torch.randn(()).type(tname)
263  torch.dot(v1, v2, out=out)
264  self.assertEqual(res1, out)
265 
266  def test_ger(self):
267  types = {
268  'torch.DoubleTensor': 1e-8,
269  'torch.FloatTensor': 1e-4,
270  }
271  for tname, _prec in types.items():
272  v1 = torch.randn(100).type(tname)
273  v2 = torch.randn(100).type(tname)
274  res1 = torch.ger(v1, v2)
275  res2 = torch.zeros(100, 100).type(tname)
276  for i in range(100):
277  for j in range(100):
278  res2[i, j] = v1[i] * v2[j]
279  self.assertEqual(res1, res2)
280 
281  # Test 0-strided
282  for tname, _prec in types.items():
283  v1 = torch.randn(1).type(tname).expand(100)
284  v2 = torch.randn(100).type(tname)
285  res1 = torch.ger(v1, v2)
286  res2 = torch.zeros(100, 100).type(tname)
287  for i in range(100):
288  for j in range(100):
289  res2[i, j] = v1[i] * v2[j]
290  self.assertEqual(res1, res2)
291 
292  def test_addr(self):
293  types = {
294  'torch.DoubleTensor': 1e-8,
295  'torch.FloatTensor': 1e-4,
296  }
297 
298  def run_test(m, v1, v2, m_transform=lambda x: x):
299  m = m_transform(m.clone())
300  ref = m.clone()
301  torch.addr(m, v1, v2, out=m)
302  for i in range(m.size(0)):
303  for j in range(m.size(1)):
304  ref[i, j] += v1[i] * v2[j]
305  self.assertEqual(m, ref)
306 
307  for tname, _prec in types.items():
308  for h, w in [(100, 110), (1, 20), (200, 2)]:
309  m = torch.randn(h, w).type(tname)
310  v1 = torch.randn(h).type(tname)
311  v2 = torch.randn(w).type(tname)
312  run_test(m, v1, v2)
313  # test transpose
314  run_test(m, v2, v1, lambda x: x.transpose(0, 1))
315  # test 0 strided
316  v1 = torch.randn(1).type(tname).expand(h)
317  run_test(m, v1, v2)
318  run_test(m, v2, v1, lambda x: x.transpose(0, 1))
319 
320  def test_addmv(self):
321  types = {
322  'torch.DoubleTensor': 1e-8,
323  'torch.FloatTensor': 1e-4,
324  }
325  for tname, _prec in types.items():
326  t = torch.randn(10).type(tname)
327  m = torch.randn(10, 100).type(tname)
328  v = torch.randn(100).type(tname)
329  res1 = torch.addmv(t, m, v)
330  res2 = torch.zeros(10).type(tname)
331  res2 += t
332  for i in range(10):
333  for j in range(100):
334  res2[i] += m[i, j] * v[j]
335  self.assertEqual(res1, res2)
336 
337  # Test 0-strided
338  for tname, _prec in types.items():
339  t = torch.randn(1).type(tname).expand(10)
340  m = torch.randn(10, 1).type(tname).expand(10, 100)
341  v = torch.randn(100).type(tname)
342  res1 = torch.addmv(t, m, v)
343  res2 = torch.zeros(10).type(tname)
344  res2 += t
345  for i in range(10):
346  for j in range(100):
347  res2[i] += m[i, j] * v[j]
348  self.assertEqual(res1, res2)
349 
350  def test_addmm(self):
351  types = {
352  'torch.DoubleTensor': 1e-8,
353  'torch.FloatTensor': 1e-4,
354  }
355  for tname, _prec in types.items():
356  M = torch.randn(10, 25).type(tname)
357  m1 = torch.randn(10, 50).type(tname)
358  m2 = torch.randn(50, 25).type(tname)
359  res1 = torch.addmm(M, m1, m2)
360  res2 = torch.zeros(10, 25).type(tname)
361  res2 += M
362  for i in range(10):
363  for j in range(25):
364  for k in range(50):
365  res2[i, j] += m1[i, k] * m2[k, j]
366  self.assertEqual(res1, res2)
367 
368  # Test 0-strided
369  for tname, _prec in types.items():
370  M = torch.randn(10, 1).type(tname).expand(10, 25)
371  m1 = torch.randn(10, 1).type(tname).expand(10, 50)
372  m2 = torch.randn(50, 25).type(tname)
373  res1 = torch.addmm(M, m1, m2)
374  res2 = torch.zeros(10, 25).type(tname)
375  res2 += M
376  for i in range(10):
377  for j in range(25):
378  for k in range(50):
379  res2[i, j] += m1[i, k] * m2[k, j]
380  self.assertEqual(res1, res2)
381 
382  def test_logical_any(self):
383  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
384  for device in devices:
385  x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device)
386 
387  self.assertEqual(
388  torch.tensor(0, dtype=torch.uint8, device=device),
389  x.any())
390 
391  self.assertEqual(
392  torch.zeros([1, 3, 400], dtype=torch.uint8, device=device),
393  x.any(0, keepdim=True))
394 
395  self.assertEqual(
396  torch.zeros([2, 1, 400], dtype=torch.uint8, device=device),
397  x.any(1, keepdim=True))
398 
399  self.assertEqual(
400  torch.zeros([2, 3, 1], dtype=torch.uint8, device=device),
401  x.any(2, keepdim=True))
402 
403  # set the last element to 0
404  x[-1][-1][-1] = 1
405 
406  self.assertEqual(
407  torch.tensor(1, dtype=torch.uint8, device=device),
408  x.any())
409 
410  y = torch.zeros([1, 3, 400], dtype=torch.uint8, device=device)
411  y[-1][-1][-1] = 1
412  self.assertEqual(y, x.any(0, keepdim=True))
413 
414  y = torch.zeros([2, 1, 400], dtype=torch.uint8, device=device)
415  y[-1][-1][-1] = 1
416  self.assertEqual(y, x.any(1, keepdim=True))
417 
418  y = torch.zeros([2, 3, 1], dtype=torch.uint8, device=device)
419  y[-1][-1][-1] = 1
420  self.assertEqual(y, x.any(2, keepdim=True))
421 
422  def test_logical_all(self):
423  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
424  for device in devices:
425  x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device)
426 
427  self.assertEqual(
428  torch.tensor(1, dtype=torch.uint8, device=device),
429  x.all())
430 
431  self.assertEqual(
432  torch.ones([1, 3, 400], dtype=torch.uint8, device=device),
433  x.all(0, keepdim=True))
434 
435  self.assertEqual(
436  torch.ones([2, 1, 400], dtype=torch.uint8, device=device),
437  x.all(1, keepdim=True))
438 
439  self.assertEqual(
440  torch.ones([2, 3, 1], dtype=torch.uint8, device=device),
441  x.all(2, keepdim=True))
442 
443  # set the last element to 0
444  x[-1][-1][-1] = 0
445 
446  self.assertEqual(
447  torch.tensor(0, dtype=torch.uint8, device=device),
448  x.all())
449 
450  y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device)
451  y[-1][-1][-1] = 0
452  self.assertEqual(y, x.all(0, keepdim=True))
453 
454  y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device)
455  y[-1][-1][-1] = 0
456  self.assertEqual(y, x.all(1, keepdim=True))
457 
458  y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device)
459  y[-1][-1][-1] = 0
460  self.assertEqual(y, x.all(2, keepdim=True))
461 
462  def test_allclose(self):
463  x = torch.tensor([1.0, 2.0, 3.0])
464  y = torch.tensor([1.01, 2.01, 3.01])
465  self.assertTrue(torch.allclose(x, y, rtol=0, atol=0.02))
466  self.assertTrue(torch.allclose(x, y, rtol=0.01, atol=0.0))
467  self.assertFalse(torch.allclose(x, y))
468  self.assertTrue(torch.allclose(torch.tensor([0.0]), torch.tensor([1e-8])))
469  x = torch.tensor([2.0, 3.0, nan])
470  y = torch.tensor([2.01, 3.01, nan])
471  self.assertFalse(torch.allclose(x, y, rtol=1e-2))
472  self.assertTrue(torch.allclose(x, y, rtol=1e-2, equal_nan=True))
473  self.assertFalse(torch.allclose(x, y, rtol=1e-3, equal_nan=True))
474  inf_t = torch.tensor([inf])
475  self.assertTrue(torch.allclose(inf_t, inf_t))
476  self.assertTrue(torch.allclose(-inf_t, -inf_t))
477  self.assertFalse(torch.allclose(inf_t, -inf_t))
478  self.assertFalse(torch.allclose(inf_t, torch.tensor([1e20])))
479  self.assertFalse(torch.allclose(-inf_t, torch.tensor([-1e20])))
480 
481  def test_linear_algebra_scalar_raises(self):
482  m = torch.randn(5, 5)
483  v = torch.randn(5)
484  s = torch.tensor(7)
485  self.assertRaises(RuntimeError, lambda: torch.mv(m, s))
486  self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s))
487  self.assertRaises(RuntimeError, lambda: torch.ger(v, s))
488  self.assertRaises(RuntimeError, lambda: torch.ger(s, v))
489  self.assertRaises(RuntimeError, lambda: torch.addr(m, v, s))
490  self.assertRaises(RuntimeError, lambda: torch.addr(m, s, v))
491 
492  def _test_math(self, torchfn, mathfn, input=None, test_expand=False):
493  if input is None:
494  input = []
495  input.append(list(range(-5, 5)))
496  input.append([0 for x in range(-5, 5)])
497  input.append([x + 1e-6 for x in range(-5, 5)])
498  # Some vectorized implementations don't support large ranges
499  input.append([x + 1e10 for x in range(-5, 5)])
500  input.append([x - 1e10 for x in range(-5, 5)])
501  input.append(torch.randn(10).tolist())
502  input.append((torch.randn(10) + 1e6).tolist())
503  input.append([math.pi * (x / 2) for x in range(-5, 5)])
504 
505  def compare_reference(input, dtype):
506  input = torch.tensor(input, dtype=dtype)
507  res1 = torchfn(input.clone())
508  res2 = input.clone().apply_(mathfn)
510 
511  # compare against the reference math function
512  compare_reference(input, torch.double)
513  compare_reference(input, torch.float)
514 
515  def check_non_contiguous(shape, dtype):
516  contig = torch.randn(shape, dtype=dtype)
517  non_contig = torch.empty(shape + (2,), dtype=dtype)[..., 0]
518  non_contig.copy_(contig)
519  self.assertFalse(non_contig.is_contiguous())
520  self.assertEqual(torchfn(contig), torchfn(non_contig), 'non-contiguous')
521 
522  # compare application against contiguous vs. non-contiguous
523  check_non_contiguous((5, 7), torch.double)
524  check_non_contiguous((1024,), torch.double)
525  check_non_contiguous((5, 7), torch.float)
526  check_non_contiguous((1024,), torch.float)
527 
528  def check_non_contiguous_index(dtype):
529  contig = torch.randn((2, 2, 1, 2), dtype=dtype)
530  non_contig = contig[:, 1, ...]
531  contig = non_contig.clone()
532  self.assertFalse(non_contig.is_contiguous())
533  self.assertEqual(torchfn(contig), torchfn(non_contig), 'non-contiguous index')
534 
535  check_non_contiguous_index(torch.float)
536  check_non_contiguous_index(torch.double)
537 
538  def check_non_contiguous_expand(shape, dtype):
539  contig = torch.randn(shape, dtype=dtype)
540  non_contig = contig.clone().expand(3, -1, -1)
541  self.assertFalse(non_contig.is_contiguous())
542  contig = torchfn(contig)
543  non_contig = torchfn(non_contig)
544  for i in range(3):
545  self.assertEqual(contig, non_contig[i], 'non-contiguous expand[' + str(i) + ']')
546 
547  # Expand is not defined for in-place operations
548  if test_expand:
549  # The size 1 case is special as it leads to 0 stride and needs to persists
550  check_non_contiguous_expand((1, 3), torch.double)
551  check_non_contiguous_expand((1, 7), torch.double)
552  check_non_contiguous_expand((5, 7), torch.float)
553 
554  # If size(dim) == 1, stride(dim) is not defined.
555  # The code needs to be able to handle this
556  def check_contiguous_size1(dtype):
557  contig = torch.randn((5, 100), dtype=dtype)
558  contig = contig[:1, :50]
559  contig2 = torch.empty(contig.size(), dtype=dtype)
560  contig2.copy_(contig)
561  self.assertTrue(contig.is_contiguous())
562  self.assertTrue(contig2.is_contiguous())
563  self.assertEqual(torchfn(contig), torchfn(contig2), 'contiguous size1')
564 
565  check_contiguous_size1(torch.double)
566  check_contiguous_size1(torch.float)
567 
568  def check_contiguous_size1_largedim(dtype):
569  contig = torch.randn((5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4), dtype=dtype)
570  contig = contig[:1, :, :, :, :, :, :, :, :, :, :, :]
571  contig2 = torch.empty(contig.size(), dtype=dtype)
572  contig2.copy_(contig)
573  self.assertTrue(contig.is_contiguous())
574  self.assertTrue(contig2.is_contiguous())
575  self.assertEqual(torchfn(contig), torchfn(contig2), 'contiguous size1')
576 
577  check_contiguous_size1_largedim(torch.double)
578  check_contiguous_size1_largedim(torch.float)
579 
580  def check_large(dtype):
581  input = torch.randn(1024, 512, dtype=dtype)
582  actual = torchfn(input)
583  expected = torch.stack([torchfn(slice) for slice in input])
584  self.assertEqual(actual, expected, 'large')
585 
586  # compare large tensor vs. repeated small applications to expose
587  # possible parallelism bugs.
588  check_large(torch.double)
589  check_large(torch.float)
590 
591  def __test_math_by_name(self, function_name, mathfn, selffn):
592  mathfn = getattr(math, mathfn)
593  if selffn:
594  def torchfn(x):
595  return getattr(x, function_name)()
596  else:
597  torchfn = getattr(torch, function_name)
598  self._test_math(torchfn, mathfn, test_expand=(not selffn))
599 
600  def _test_math_by_name(self, function_name, test_self=True):
601  if test_self:
602  self.__test_math_by_name(function_name + "_", function_name, True)
603  self.__test_math_by_name(function_name, function_name, False)
604 
605  def test_sin(self):
606  self._test_math_by_name('sin')
607 
608  def test_sinh(self):
609  def sinh(x):
610  try:
611  return math.sinh(x)
612  except OverflowError:
613  return inf if x > 0 else -inf
614  self._test_math(torch.sinh, sinh)
615 
616  def test_lgamma(self):
617  def lgamma(x):
618  if x <= 0 and x == int(x):
619  return inf
620  return math.lgamma(x)
621  self._test_math(torch.lgamma, lgamma)
622 
623  @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
624  def test_mvlgamma(self):
625  from scipy.special import multigammaln
626  for d in range(1, 5):
627  input = torch.empty(10).uniform_(d, 10)
628  res_torch = torch.mvlgamma(input, d)
629  res_scipy = multigammaln(input.numpy(), d)
630  self.assertEqual(res_torch.numpy(), res_scipy)
631 
632  def test_mvlgamma_argcheck(self):
633  def run_test(d):
634  input = torch.linspace((d - 2) / 2, 10, 10)
635  torch.mvlgamma(input, d)
636 
637  with self.assertRaisesRegex(RuntimeError, "Condition for computing multivariate log-gamma not met"):
638  run_test(3)
639 
640  def _digamma_input(self, test_poles=True):
641  input = []
642  input.append((torch.randn(10).abs() + 1e-4).tolist())
643  input.append((torch.randn(10).abs() + 1e6).tolist())
644  zeros = torch.linspace(-9.5, -0.5, 10)
645  input.append(zeros.tolist())
646  input.append((zeros - 0.49).tolist())
647  input.append((zeros + 0.49).tolist())
648  input.append((zeros + (torch.rand(10) * 0.99) - 0.5).tolist())
649 
650  if test_poles:
651  input.append([-0.999999994, -1.999999994, -2.0000000111,
652  -100.99999994, -1931.99999994, 0.000000111,
653  -0.000000111, 0, -2, -329])
654  return input
655 
656  @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
657  def test_digamma(self):
658  from scipy.special import digamma
659 
660  # scipy 1.1.0 changed when it returns +/-inf vs. NaN
661  def torch_digamma_without_inf(inp):
662  res = torch.digamma(inp)
663  res[(res == -inf) | (res == inf)] = nan
664  return res
665 
666  def scipy_digamma_without_inf(inp):
667  res = digamma(inp)
668  if np.isscalar(res):
669  return res if np.isfinite(res) else nan
670  res[np.isinf(res)] = nan
671  return res
672 
673  self._test_math(torch_digamma_without_inf, scipy_digamma_without_inf, self._digamma_input())
674 
675  @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
676  def test_polygamma(self):
677  from scipy.special import polygamma
678  for n in [0, 1]:
679  self._test_math(lambda x: torch.polygamma(n, x),
680  lambda x: polygamma(n, x).item(),
681  self._digamma_input(test_poles=False))
682 
683  def test_asin(self):
684  self._test_math(torch.asin, lambda x: math.asin(x) if abs(x) <= 1 else nan)
685 
686  def test_cos(self):
687  self._test_math_by_name('cos')
688 
689  def test_cosh(self):
690  def cosh(x):
691  try:
692  return math.cosh(x)
693  except OverflowError:
694  # Return inf on overflow.
695  # See http://en.cppreference.com/w/cpp/numeric/math/cosh
696  return inf
697  self._test_math(torch.cosh, cosh)
698 
699  def test_acos(self):
700  self._test_math(torch.acos, lambda x: math.acos(x) if abs(x) <= 1 else nan)
701 
702  def test_tan(self):
703  self._test_math_by_name('tan')
704 
705  def test_tanh(self):
706  self._test_math_by_name('tanh')
707 
708  def test_atan(self):
709  self._test_math_by_name('atan')
710 
711  def test_log(self):
712  def log(x):
713  if x == 0:
714  return -inf
715  elif x < 0:
716  return nan
717  return math.log(x)
718  self._test_math(torch.log, log)
719 
720  def test_log10(self):
721  def log10(x):
722  if x == 0:
723  return -inf
724  elif x < 0:
725  return nan
726  return math.log10(x)
727  self._test_math(torch.log10, log10)
728 
729  def test_log1p(self):
730  def log1p(x):
731  if x == -1:
732  return -inf
733  elif x < -1:
734  return nan
735  return math.log1p(x)
736  self._test_math(torch.log1p, log1p)
737 
738  def test_log2(self):
739  def log2(x):
740  if x == 0:
741  return -inf
742  elif x < 0:
743  return nan
744  try:
745  return math.log2(x)
746  except AttributeError:
747  return math.log(x, 2)
748  self._test_math(torch.log2, log2)
749 
750  def test_sqrt(self):
751  self._test_math(torch.sqrt, lambda x: math.sqrt(x) if x >= 0 else nan)
752 
753  def test_erf(self):
754  self._test_math_by_name('erf')
755 
756  def test_erfc(self):
757  self._test_math_by_name('erfc')
758 
759  def test_erfinv(self):
760  def checkType(tensor):
761  inputValues = torch.randn(4, 4, out=tensor()).clamp(-2., 2.)
762  self.assertEqual(tensor(inputValues).erf().erfinv(), tensor(inputValues))
763  # test inf
764  self.assertTrue(torch.equal(tensor([-1, 1]).erfinv(), tensor([-inf, inf])))
765  # test nan
766  self.assertEqual(tensor([-2, 2]).erfinv(), tensor([nan, nan]))
767 
768  checkType(torch.FloatTensor)
769  checkType(torch.DoubleTensor)
770 
771  def test_exp(self):
772  def exp(x):
773  try:
774  return math.exp(x)
775  except OverflowError:
776  return inf
777  self._test_math(torch.exp, exp)
778 
779  def test_expm1(self):
780  def expm1(x):
781  try:
782  return math.expm1(x)
783  except OverflowError:
784  return inf
785  self._test_math(torch.expm1, expm1)
786 
787  def test_floor(self):
788  self._test_math_by_name('floor')
789 
790  def test_ceil(self):
791  self._test_math_by_name('ceil')
792 
793  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
794  def test_ceil_out_cpu_cuda(self):
795  a = torch.randn(1)
796  b = torch.randn(1, device="cuda")
797  self.assertRaises(RuntimeError, lambda: torch.ceil(a, out=b))
798 
799  def test_rsqrt(self):
800  def rsqrt(x):
801  if x == 0:
802  return inf
803  elif x < 0:
804  return nan
805  return 1.0 / math.sqrt(x)
806 
807  self._test_math(torch.rsqrt, rsqrt)
808 
809  def test_sigmoid(self):
810  # TODO: why not simulate math.sigmoid like with rsqrt?
811  inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000]
812  expectedOutput = [0.0000, 0.2689, 0.5, 0.6225, 0.7311, 0.8808, 1.000]
813  precision_4dps = 0.0002
814 
815  def checkType(tensor):
816  self.assertEqual(tensor(inputValues).sigmoid(), tensor(expectedOutput), precision_4dps)
817 
818  checkType(torch.FloatTensor)
819  checkType(torch.DoubleTensor)
820 
821  def test_frac(self):
822  self._test_math(torch.frac, lambda x: math.fmod(x, 1))
823 
824  def test_trunc(self):
825  self._test_math(torch.trunc, lambda x: x - math.fmod(x, 1))
826 
827  def test_round(self):
828  self._test_math(torch.round, round)
829 
830  def test_has_storage(self):
831  self.assertIsNotNone(torch.Tensor().storage())
832  self.assertIsNotNone(torch.Tensor(0).storage())
833  self.assertIsNotNone(torch.Tensor([]).storage())
834  self.assertIsNotNone(torch.Tensor().clone().storage())
835  self.assertIsNotNone(torch.Tensor([0, 0, 0]).nonzero().storage())
836  self.assertIsNotNone(torch.Tensor().new().storage())
837 
838  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
839  def test_has_storage_numpy(self):
840  for dtype in [np.float32, np.float64, np.int64,
841  np.int32, np.int16, np.uint8]:
842  arr = np.array([1], dtype=dtype)
843  self.assertIsNotNone(torch.FloatTensor(arr).storage())
844  self.assertIsNotNone(torch.DoubleTensor(arr).storage())
845  self.assertIsNotNone(torch.IntTensor(arr).storage())
846  self.assertIsNotNone(torch.LongTensor(arr).storage())
847  self.assertIsNotNone(torch.ByteTensor(arr).storage())
849  self.assertIsNotNone(torch.cuda.FloatTensor(arr).storage())
850  self.assertIsNotNone(torch.cuda.DoubleTensor(arr).storage())
851  self.assertIsNotNone(torch.cuda.IntTensor(arr).storage())
852  self.assertIsNotNone(torch.cuda.LongTensor(arr).storage())
853  self.assertIsNotNone(torch.cuda.ByteTensor(arr).storage())
854 
855  def _testSelection(self, torchfn, mathfn):
856  # contiguous
857  m1 = torch.randn(100, 100)
858  res1 = torchfn(m1)
859  res2 = m1[0, 0]
860  for i, j in iter_indices(m1):
861  res2 = mathfn(res2, m1[i, j])
862  self.assertEqual(res1, res2)
863 
864  # non-contiguous
865  m1 = torch.randn(10, 10, 10)
866  m2 = m1[:, 4]
867  res1 = torchfn(m2)
868  res2 = m2[0, 0]
869  for i, j in iter_indices(m2):
870  res2 = mathfn(res2, m2[i][j])
871  self.assertEqual(res1, res2)
872 
873  # with indices
874  m1 = torch.randn(100, 100)
875  res1val, res1ind = torchfn(m1, 1, False)
876  res2val = m1[:, 0:1].clone().squeeze()
877  res2ind = res1ind.clone().fill_(0)
878  for i, j in iter_indices(m1):
879  if mathfn(res2val[i], m1[i, j]) != res2val[i]:
880  res2val[i] = m1[i, j]
881  res2ind[i] = j
882 
883  maxerr = 0
884  for i in range(res1val.size(0)):
885  maxerr = max(maxerr, abs(res1val[i] - res2val[i]))
886  self.assertEqual(res1ind[i], res2ind[i])
887  self.assertLessEqual(abs(maxerr), 1e-5)
888 
889  # NaNs
890  for index in (0, 4, 99):
891  m1 = torch.randn(100)
892  m1[index] = nan
893  res1val, res1ind = torch.max(m1, 0)
894  self.assertTrue(math.isnan(res1val))
895  self.assertEqual(res1ind, index)
896  res1val = torchfn(m1)
897  self.assertTrue(math.isnan(res1val))
898 
899  def test_max(self):
900  self._testSelection(torch.max, max)
901 
902  @staticmethod
903  def _test_max_with_inf(self, dtypes=(torch.float, torch.double), device='cpu'):
904  for dtype in dtypes:
905  a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
906  self.assertTrue(torch.all(torch.max(a, dim=1)[0] == inf).item())
907  self.assertTrue(torch.max(a).item() == inf)
908 
909  def test_max_with_inf(self):
910  self._test_max_with_inf(self)
911 
912  def test_min(self):
913  self._testSelection(torch.min, min)
914 
915  @staticmethod
916  def _test_min_with_inf(self, dtypes=(torch.float, torch.double), device='cpu'):
917  for dtype in dtypes:
918  a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
919  self.assertTrue(torch.all(torch.min(a, dim=1)[0] == (-inf)).item())
920  self.assertTrue(torch.min(a).item() == -inf)
921 
922  def test_min_with_inf(self):
923  self._test_min_with_inf(self)
924 
925  @staticmethod
926  def _test_norm(self, device):
927  # full reduction
928  x = torch.randn(25, device=device)
929  xn = x.cpu().numpy()
930  for p in [0, 1, 2, 3, 4, inf, -inf]:
931  res = x.norm(p).item()
932  expected = np.linalg.norm(xn, p)
933  self.assertEqual(res, expected, "full reduction failed for {}-norm".format(p))
934 
935  # one dimension
936  x = torch.randn(25, 25, device=device)
937  xn = x.cpu().numpy()
938  for p in [0, 1, 2, 3, 4, inf, -inf]:
939  res = x.norm(p, 1).cpu().numpy()
940  expected = np.linalg.norm(xn, p, 1)
941  self.assertEqual(res.shape, expected.shape)
942  self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p))
943 
944  # matrix norm
945  for p in ['fro', 'nuc']:
946  res = x.norm(p).cpu().numpy()
947  expected = np.linalg.norm(xn, p)
948  self.assertEqual(res.shape, expected.shape)
949  self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p))
950 
951  # larger tensor sanity check
952  self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000)))
953 
954  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
955  @skipIfNoLapack
956  def test_norm(self):
957  self._test_norm(self, device='cpu')
958 
959  @staticmethod
960  def _test_dist(self, device):
961  def run_test(x, y):
962  for p in [0, 1, 2, 3, 4, inf, -inf]:
963  dist_xy = torch.dist(x, y, p)
964  dist_xy_norm = torch.norm(x - y, p)
965  self.assertEqual(dist_xy, dist_xy_norm)
966 
967  run_test(torch.randn(5, device=device), torch.randn(5, device=device))
968 
969  x = torch.zeros(3, device=device)
970  y = torch.zeros(3, device=device)
971  y[1] = 1.
972  run_test(x, y)
973 
974  def test_dist(self):
975  self._test_dist(self, device='cpu')
976 
977  def test_dim_reduction_uint8_overflow(self):
978  example = [[-1, 2, 1], [5, 3, 6]]
979  x = torch.tensor(example, dtype=torch.uint8)
980  self.assertEqual(x.sum(dtype=torch.uint8).item(), 16)
981  self.assertEqual(x.sum(0, dtype=torch.uint8), torch.FloatTensor([4, 5, 7]))
982  self.assertEqual(x.sum(1, dtype=torch.uint8), torch.FloatTensor([2, 14]))
983  y = torch.tensor(example, dtype=torch.uint8)
984  torch.sum(x, 0, out=y)
985  self.assertEqual(x.sum(0, dtype=torch.uint8), y)
986 
987  @staticmethod
988  def _test_dim_reduction(self, cast):
989  example = [[-1, 2, 1], [5, 3, 6]]
990 
991  types = [torch.double,
992  torch.float,
993  torch.int64,
994  torch.int32,
995  torch.int16]
996 
997  # This won't test for 256bit instructions, since we usually
998  # only work on 1 cacheline (1024bit) at a time and these
999  # examples aren't big enough to trigger that.
1000  for dtype in types:
1001  x = cast(torch.tensor(example, dtype=dtype))
1002  self.assertEqual(x.sum().item(), 16)
1003  self.assertEqual(x.sum(0), torch.FloatTensor([4, 5, 7]))
1004  self.assertEqual(x.sum(1), torch.FloatTensor([2, 14]))
1005  y = cast(torch.tensor(example, dtype=dtype))
1006  torch.sum(x, 0, out=y)
1007  self.assertEqual(x.sum(0), y)
1008 
1009  # Mean not supported for Int types
1010  for dtype in types[:2]:
1011  x = cast(torch.tensor(example, dtype=dtype))
1012  self.assertEqual(x.mean().item(), 16.0 / 6)
1013  self.assertEqual(x.mean(0), torch.FloatTensor([2.0, 2.5, 7.0 / 2]))
1014  self.assertEqual(x.mean(1), torch.FloatTensor([2.0 / 3, 14.0 / 3]))
1015  self.assertEqual(x.mean(), x.mean((0, 1)))
1016 
1017  for dtype in types:
1018  x = cast(torch.tensor(example, dtype=dtype))
1019  self.assertEqual(x.prod().item(), -180)
1020  self.assertEqual(x.prod(0), torch.FloatTensor([-5, 6, 6]))
1021  self.assertEqual(x.prod(1), torch.FloatTensor([-2, 90]))
1022 
1023  for dtype in types:
1024  x = cast(torch.tensor(example, dtype=dtype))
1025  self.assertEqual(x.max().item(), 6)
1026  self.assertEqual(x.max(0), (torch.FloatTensor([5, 3, 6]), torch.FloatTensor([1, 1, 1])))
1027  self.assertEqual(x.max(1), (torch.FloatTensor([2, 6]), torch.FloatTensor([1, 2])))
1028 
1029  for dtype in types:
1030  x = cast(torch.tensor(example, dtype=dtype))
1031  self.assertEqual(x.min().item(), -1)
1032  self.assertEqual(x.min(0), (torch.FloatTensor([-1, 2, 1]), torch.FloatTensor([0, 0, 0])))
1033  self.assertEqual(x.min(1), (torch.FloatTensor([-1, 3]), torch.FloatTensor([0, 1])))
1034 
1035  for dtype in types:
1036  x = cast(torch.tensor(example, dtype=dtype))
1037  self.assertEqual(x.argmax().item(), 5)
1038  self.assertEqual(x.argmax(dim=None).item(), 5)
1039  self.assertEqual(x.argmax(dim=0), torch.FloatTensor([1, 1, 1]))
1040  self.assertEqual(x.argmax(dim=1), torch.FloatTensor([1, 2]))
1041  self.assertEqual(x.argmax(dim=0, keepdim=True), torch.FloatTensor([[1, 1, 1]]))
1042  # test that non-contiguous tensors work
1043  self.assertEqual(x[:, :2].argmax().item(), 2)
1044 
1045  for dtype in types:
1046  x = cast(torch.tensor(example, dtype=dtype))
1047  self.assertEqual(x.argmin().item(), 0)
1048  self.assertEqual(x.argmin(dim=None).item(), 0)
1049  self.assertEqual(x.argmin(dim=0), torch.FloatTensor([0, 0, 0]))
1050  self.assertEqual(x.argmin(dim=1), torch.FloatTensor([0, 1]))
1051  self.assertEqual(x.argmin(dim=1, keepdim=True), torch.FloatTensor([[0], [1]]))
1052  # test that non-contiguous tensors work
1053  self.assertEqual(x[:, :2].argmin().item(), 0)
1054 
1055  dim_red_fns = [
1056  "mean", "median", "mode", "norm", "prod",
1057  "std", "sum", "var", "max", "min"]
1058 
1059  def normfn_attr(t, dim, keepdim=False, out=None):
1060  attr = getattr(torch, "norm")
1061  return attr(t, 2, dim, keepdim, out=out)
1062 
1063  for fn_name in dim_red_fns:
1064  fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
1065 
1066  def fn(x, dim, keepdim=False, out=None):
1067  ans = fn_attr(x, dim, keepdim=keepdim, out=out)
1068  return ans if not istuple(ans) else ans[0]
1069 
1070  def fn_tuple(x, dim, keepdim=False, out=None):
1071  return fn_attr(x, dim, keepdim=keepdim, out=out)
1072 
1073  def test_multidim(x, dim):
1074  self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
1075  self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
1076  self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())
1077 
1078  # general case
1079  x = cast(torch.randn(3, 4, 5))
1080  dim = random.randint(0, 2)
1081  test_multidim(x, dim)
1082 
1083  # check 1-d behavior
1084  x = cast(torch.randn(1))
1085  dim = 0
1086  self.assertEqual(fn(x, dim).shape, ())
1087  self.assertEqual(fn(x, dim, keepdim=True).shape, (1,))
1088 
1089  # check reducing of a singleton dimension
1090  dims = [3, 4, 5]
1091  singleton_dim = random.randint(0, 2)
1092  dims[singleton_dim] = 1
1093  x = cast(torch.randn(dims))
1094  test_multidim(x, singleton_dim)
1095 
1096  # check reducing with output kwargs
1097  if fn_name in ['median', 'mode', 'max', 'min']:
1098  y = cast(torch.randn(5, 3))
1099  values = cast(torch.randn(5, 3))
1100  indices = cast(torch.zeros(5, 3).long() - 1)
1101  fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1]))
1102  values_expected, indices_expected = fn_tuple(y, 1, keepdim=False)
1103  self.assertEqual(values[:, 1], values_expected,
1104  '{} values with out= kwarg'.format(fn_name))
1105  self.assertEqual(indices[:, 1], indices_expected,
1106  '{} indices with out= kwarg'.format(fn_name))
1107  continue
1108 
1109  x = cast(torch.randn(5, 3))
1110  y = cast(torch.randn(5, 3))
1111  fn(y, 1, keepdim=False, out=x[:, 1])
1112  expected = fn(y, 1, keepdim=False)
1113  self.assertEqual(x[:, 1], expected, '{} with out= kwarg'.format(fn_name))
1114 
1115  def test_dim_reduction(self):
1116  self._test_dim_reduction(self, lambda t: t)
1117 
1118  def test_reduction_empty(self):
1119  fns_to_test = [
1120  # name, function, identity
1121  ('max', torch.max, None),
1122  ('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None),
1123  ('argmax', torch.argmax, None),
1124  ('min', torch.min, None),
1125  ('argmin', torch.argmin, None),
1126  ('mode', torch.mode, None),
1127  ('median', torch.median, None),
1128 
1129  ('prod', torch.prod, 1),
1130  ('sum', torch.sum, 0),
1131  ('norm', torch.norm, 0),
1132  ('mean', torch.mean, nan),
1133  ('var', torch.var, nan),
1134  ('std', torch.std, nan),
1135  ('logsumexp', torch.logsumexp, -inf),
1136  ]
1137 
1138  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
1139  shape = (2, 0, 4)
1140  for device in devices:
1141  x = torch.randn(shape, device=device)
1142 
1143  for item in fns_to_test:
1144  name, fn, identity = item
1145  if identity is None:
1146  ident_err = 'does not have an identity'
1147  self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2))
1148  self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2, keepdim=True))
1149  self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1))
1150  self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1, keepdim=True))
1151  else:
1152  self.assertEqual(torch.empty((2, 0), device=device), fn(x, dim=2))
1153  self.assertEqual(torch.empty((2, 0, 1), device=device), fn(x, dim=2, keepdim=True))
1154  # assertEqual doesn't work with inf, -inf, nan and two tensors.
1155  check = (torch.testing.assert_allclose if math.isnan(identity) or math.isinf(identity) else
1156  self.assertEqual)
1157  check(torch.full((2, 4), identity, device=device), fn(x, dim=1))
1158  check(torch.full((2, 1, 4), identity, device=device), fn(x, dim=1, keepdim=True))
1159  try:
1160  check(torch.full((), identity, device=device), fn(x))
1161  except TypeError as err:
1162  # ignore if there is no allreduce.
1163  self.assertTrue('required positional arguments: "dim"' in str(err))
1164 
1165  # any
1166  xb = x.to(torch.uint8)
1167  yb = x.to(torch.uint8)
1168  self.assertEqual((2, 0), xb.any(2).shape)
1169  self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape)
1170  self.assertEqual(torch.zeros((2, 4), device=device), xb.any(1))
1171  self.assertEqual(torch.zeros((2, 1, 4), device=device), xb.any(1, keepdim=True))
1172  self.assertEqual(torch.zeros((), device=device), xb.any())
1173 
1174  # all
1175  self.assertEqual((2, 0), xb.all(2).shape)
1176  self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape)
1177  self.assertEqual(torch.ones((2, 4), device=device), xb.all(1))
1178  self.assertEqual(torch.ones((2, 1, 4), device=device), xb.all(1, keepdim=True))
1179  self.assertEqual(torch.ones((), device=device), xb.all())
1180 
1181  def test_pairwise_distance_empty(self):
1182  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
1183  for device in devices:
1184  shape = (2, 0)
1185  x = torch.randn(shape, device=device)
1186  y = torch.randn(shape, device=device)
1187 
1188  self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y))
1189  self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
1190 
1191  shape = (0, 2)
1192  x = torch.randn(shape, device=device)
1193  y = torch.randn(shape, device=device)
1194  self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y))
1195  self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
1196 
1197  def test_pdist_empty(self):
1198  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
1199  for device in devices:
1200  shape = (0, 2)
1201  x = torch.randn(shape, device=device)
1202  self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
1203 
1204  shape = (1, 2)
1205  x = torch.randn(shape, device=device)
1206  self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
1207 
1208  shape = (3, 0)
1209  x = torch.randn(shape, device=device)
1210  self.assertEqual(torch.zeros(3, device=device), torch.pdist(x))
1211 
1212  def test_pdist_norm(self):
1213  def test_pdist_single(shape, device, p, dtype, trans):
1214  x = torch.randn(shape, dtype=dtype, device=device)
1215  if trans:
1216  x.transpose_(-2, -1)
1217  actual = torch.pdist(x, p=p)
1218  expected = brute_pdist(x, p=p)
1219  self.assertEqual(expected.shape, actual.shape)
1220  self.assertTrue(torch.allclose(expected, actual))
1221 
1222  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
1223  for device in devices:
1224  for shape in [(4, 5), (3, 2), (2, 1)]:
1225  for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
1226  for trans in [False, True]:
1227  for dtype in [torch.float32, torch.float64]:
1228  test_pdist_single(shape, device, p, dtype, trans)
1229 
1230  # do a simplified comparison with big inputs, see:
1231  # https://github.com/pytorch/pytorch/issues/15511
1232  for dtype in [torch.float32, torch.float64]:
1233  test_pdist_single((1000, 2), device, 2, dtype, False)
1234 
1235  def test_cdist_empty(self):
1236  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
1237  for device in devices:
1238  x = torch.randn((0, 5), device=device)
1239  y = torch.randn((4, 5), device=device)
1240  self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
1241 
1242  x = torch.randn((2, 5), device=device)
1243  y = torch.randn((0, 5), device=device)
1244  self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
1245 
1246  x = torch.randn((2, 0), device=device)
1247  y = torch.randn((3, 0), device=device)
1248  self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y))
1249 
1250  x = torch.randn((2, 0), device=device)
1251  y = torch.randn((0, 0), device=device)
1252  self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
1253 
1254  def test_cdist_norm(self):
1255  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
1256  for device in devices:
1257  for r1 in [3, 4, 5, 6]:
1258  for m in [2, 3, 4, 10]:
1259  for r2 in [4, 6, 7, 8]:
1260  for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
1261  x = torch.randn(r1, m, device=device)
1262  y = torch.randn(r2, m, device=device)
1263  actual = torch.cdist(x, y, p=p)
1264  expected = brute_cdist(x, y, p=p)
1265  self.assertTrue(torch.allclose(expected, actual))
1266 
1267  def test_cdist_large(self):
1268  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
1269  for device in devices:
1270  x = torch.randn(1000, 10, device=device)
1271  y = torch.randn(1000, 10, device=device)
1272  actual = torch.cdist(x, y, p=2)
1273  expected = brute_cdist(x, y, p=2)
1274  self.assertTrue(torch.allclose(expected, actual))
1275 
1276  def test_cdist_non_contiguous(self):
1277  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
1278  for device in devices:
1279  x = torch.randn(5, 7, device=device).t()
1280  y = torch.randn(5, 3, device=device).t()
1281  actual = torch.cdist(x, y, p=2)
1282  expected = brute_cdist(x, y, p=2)
1283  self.assertFalse(x.is_contiguous())
1284  self.assertFalse(y.is_contiguous())
1285  self.assertTrue(torch.allclose(expected, actual))
1286 
1287  x = torch.randn(7, 5, device=device)
1288  y = torch.randn(5, 3, device=device).t()
1289  actual = torch.cdist(x, y, p=2)
1290  expected = brute_cdist(x, y, p=2)
1291  self.assertTrue(x.is_contiguous())
1292  self.assertFalse(y.is_contiguous())
1293  self.assertTrue(torch.allclose(expected, actual))
1294 
1295  x = torch.randn(5, 7, device=device).t()
1296  y = torch.randn(3, 5, device=device)
1297  actual = torch.cdist(x, y, p=2)
1298  expected = brute_cdist(x, y, p=2)
1299  self.assertFalse(x.is_contiguous())
1300  self.assertTrue(y.is_contiguous())
1301  self.assertTrue(torch.allclose(expected, actual))
1302 
1303  @unittest.skipIf(not TEST_SCIPY, "Scipy not found")
1304  def test_logsumexp(self):
1305  from scipy.special import logsumexp
1306  a = torch.randn(5, 4)
1307  a[0, 0] = inf
1308  a[1, :] = -inf
1309  actual = a.logsumexp(1)
1310  expected = logsumexp(a.numpy(), 1)
1311  self.assertEqual(expected.shape, actual.shape)
1312  self.assertTrue(np.allclose(expected, actual.numpy()))
1313  # check that out is actually inplace
1314  b = torch.zeros(5, 2)
1315  c = b[:, 0]
1316  torch.logsumexp(a, 1, out=c)
1317  self.assertTrue(np.allclose(expected, b[:, 0].numpy()))
1318 
1319  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1320  def test_cpu_parallel(self):
1321  # To use parallel branches we'll need to compare on tensors
1322  # that are relatively large. Even if this is run on a single
1323  # core machine these tests will still give you signal on
1324  # the correctness
1325 
1326  def _run_test(size):
1327  for dim in range(len(size) + 1):
1328  nv = np.round(np.random.rand(*size)) # 0s and 1s
1329  tv = torch.from_numpy(nv)
1330  # Parallelisim is only used if numel is
1331  # larger than grainsize defined in Parallel.h
1332  self.assertTrue(tv.numel() > 32768)
1333  if dim == len(size):
1334  nvs = nv.sum()
1335  tvs = tv.sum()
1336  else:
1337  nvs = nv.sum(dim)
1338  tvs = tv.sum(dim)
1339  diff = np.abs(nvs - tvs.numpy()).sum()
1340  self.assertEqual(diff, 0)
1341 
1342  _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3])
1343  _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
1344  _run_test([1, 32 * 8 * 32 * 8])
1345  _run_test([1, 32770])
1346 
1347  def _testCSelection(self, torchfn, mathfn):
1348  # Two tensors
1349  size = (100, 100)
1350  a = torch.rand(*size)
1351  b = torch.rand(*size)
1352  c = torchfn(a, b)
1353  expected_c = torch.zeros(*size)
1354  expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b))
1355  self.assertEqual(expected_c, c, 0)
1356 
1357  def test_max_elementwise(self):
1358  self._testCSelection(torch.max, max)
1359 
1360  def test_min_elementwise(self):
1361  self._testCSelection(torch.min, min)
1362 
1363  @staticmethod
1364  def _test_lerp(self, cast):
1365  start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)]
1366  for shapes in product(start_end_shapes, start_end_shapes):
1367  start = cast(torch.randn(shapes[0]))
1368  end = cast(torch.randn(shapes[1]))
1369 
1370  # Tensor weights
1371  for weight in [cast(torch.randn(shapes[0])), random.random()]:
1372  actual = torch.lerp(start, end, weight)
1373  actual_method = start.lerp(end, weight)
1374  self.assertEqual(actual, actual_method)
1375  actual_out = cast(torch.Tensor())
1376  torch.lerp(start, end, weight, out=actual_out)
1377  self.assertEqual(actual, actual_out)
1378  expected = start + weight * (end - start)
1379  self.assertEqual(expected, actual)
1380 
1381  def test_lerp(self):
1382  self._test_lerp(self, lambda t: t)
1383 
1384  def test_all_any(self):
1385  def test(size):
1386  x = torch.ones(*size).byte()
1387  self.assertTrue(x.all())
1388  self.assertTrue(x.any())
1389 
1390  x[3] = 0
1391  self.assertFalse(x.all())
1392  self.assertTrue(x.any())
1393 
1394  x.zero_()
1395  self.assertFalse(x.all())
1396  self.assertFalse(x.any())
1397 
1398  x.fill_(2)
1399  self.assertTrue(x.all())
1400  self.assertTrue(x.any())
1401 
1402  test((10,))
1403  test((5, 5))
1404 
1405  def test_all_any_empty(self):
1406  x = torch.ByteTensor()
1407  self.assertTrue(x.all())
1408  self.assertFalse(x.any())
1409 
1410  def test_all_any_with_dim(self):
1411  def test(x):
1412  r1 = x.prod(dim=0, keepdim=False).byte()
1413  r2 = x.all(dim=0, keepdim=False)
1414  self.assertEqual(r1.shape, r2.shape)
1415  self.assertTrue((r1 == r2).all())
1416 
1417  r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte()
1418  r4 = x.any(dim=1, keepdim=True)
1419  self.assertEqual(r3.shape, r4.shape)
1420  self.assertTrue((r3 == r4).all())
1421 
1422  test(torch.ByteTensor([[0, 0, 0],
1423  [0, 0, 1],
1424  [0, 1, 1],
1425  [1, 1, 1]]))
1426 
1427  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
1428  def test_all_any_empty_cuda(self):
1429  x = torch.cuda.ByteTensor()
1430  self.assertTrue(x.all())
1431  self.assertFalse(x.any())
1432 
1433  def test_mv(self):
1434  m1 = torch.randn(100, 100)
1435  v1 = torch.randn(100)
1436 
1437  res1 = torch.mv(m1, v1)
1438  res2 = res1.clone().zero_()
1439  for i, j in iter_indices(m1):
1440  res2[i] += m1[i][j] * v1[j]
1441 
1442  self.assertEqual(res1, res2)
1443 
1444  def test_add(self):
1445  # [res] torch.add([res,] tensor1, tensor2)
1446  m1 = torch.randn(100, 100)
1447  v1 = torch.randn(100)
1448 
1449  # contiguous
1450  res1 = torch.add(m1[4], v1)
1451  res2 = res1.clone().zero_()
1452  for i in range(m1.size(1)):
1453  res2[i] = m1[4, i] + v1[i]
1454  self.assertEqual(res1, res2)
1455 
1456  m1 = torch.randn(100, 100)
1457  v1 = torch.randn(100)
1458 
1459  # non-contiguous
1460  res1 = torch.add(m1[:, 4], v1)
1461  res2 = res1.clone().zero_()
1462  for i in range(m1.size(0)):
1463  res2[i] = m1[i, 4] + v1[i]
1464  self.assertEqual(res1, res2)
1465 
1466  # [res] torch.add([res,] tensor, value)
1467  m1 = torch.randn(10, 10)
1468 
1469  # contiguous
1470  res1 = m1.clone()
1471  res1[3].add_(2)
1472  res2 = m1.clone()
1473  for i in range(m1.size(1)):
1474  res2[3, i] = res2[3, i] + 2
1475  self.assertEqual(res1, res2)
1476 
1477  # non-contiguous
1478  m1 = torch.randn(10, 10)
1479  res1 = m1.clone()
1480  res1[:, 3].add_(2)
1481  res2 = m1.clone()
1482  for i in range(m1.size(0)):
1483  res2[i, 3] = res2[i, 3] + 2
1484  self.assertEqual(res1, res2)
1485 
1486  # inter-type
1487  m1 = torch.randn(10, 10)
1488  self.assertEqual(m1 + 3, m1 + torch.tensor(3))
1489  self.assertEqual(3 + m1, torch.tensor(3) + m1)
1490  one = torch.tensor(1, dtype=torch.uint8)
1491  self.assertEqual(torch.add(one, 1), 2)
1492  self.assertEqual(torch.add(one, 1).dtype, torch.uint8)
1493 
1494  # contiguous + non-contiguous
1495  m1 = torch.randn(10, 10)
1496  m2 = torch.randn(10, 10).t()
1497  res = m1 + m2
1498  self.assertTrue(res.is_contiguous())
1499  self.assertEqual(res, m1 + m2.contiguous())
1500 
1501  # 1d + empty
1502  m1 = torch.tensor([1.0], dtype=torch.float)
1503  m2 = torch.tensor([], dtype=torch.float)
1504  self.assertEqual(m1 + m2, [])
1505 
1506  # [res] torch.add([res,] tensor1, value, tensor2)
1507 
1508  def test_csub(self):
1509  # with a tensor
1510  a = torch.randn(100, 90)
1511  b = a.clone().normal_()
1512 
1513  res_add = torch.add(a, -1, b)
1514  res_csub = a.clone()
1515  res_csub.sub_(b)
1516  self.assertEqual(res_add, res_csub)
1517 
1518  # with a scalar
1519  a = torch.randn(100, 100)
1520 
1521  scalar = 123.5
1522  res_add = torch.add(a, -scalar)
1523  res_csub = a.clone()
1524  res_csub.sub_(scalar)
1525  self.assertEqual(res_add, res_csub)
1526 
1527  @staticmethod
1528  def _test_neg(self, cast):
1529  float_types = [torch.DoubleTensor, torch.FloatTensor, torch.LongTensor]
1530  int_types = [torch.IntTensor, torch.ShortTensor, torch.ByteTensor,
1531  torch.CharTensor]
1532 
1533  for t in float_types + int_types:
1534  if t in float_types:
1535  a = cast(torch.randn(100, 90).type(t))
1536  else:
1537  a = cast(torch.randint(-128, 128, (100, 90), dtype=t.dtype))
1538  zeros = cast(torch.Tensor().type(t)).resize_as_(a).zero_()
1539 
1540  if t == torch.ByteTensor:
1541  res_add = torch.add(zeros, a, alpha=255)
1542  else:
1543  res_add = torch.add(zeros, a, alpha=-1)
1544  res_neg = a.clone()
1545  res_neg.neg_()
1546  self.assertEqual(res_neg, res_add)
1547 
1548  # test out of place as well
1549  res_neg_out_place = a.clone().neg()
1550  self.assertEqual(res_neg_out_place, res_add)
1551 
1552  # test via __neg__ operator
1553  res_neg_op = -a.clone()
1554  self.assertEqual(res_neg_op, res_add)
1555 
1556  def test_neg(self):
1557  self._test_neg(self, lambda t: t)
1558 
1559  def test_threshold(self):
1560  for dtype in torch.testing.get_all_dtypes():
1561  if dtype != torch.uint8 and dtype != torch.float16:
1562  # 100 is wide enough to use AVX2 instructions for all types
1563  x = torch.randn(100).sign().to(dtype=dtype)
1564  y = torch.threshold(x, 0, 0)
1565  self.assertTrue(y.le(0).any())
1566 
1567  def test_reciprocal(self):
1568  a = torch.randn(100, 89)
1569  res_div = 1 / a
1570  res_reciprocal = a.clone()
1571  res_reciprocal.reciprocal_()
1572  self.assertEqual(res_reciprocal, res_div)
1573 
1574  def test_mul(self):
1575  m1 = torch.randn(10, 10)
1576  res1 = m1.clone()
1577  res1[:, 3].mul_(2)
1578  res2 = m1.clone()
1579  for i in range(res1.size(0)):
1580  res2[i, 3] = res2[i, 3] * 2
1581  self.assertEqual(res1, res2)
1582 
1583  def test_div(self):
1584  m1 = torch.randn(10, 10)
1585  res1 = m1.clone()
1586  res1[:, 3].div_(2)
1587  res2 = m1.clone()
1588  for i in range(m1.size(0)):
1589  res2[i, 3] = res2[i, 3] / 2
1590  self.assertEqual(res1, res2)
1591 
1592  def test_floordiv(self):
1593  for dtype in torch.testing.get_all_dtypes():
1594  if dtype is torch.float16:
1595  continue
1596  x = torch.randn(100).mul(10).to(dtype)
1597  y = x // 3
1598  self.assertEqual(y.dtype, x.dtype)
1599  z = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=y.dtype)
1600  self.assertEqual(y, z)
1601 
1602  def test_rdiv(self):
1603  for dtype in torch.testing.get_all_dtypes():
1604  if dtype is torch.float16:
1605  continue
1606  x = torch.rand(100).add(1).mul(4).to(dtype)
1607  y = 30 / x
1608  if dtype.is_floating_point:
1609  z = torch.tensor([30 / v.item() for v in x], dtype=dtype)
1610  else:
1611  z = torch.tensor([math.trunc(30. / v.item()) for v in x], dtype=dtype)
1612  self.assertEqual(y, z)
1613 
1614  def test_fmod(self):
1615  m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
1616  res1 = m1.clone()
1617  q = 2.1
1618  res1[:, 3].fmod_(q)
1619  res2 = m1.clone()
1620  for i in range(m1.size(1)):
1621  res2[i, 3] = math.fmod(res2[i, 3], q)
1622  self.assertEqual(res1, res2)
1623 
1624  def test_remainder(self):
1625  # Check the Floating point case, both tensor and scalar overloads
1626  for use_item in [True, False]:
1627  m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
1628  res1 = m1.clone()
1629  res2 = m1.clone()
1630  qs = torch.arange(-5.1, 4.1)
1631  # Check the case where the divisor is a simple float
1632  for col_idx, q in enumerate(qs):
1633  # Reference
1634  for i in range(m1.size(0)):
1635  res2[i, col_idx] = res2[i, col_idx] % q
1636  # To test
1637  res1[:, col_idx].remainder_(q if not use_item else q.item())
1638  self.assertEqual(res1, res2)
1639  # Check the case where the divisor is a tensor
1640  res1 = m1.clone()
1641  res1.remainder_(qs.unsqueeze(0).expand_as(res1))
1642  self.assertEqual(res1, res2)
1643 
1644  # Check the LongTensor case, both tensor and scalar overloads
1645  for use_item in [True, False]:
1646  long_m1 = torch.LongTensor(10, 10).random_(-10, 10)
1647  long_res1 = long_m1.clone()
1648  long_res2 = long_m1.clone()
1649  long_qs = torch.arange(-5, 5)
1650  long_qs[5] = 5 # Can't handle the divisor=0 case
1651  for col_idx, long_q in enumerate(long_qs):
1652  # Reference
1653  for i in range(long_m1.size(0)):
1654  long_res2[i, col_idx] = long_res2[i, col_idx] % long_q
1655  # To test
1656  long_res1[:, col_idx].remainder_(long_q if not use_item else long_q.item())
1657  self.assertEqual(long_res1, long_res2)
1658  # Divisor is a tensor case
1659  long_res1 = long_m1.clone()
1660  long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))
1661 
1662  @staticmethod
1663  def _test_remainder_overflow(self, dtype, device):
1664  # Check Integer Overflows
1665  x = torch.tensor(23500, dtype=dtype, device=device)
1666  q = 392486996410368
1667  self.assertEqual(x % q, x)
1668  self.assertEqual(-x % q, q - x)
1669  self.assertEqual(x % -q, x - q)
1670  self.assertEqual(-x % -q, -x)
1671 
1672  def test_remainder_overflow(self):
1673  self._test_remainder_overflow(self, dtype=torch.int64, device='cpu')
1674 
1675  def test_mm(self):
1676  def _test_mm(n, m, p, dtype, genf):
1677  # helper function
1678  def matrixmultiply(mat1, mat2):
1679  n = mat1.size(0)
1680  m = mat1.size(1)
1681  p = mat2.size(1)
1682  res = torch.zeros(n, p, dtype=dtype)
1683  for i, j in iter_indices(res):
1684  res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
1685  return res
1686 
1687  # contiguous case
1688  mat1 = genf(n, m)
1689  mat2 = genf(m, p)
1690  res = torch.mm(mat1, mat2)
1691 
1692  res2 = matrixmultiply(mat1, mat2)
1693  self.assertEqual(res, res2)
1694 
1695  # non contiguous case 1
1696  mat1 = genf(n, m)
1697  mat2 = genf(p, m).t()
1698  res = torch.mm(mat1, mat2)
1699 
1700  res2 = matrixmultiply(mat1, mat2)
1701  self.assertEqual(res, res2)
1702 
1703  # non contiguous case 2
1704  mat1 = genf(m, n).t()
1705  mat2 = genf(m, p)
1706  res = torch.mm(mat1, mat2)
1707 
1708  res2 = matrixmultiply(mat1, mat2)
1709  self.assertEqual(res, res2)
1710 
1711  # non contiguous case 3
1712  mat1 = genf(m, n).t()
1713  mat2 = genf(p, m).t()
1714  res = torch.mm(mat1, mat2)
1715 
1716  res2 = matrixmultiply(mat1, mat2)
1717  self.assertEqual(res, res2)
1718 
1719  # test with zero stride
1720  mat1 = genf(n, m)
1721  mat2 = genf(m, 1).expand(m, p)
1722  res = torch.mm(mat1, mat2)
1723 
1724  res2 = matrixmultiply(mat1, mat2)
1725  self.assertEqual(res, res2)
1726 
1727  # explicitly exercise the _out variant in torch.mm().
1728  # contiguous case
1729  mat1 = genf(n, m)
1730  mat2 = genf(m, p)
1731  res = genf(n, p)
1732  torch.mm(mat1, mat2, out=res)
1733 
1734  res2 = matrixmultiply(mat1, mat2)
1735  self.assertEqual(res, res2)
1736 
1737  # explicitly exercise the _out variant in torch.mm().
1738  # non contiguous case 3
1739  mat1 = genf(m, n).t()
1740  mat2 = genf(p, m).t()
1741  res = genf(n, p)
1742  torch.mm(mat1, mat2, out=res)
1743 
1744  res2 = matrixmultiply(mat1, mat2)
1745  self.assertEqual(res, res2)
1746 
1747  for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]:
1748  _test_mm(n, m, p, torch.float32, lambda x, y: torch.randn(x, y, dtype=torch.float32))
1749  _test_mm(n, m, p, torch.float64, lambda x, y: torch.randn(x, y, dtype=torch.float64))
1750  _test_mm(n, m, p, torch.int32, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int32))
1751  _test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64))
1752 
1753  @staticmethod
1754  def _test_btrifact(self, cast):
1755  from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
1756 
1757  def run_test(matrix_size, batches, cast):
1758  a = cast(fullrank(matrix_size, *batches))
1759  a_LU_info, pivots_info, info_ = a.btrifact_with_info()
1760  self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size)))
1761  self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,)))
1762  self.assertEqual(info_.size(), torch.Size(batches))
1763  self.assertEqual(info_.abs().sum(), 0)
1764  a_LU, pivots = a.btrifact()
1765  self.assertEqual(a_LU, a_LU_info)
1766  self.assertEqual(pivots_info, pivots)
1767  if a.is_cuda:
1768  a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=False)
1769  self.assertIsNone(nopiv)
1770  self.assertEqual(info_, info_nopiv)
1771  P, L, U = torch.btriunpack(a_LU, pivots)
1772  self.assertEqual(P.matmul(L.matmul(U)), a)
1773 
1774  for ms, batch in product([3, 5, 7], [(2,), (3,), (3, 5)]):
1775  run_test(ms, batch, cast)
1776 
1777  # Info should be positive for rank deficient matrices
1778  a = cast(fullrank(3, 5))
1779  if not (a.is_cuda and any(x in torch.version.cuda for x in ['8.0', '9.2'])):
1780  a[0, 1] = 2 * a[0, 0] # Row 2 of a[0] is 2 times Row 1 of a[0], thereby causing a rank deficiency
1781  self.assertGreater(a.btrifact_with_info()[2][0], 0)
1782 
1783  # Error checking, no pivoting variant on CPU
1784  with self.assertRaisesRegex(RuntimeError,
1785  'btrifact without pivoting is not implemented on the CPU'):
1786  torch.btrifact(torch.empty(1, 2, 2), pivot=False)
1787 
1788  @skipIfNoLapack
1789  @skipIfRocm
1790  def test_btrifact(self):
1791  self._test_btrifact(self, lambda t: t)
1792 
1793  @staticmethod
1794  def _test_btrisolve(self, cast):
1795  a = torch.FloatTensor((((1.3722, -0.9020),
1796  (1.8849, 1.9169)),
1797  ((0.7187, -1.1695),
1798  (-0.0139, 1.3572)),
1799  ((-1.6181, 0.7148),
1800  (1.3728, 0.1319))))
1801  b = torch.FloatTensor(((4.02, 6.19),
1802  (-1.56, 4.00),
1803  (9.81, -4.09)))
1804  a, b = cast(a), cast(b)
1805  LU_data, pivots, info = a.btrifact_with_info()
1806  self.assertEqual(info.abs().sum(), 0)
1807  x = torch.btrisolve(b, LU_data, pivots)
1808  b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
1809  self.assertEqual(b_, b)
1810 
1811  @skipIfNoLapack
1812  def test_btrisolve(self):
1813  self._test_btrisolve(self, lambda t: t)
1814 
1815  @staticmethod
1816  def _test_btriunpack(self, cast):
1817  def run_test(shape, cast):
1818  a = cast(torch.randn(*shape))
1819  a_lu, p = torch.btrifact(a.reshape(-1, shape[-1], shape[-1]))
1820  a_lu = a_lu.reshape_as(a)
1821  p = p.reshape(a.shape[:-1])
1822  p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p)
1823  self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
1824 
1825  run_test((5, 3, 3), cast)
1826  run_test((7, 3, 5, 5), cast)
1827  run_test((7, 5, 3, 3, 3), cast)
1828 
1829  @skipIfNoLapack
1830  def test_btriunpack(self):
1831  self._test_btriunpack(self, lambda t: t)
1832 
1833  def test_bmm(self):
1834  num_batches = 10
1835  M, N, O = 23, 8, 12
1836  b1 = torch.randn(num_batches, M, N)
1837  b2 = torch.randn(num_batches, N, O)
1838  res = torch.bmm(b1, b2)
1839  for i in range(num_batches):
1840  r = torch.mm(b1[i], b2[i])
1841  self.assertEqual(r, res[i])
1843  # check that mixed arguments are rejected
1844  self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cuda()))
1845  self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cuda(), b2))
1846 
1847  def test_addbmm(self):
1848  # num_batches = 10
1849  # M, N, O = 12, 8, 5
1850  num_batches = 2
1851  M, N, O = 2, 3, 4
1852  b1 = torch.randn(num_batches, M, N)
1853  b2 = torch.randn(num_batches, N, O)
1854  res = torch.bmm(b1, b2)
1855  res2 = torch.Tensor().resize_as_(res[0]).zero_()
1856 
1857  res2.addbmm_(b1, b2)
1858  self.assertEqual(res2, res.sum(0, False))
1859 
1860  res2.addbmm_(1, b1, b2)
1861  self.assertEqual(res2, res.sum(0, False) * 2)
1862 
1863  res2.addbmm_(1., .5, b1, b2)
1864  self.assertEqual(res2, res.sum(0, False) * 2.5)
1865 
1866  res3 = torch.addbmm(1, res2, 0, b1, b2)
1867  self.assertEqual(res3, res2)
1868 
1869  res4 = torch.addbmm(1, res2, .5, b1, b2)
1870  self.assertEqual(res4, res.sum(0, False) * 3)
1871 
1872  res5 = torch.addbmm(0, res2, 1, b1, b2)
1873  self.assertEqual(res5, res.sum(0, False))
1874 
1875  res6 = torch.addbmm(.1, res2, .5, b1, b2)
1876  self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5))
1877 
1878  def test_baddbmm(self):
1879  num_batches = 10
1880  M, N, O = 12, 8, 5
1881  b1 = torch.randn(num_batches, M, N)
1882  b2 = torch.randn(num_batches, N, O)
1883  res = torch.bmm(b1, b2)
1884  res2 = torch.Tensor().resize_as_(res).zero_()
1885 
1886  res2.baddbmm_(b1, b2)
1887  self.assertEqual(res2, res)
1888 
1889  res2.baddbmm_(1, b1, b2)
1890  self.assertEqual(res2, res * 2)
1891 
1892  res2.baddbmm_(1, .5, b1, b2)
1893  self.assertEqual(res2, res * 2.5)
1894 
1895  res3 = torch.baddbmm(1, res2, 0, b1, b2)
1896  self.assertEqual(res3, res2)
1897 
1898  res4 = torch.baddbmm(1, res2, .5, b1, b2)
1899  self.assertEqual(res4, res * 3)
1900 
1901  res5 = torch.baddbmm(0, res2, 1, b1, b2)
1902  self.assertEqual(res5, res)
1903 
1904  res6 = torch.baddbmm(.1, res2, .5, b1, b2)
1905  self.assertEqual(res6, res2 * .1 + res * .5)
1906 
1907  @staticmethod
1908  def _test_clamp(self, device='cpu'):
1909  m1 = torch.rand(100, device=device).mul(5).add(-2.5) # uniform in [-2.5, 2.5]
1910  # just in case we're extremely lucky.
1911  min_val = -1
1912  max_val = 1
1913  m1[1] = min_val
1914  m1[2] = max_val
1915 
1916  res1 = m1.clone()
1917  res1.clamp_(min_val, max_val)
1918  res2 = m1.clone()
1919  for i in iter_indices(res2):
1920  res2[i] = max(min_val, min(max_val, res2[i]))
1921  self.assertEqual(res1, res2)
1922 
1923  out = m1.clone()
1924  torch.clamp(m1, min=min_val, max=max_val, out=out)
1925  self.assertEqual(out, res1)
1926 
1927  res1 = torch.clamp(m1, min=min_val)
1928  res2 = m1.clone()
1929  for i in iter_indices(res2):
1930  res2[i] = max(min_val, res2[i])
1931  self.assertEqual(res1, res2)
1932 
1933  torch.clamp(m1, min=min_val, out=out)
1934  self.assertEqual(out, res1)
1935 
1936  res1 = torch.clamp(m1, max=max_val)
1937  res2 = m1.clone()
1938  for i in iter_indices(res2):
1939  res2[i] = min(max_val, res2[i])
1940  self.assertEqual(res1, res2)
1941 
1942  torch.clamp(m1, max=max_val, out=out)
1943  self.assertEqual(out, res1)
1944 
1945  # if the tensor contains nan case
1946  test_tens = torch.tensor([nan], device=device)
1947 
1948  res1 = test_tens.clone()
1949  res1.clamp_(min_val, max_val)
1950  res2 = test_tens.clone()
1951  for i in iter_indices(res2):
1952  res2[i] = max(min(res2[i], max_val), min_val)
1953  self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1954 
1955  out = test_tens.clone()
1956  torch.clamp(test_tens, min=min_val, max=max_val, out=out)
1957  self.assertEqual(torch.isnan(out), torch.isnan(res1))
1958 
1959  res1 = torch.clamp(test_tens, min=min_val)
1960  res2 = test_tens.clone()
1961  for i in iter_indices(res2):
1962  res2[i] = max(res2[i], min_val)
1963  self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1964 
1965  torch.clamp(test_tens, min=min_val, out=out)
1966  self.assertEqual(torch.isnan(out), torch.isnan(res1))
1967 
1968  res1 = torch.clamp(test_tens, max=max_val)
1969  res2 = test_tens.clone()
1970  for i in iter_indices(res2):
1971  res2[i] = min(res2[i], max_val)
1972  self.assertEqual(torch.isnan(res1), torch.isnan(res2))
1973 
1974  torch.clamp(test_tens, max=max_val, out=out)
1975  self.assertEqual(torch.isnan(out), torch.isnan(res1))
1976 
1977  error_msg = 'At least one of \'min\' or \'max\' must not be None'
1978  with self.assertRaisesRegex(RuntimeError, error_msg):
1979  m1.clamp()
1980  with self.assertRaisesRegex(RuntimeError, error_msg):
1981  m1.clamp_()
1982 
1983  def test_clamp(self):
1984  self._test_clamp(self)
1985 
1986  def test_pow(self):
1987  # [res] torch.pow([res,] x)
1988 
1989  # pow has dedicated implementation for different exponents
1990  for exponent in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]:
1991  # base - tensor, exponent - number
1992  # contiguous
1993  m1 = torch.rand(100, 100) + 0.5
1994  res1 = torch.pow(m1[4], exponent)
1995  res2 = res1.clone().zero_()
1996  for i in range(res2.size(0)):
1997  res2[i] = math.pow(m1[4][i], exponent)
1998  self.assertEqual(res1, res2)
1999 
2000  # non-contiguous
2001  m1 = torch.rand(100, 100) + 0.5
2002  res1 = torch.pow(m1[:, 4], exponent)
2003  res2 = res1.clone().zero_()
2004  for i in range(res2.size(0)):
2005  res2[i] = math.pow(m1[i, 4], exponent)
2006  self.assertEqual(res1, res2)
2007 
2008  # base - number, exponent - tensor
2009  # contiguous
2010  m1 = torch.randn(100, 100)
2011  res1 = torch.pow(3, m1[4])
2012  res2 = res1.clone().zero_()
2013  for i in range(res2.size(0)):
2014  res2[i] = math.pow(3, m1[4, i])
2015  self.assertEqual(res1, res2)
2016 
2017  # non-contiguous
2018  m1 = torch.randn(100, 100)
2019  res1 = torch.pow(3, m1[:, 4])
2020  res2 = res1.clone().zero_()
2021  for i in range(res2.size(0)):
2022  res2[i] = math.pow(3, m1[i][4])
2023  self.assertEqual(res1, res2)
2024 
2025  @staticmethod
2026  def _test_rpow(self, cast):
2027  m = cast(torch.randn(10, 10))
2028  self.assertEqual(torch.pow(2, m), 2**m)
2029 
2030  # test with scalar
2031  m = cast(torch.randn(1).squeeze())
2032  assert m.dim() == 0, "m is intentionally a scalar"
2033  self.assertEqual(torch.pow(2, m), 2**m)
2034 
2035  def test_rpow(self):
2036  self._test_rpow(self, lambda x: x)
2037 
2038  @staticmethod
2039  def _test_int_pow(self, cast):
2040  if not TEST_NUMPY:
2041  return
2042  import numpy as np
2043 
2044  def check_against_np(tensor, exp):
2045  tensor_np = tensor.cpu().numpy()
2046  exp_np = exp if isinstance(exp, int) else exp.cpu().numpy()
2047  expected = torch.LongTensor(tensor_np ** exp_np).type_as(tensor)
2048  self.assertEqual(torch.pow(tensor, exp), expected)
2049  self.assertEqual(tensor.pow(exp), torch.pow(tensor, exp))
2050 
2051  typecasts = [
2052  lambda x: x.long(),
2053  lambda x: x.short(),
2054  lambda x: x.byte(),
2055  ]
2056 
2057  if not IS_WINDOWS:
2058  typecasts.append(lambda x: x.int())
2059 
2060  shape = (11, 5)
2061  tensor = cast(torch.LongTensor(shape).random_(-10, 10))
2062  exps = [0, 1, 2, 5, cast(torch.LongTensor(shape).random_(0, 20))]
2063 
2064  for typecast in typecasts:
2065  for exp in exps:
2066  t = typecast(tensor)
2067  e = exp if isinstance(exp, int) else typecast(exp)
2068  check_against_np(t, e)
2069 
2070  def test_int_pow(self):
2071  self._test_int_pow(self, lambda x: x)
2072 
2073  def _test_cop(self, torchfn, mathfn):
2074  def reference_implementation(res2):
2075  for i, j in iter_indices(sm1):
2076  idx1d = i * sm1.size(0) + j
2077  res2[i, j] = mathfn(sm1[i, j], sm2[idx1d])
2078  return res2
2079 
2080  # contiguous
2081  m1 = torch.randn(10, 10, 10)
2082  m2 = torch.randn(10, 10 * 10)
2083  sm1 = m1[4]
2084  sm2 = m2[4]
2085 
2086  res1 = torchfn(sm1, sm2.view(10, 10))
2087  res2 = reference_implementation(res1.clone())
2088  self.assertEqual(res1, res2)
2089 
2090  # non-contiguous
2091  m1 = torch.randn(10, 10, 10)
2092  m2 = torch.randn(10 * 10, 10 * 10)
2093  sm1 = m1[:, 4]
2094  sm2 = m2[:, 4]
2095  # view as sm1.size()
2096  sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0]))
2097  res1 = torchfn(sm1, sm2)
2098  # reference_implementation assumes 1-d sm2
2099  sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride())
2100  res2 = reference_implementation(res1.clone())
2101  self.assertEqual(res1, res2)
2102 
2103  def test_cdiv(self):
2104  self._test_cop(torch.div, lambda x, y: x / y)
2105 
2106  def test_cfmod(self):
2107  self._test_cop(torch.fmod, math.fmod)
2108 
2109  def test_cremainder(self):
2110  self._test_cop(torch.remainder, lambda x, y: x % y)
2111 
2112  def test_cmul(self):
2113  self._test_cop(torch.mul, lambda x, y: x * y)
2114 
2115  def test_cpow(self):
2116  self._test_cop(torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y))
2117 
2118  @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
2119  def test_einsum(self):
2120  # test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f
2121  x = torch.randn(5)
2122  y = torch.randn(7)
2123  A = torch.randn(3, 5)
2124  B = torch.randn(2, 5)
2125  C = torch.randn(2, 3, 5)
2126  D = torch.randn(2, 5, 7)
2127  E = torch.randn(7, 9)
2128  F = torch.randn(2, 3, 5, 7)
2129  G = torch.randn(7, 11, 13)
2130  H = torch.randn(4, 4)
2131  I = torch.randn(3, 4, 4)
2132  l = torch.randn(5, 10)
2133  r = torch.randn(5, 20)
2134  w = torch.randn(30, 10, 20)
2135  test_list = [
2136  # -- Vector
2137  ("i->", x), # sum
2138  ("i,i->", x, x), # dot
2139  ("i,i->i", x, x), # vector element-wise mul
2140  ("i,j->ij", x, y), # outer
2141  # -- Matrix
2142  ("ij->ji", A), # transpose
2143  ("ij->j", A), # row sum
2144  ("ij->i", A), # col sum
2145  ("ij,ij->ij", A, A), # matrix element-wise mul
2146  ("ij,j->i", A, x), # matrix vector multiplication
2147  ("ij,kj->ik", A, B), # matmul
2148  ("ij,ab->ijab", A, E), # matrix outer product
2149  # -- Tensor
2150  ("aij,ajk->aik", C, D), # batch matmul
2151  ("ijk,jk->i", C, A), # tensor matrix contraction
2152  ("aij,jk->aik", D, E), # tensor matrix contraction
2153  ("abcd,dfg->abcfg", F, G), # tensor tensor contraction
2154  ("ijk,jk->ik", C, A), # tensor matrix contraction with double indices
2155  ("ijk,jk->ij", C, A), # tensor matrix contraction with double indices
2156  ("ijk,ik->j", C, B), # non contiguous
2157  ("ijk,ik->jk", C, B), # non contiguous with double indices
2158  # -- Diagonal
2159  ("ii", H), # trace
2160  ("ii->i", H), # diagonal
2161  # -- Ellipsis
2162  ("i...->...", H),
2163  ("ki,...k->i...", A.t(), B),
2164  ("k...,jk", A.t(), B),
2165  ("...ii->...i", I), # batch diagonal
2166  # -- Other
2167  ("bn,anm,bm->ba", l, w, r), # as torch.bilinear
2168  ("... ii->...i ", I), # batch diagonal with spaces
2169  ]
2170  for test in test_list:
2171  actual = torch.einsum(test[0], test[1:])
2172  expected = np.einsum(test[0], *[t.numpy() for t in test[1:]])
2173  self.assertEqual(expected.shape, actual.shape, test[0])
2174  self.assertTrue(np.allclose(expected, actual.numpy()), test[0])
2175  # test vararg
2176  actual2 = torch.einsum(test[0], *test[1:])
2177  self.assertEqual(expected.shape, actual2.shape, test[0])
2178  self.assertTrue(np.allclose(expected, actual2.numpy()), test[0])
2179 
2180  def do_einsum(*args):
2181  return torch.einsum(test[0], args)
2182  # FIXME: following test cases fail gradcheck
2183  if test[0] not in {"i,i->", "i,i->i", "ij,ij->ij"}:
2184  gradcheck_inps = tuple(t.detach().requires_grad_() for t in test[1:])
2185  self.assertTrue(torch.autograd.gradcheck(do_einsum, gradcheck_inps))
2186  self.assertTrue(A._version == 0) # check that we do not use inplace ops
2187 
2188  def test_sum_all(self):
2189  def check_sum_all(tensor):
2190  pylist = tensor.reshape(-1).tolist()
2191  self.assertEqual(tensor.sum(), sum(pylist))
2192 
2193  check_sum_all(torch.tensor([1, 2, 3, 4, 5]))
2194  check_sum_all(torch.randn(200000))
2195  check_sum_all(torch.randn(2000, 2)[:, 0])
2196 
2197  def _assert_matches_numpy(self, t, n):
2198  self.assertEqual(n.shape, t.shape)
2199  if t.dtype == torch.float:
2200  self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05,
2201  equal_nan=True))
2202  else:
2203  self.assertTrue(np.allclose(n, t.numpy(), equal_nan=True))
2204 
2205  def _test_dim_ops(self, pytorch_op, numpy_op,
2206  use_floating=True, use_integral=True):
2207  def do_one(tensors_dict, dim):
2208  for category, tensors in tensors_dict.items():
2209  if category == "slice":
2210  dim = 0
2211  for tensor in tensors:
2212  # we have no control over NumPy warnings...
2213  with warnings.catch_warnings():
2214  warnings.simplefilter("ignore")
2215  expected = numpy_op(tensor.numpy(), dim)
2216  actual = pytorch_op(tensor, dim)
2217  self._assert_matches_numpy(actual, expected)
2219  self._assert_matches_numpy(pytorch_op(tensor.cuda(),
2220  dim).cpu(),
2221  expected)
2222  do_one(self._make_tensors((5, 400000), use_floating=use_floating,
2223  use_integral=use_integral), 1)
2224  do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2225  use_integral=use_integral), 0)
2226  do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2227  use_integral=use_integral), 1)
2228  do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
2229  use_integral=use_integral), 2)
2230  do_one(self._make_tensors((100000, ), use_floating=use_floating,
2231  use_integral=use_integral), -1)
2232  do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2233  use_integral=use_integral), 0)
2234  do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2235  use_integral=use_integral), 1)
2236  do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2237  use_integral=use_integral), 2)
2238  do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2239  use_integral=use_integral), (1, 2))
2240  do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2241  use_integral=use_integral), (1, -1))
2242  do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2243  use_integral=use_integral), (0, 2))
2244  do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
2245  use_integral=use_integral), (0, 2, 1))
2246 
2247  @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
2248  def test_sum_dim(self):
2249  self._test_dim_ops(
2250  lambda t, d: t.sum(d),
2251  lambda n, d: n.sum(d))
2252 
2253  @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
2254  def test_mean_dim(self):
2255  self._test_dim_ops(
2256  lambda t, d: t.mean(d),
2257  lambda n, d: n.mean(d),
2258  use_integral=False)
2259 
2260  @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
2261  def test_std_dim(self):
2262  for unbiased in [False, True]:
2263  self._test_dim_ops(
2264  lambda t, d: t.std(d, unbiased=unbiased),
2265  lambda n, d: n.std(d, ddof=1 if unbiased else 0),
2266  use_integral=False)
2267 
2268  @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
2269  def test_var_dim(self):
2270  for unbiased in [False, True]:
2271  self._test_dim_ops(
2272  lambda t, d: t.var(d, unbiased=unbiased),
2273  lambda n, d: n.var(d, ddof=1 if unbiased else 0),
2274  use_integral=False)
2275 
2276  @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
2277  @unittest.skipIf(not TEST_SCIPY, 'Scipy not found')
2278  def test_logsumexp_dim(self):
2279  from scipy.special import logsumexp
2280  self._test_dim_ops(
2281  lambda t, d: t.logsumexp(d),
2282  lambda n, d: logsumexp(n, d),
2283  use_integral=False)
2284 
2285  def test_sum_out(self):
2286  x = torch.rand(100, 100)
2287  res1 = torch.sum(x, 1)
2288  res2 = torch.Tensor()
2289  torch.sum(x, 1, out=res2)
2290  self.assertEqual(res1, res2)
2291  x = torch.rand(100, 100, 100)
2292  res1 = x.sum(2).sum(1)
2293  res2 = torch.Tensor()
2294  torch.sum(x, (2, 1), out=res2)
2295  self.assertEqual(res1, res2)
2296 
2297  # TODO: these tests only check if it's possible to pass a return value
2298  # it'd be good to expand them
2299  def test_prod(self):
2300  x = torch.rand(100, 100)
2301  res1 = torch.prod(x, 1)
2302  res2 = torch.Tensor()
2303  torch.prod(x, 1, out=res2)
2304  self.assertEqual(res1, res2)
2305 
2306  def test_cumsum(self):
2307  x = torch.rand(100, 100)
2308  res1 = torch.cumsum(x, 1)
2309  res2 = torch.Tensor()
2310  torch.cumsum(x, 1, out=res2)
2311  self.assertEqual(res1, res2)
2312 
2313  def test_cumprod(self):
2314  x = torch.rand(100, 100)
2315  res1 = torch.cumprod(x, 1)
2316  res2 = torch.Tensor()
2317  torch.cumprod(x, 1, out=res2)
2318  self.assertEqual(res1, res2)
2319 
2320  def _test_reduce_integer_upcast(self, fn, has_out=True):
2321  shape = (3, 4, 5)
2322  reduced_shape = fn(torch.ones(shape)).shape
2323 
2324  def _test_out(dtype, other_dtype):
2325  out = torch.ones(reduced_shape, dtype=dtype)
2326  result = fn(x, out=out)
2327  self.assertIs(out.dtype, result.dtype)
2328  self.assertEqual(fn(x.type(dtype)), result)
2329  result = fn(x, out=out, dtype=dtype)
2330  self.assertIs(out.dtype, result.dtype)
2331  self.assertEqual(fn(x.type(dtype)), result)
2332  # 'out' is favored over dtype, check error
2333  self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype))
2334 
2335  for dtype in [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]:
2336  x = torch.ones(shape, dtype=dtype)
2337  expected_dtype = dtype if dtype.is_floating_point else torch.int64
2338  self.assertIs(expected_dtype, fn(x).dtype)
2339  self.assertEqual(fn(x.type(expected_dtype)), fn(x))
2340 
2341  if dtype.is_floating_point:
2342  other_dtype = torch.float32 if dtype == torch.float64 else torch.float64
2343  else:
2344  other_dtype = torch.int32 if dtype != torch.int32 else torch.int16
2345  self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype)
2346  self.assertEqual(fn(x.type(other_dtype)), fn(x, dtype=other_dtype))
2347 
2348  # test mixed int/float
2349  mixed_dtype = torch.int32 if dtype.is_floating_point else torch.float32
2350  self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype)
2351  self.assertEqual(fn(x.type(mixed_dtype)), fn(x, dtype=mixed_dtype))
2352 
2353  if has_out:
2354  _test_out(dtype, other_dtype)
2355  _test_out(dtype, mixed_dtype)
2356 
2357  def test_sum_integer_upcast(self):
2358  self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False)
2359  self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs))
2360 
2361  def test_prod_integer_upcast(self):
2362  self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False)
2363  self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs))
2364 
2365  def test_cumsum_integer_upcast(self):
2366  self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs))
2367 
2368  def test_cumprod_integer_upcast(self):
2369  self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs))
2370 
2371  def test_cross(self):
2372  x = torch.rand(100, 3, 100)
2373  y = torch.rand(100, 3, 100)
2374  res1 = torch.cross(x, y)
2375  res2 = torch.Tensor()
2376  torch.cross(x, y, out=res2)
2377  self.assertEqual(res1, res2)
2378 
2379  def test_zeros(self):
2380  res1 = torch.zeros(100, 100)
2381  res2 = torch.Tensor()
2382  torch.zeros(100, 100, out=res2)
2383  self.assertEqual(res1, res2)
2384 
2385  boolTensor = torch.zeros(2, 2, dtype=torch.bool)
2386  expected = torch.tensor([[False, False], [False, False]], dtype=torch.bool)
2387  self.assertEqual(boolTensor, expected)
2388 
2389  halfTensor = torch.zeros(1, 1, dtype=torch.half)
2390  expected = torch.tensor([[0.]], dtype=torch.float16)
2391  self.assertEqual(halfTensor, expected)
2392 
2393  def test_zeros_like(self):
2394  expected = torch.zeros(100, 100)
2395 
2396  res1 = torch.zeros_like(expected)
2397  self.assertEqual(res1, expected)
2398 
2399  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
2400  def test_zeros_like_cuda(self):
2401  expected = torch.zeros(100, 100).cuda()
2402 
2403  res1 = torch.zeros_like(expected)
2404  self.assertEqual(res1, expected)
2405 
2406  @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected')
2407  def test_zeros_like_multiple_device(self):
2408  expected = torch.zeros(100, 100).cuda()
2409  x = torch.cuda.FloatTensor(100, 100, device=1)
2410  output = torch.zeros_like(x)
2411  self.assertEqual(output, expected)
2412 
2413  def test_zeros_out(self):
2414  shape = (3, 4)
2415  out = torch.zeros(shape)
2416  torch.zeros(shape, out=out)
2417 
2418  # change the dtype, layout, device
2419  self.assertRaises(RuntimeError, lambda: torch.zeros(shape, dtype=torch.int64, out=out))
2420  self.assertRaises(RuntimeError, lambda: torch.zeros(shape, layout=torch.sparse_coo, out=out))
2422  self.assertRaises(RuntimeError, lambda: torch.zeros(shape, device='cuda', out=out))
2423 
2424  # leave them the same
2425  self.assertEqual(torch.zeros(shape), torch.zeros(shape, dtype=out.dtype, out=out))
2426  self.assertEqual(torch.zeros(shape), torch.zeros(shape, layout=torch.strided, out=out))
2427  self.assertEqual(torch.zeros(shape), torch.zeros(shape, device='cpu', out=out))
2428 
2429  @staticmethod
2430  def _test_histc(self, device):
2431  # negative nbins throws
2432  with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
2433  torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
2434 
2435  # without nbins
2436  actual = torch.histc(
2437  torch.tensor([2, 5], dtype=torch.float, device=device))
2438  expected = torch.zeros(100, dtype=torch.float, device=device)
2439  expected.data[0] = 1
2440  expected.data[99] = 1
2441  self.assertEqual(expected, actual)
2442  # tensor with the same element
2443  actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
2444  self.assertEqual(
2445  torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
2446  actual)
2447  # no element falls between [min, max]
2448  actual = torch.histc(
2449  torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
2450  self.assertEqual(
2451  torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
2452  actual)
2453  # element falls below min + integral bin size and
2454  actual = torch.histc(
2455  torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
2456  bins=5, min=1, max=5)
2457  self.assertEqual(
2458  torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
2459  actual)
2460  # non-integral bin size
2461  actual = torch.histc(
2462  torch.tensor([1, 2, 1], dtype=torch.float, device=device),
2463  bins=4, min=0, max=3)
2464  self.assertEqual(
2465  torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2466  actual)
2467  # double input
2468  actual = torch.histc(
2469  torch.tensor([1, 2, 1], dtype=torch.double, device=device),
2470  bins=4, min=0, max=3)
2471  self.assertEqual(
2472  torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
2473  actual)
2474  # mixed input
2475  actual = torch.histc(
2476  torch.tensor([1., 2, 1], dtype=torch.float, device=device),
2477  bins=4, min=0, max=3)
2478  self.assertEqual(
2479  torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2480  actual)
2481 
2482  # test against numpy.histogram()
2483  def test_against_np(tensor, bins=100, min=0, max=0):
2484  if min == 0 and max == 0:
2485  min = tensor.min().item()
2486  max = tensor.max().item()
2487  nparr = tensor.cpu().numpy()
2488  actual = torch.histc(tensor, bins=bins, min=min, max=max)
2489  expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0])
2490  self.assertEqual(actual.cpu(), expected)
2491 
2492  if TEST_NUMPY:
2493  test_against_np(torch.tensor([1., 2, 1], device=device))
2494  test_against_np(torch.randn(5000, device=device))
2495 
2496  # Test bins arg
2497  test_against_np(torch.randn(301, device=device), bins=10)
2498 
2499  # Test truncated range
2500  test_against_np(torch.randn(201, device=device), min=0.1, max=1)
2501 
2502  noncontig = torch.randn(100, 3, device=device)[:, 2]
2503  test_against_np(noncontig)
2504 
2505  multidim = torch.randn(3, 5, 7, 2, device=device)
2506  test_against_np(multidim)
2507 
2508  expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
2509  test_against_np(expanded)
2510 
2511  def test_histc_cpu(self):
2512  self._test_histc(self, 'cpu')
2513 
2514  def test_ones(self):
2515  res1 = torch.ones(100, 100)
2516  res2 = torch.Tensor()
2517  torch.ones(100, 100, out=res2)
2518  self.assertEqual(res1, res2)
2519 
2520  # test boolean tensor
2521  res1 = torch.ones(1, 2, dtype=torch.bool)
2522  expected = torch.tensor([[True, True]], dtype=torch.bool)
2523  self.assertEqual(res1, expected)
2524 
2525  def test_ones_like(self):
2526  expected = torch.ones(100, 100)
2527 
2528  res1 = torch.ones_like(expected)
2529  self.assertEqual(res1, expected)
2530 
2531  # test boolean tensor
2532  expected = torch.tensor([True, True], dtype=torch.bool)
2533  res1 = torch.ones_like(expected)
2534  self.assertEqual(res1, expected)
2535 
2536  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
2537  def test_ones_like_cuda(self):
2538  expected = torch.ones(100, 100).cuda()
2539 
2540  res1 = torch.ones_like(expected)
2541  self.assertEqual(res1, expected)
2542 
2543  @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected')
2544  def test_ones_like_multiple_device(self):
2545  expected = torch.ones(100, 100).cuda()
2546  x = torch.cuda.FloatTensor(100, 100, device=1)
2547  output = torch.ones_like(x)
2548  self.assertEqual(output, expected)
2549 
2550  def test_dtypes(self):
2551  all_dtypes = torch.testing.get_all_dtypes()
2552  do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cpu'))
2554  do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cuda:0'))
2555 
2556  def test_copy_dtypes(self):
2557  all_dtypes = torch.testing.get_all_dtypes()
2558  for dtype in all_dtypes:
2559  copied_dtype = copy.deepcopy(dtype)
2560  self.assertIs(dtype, copied_dtype)
2561 
2562  def test_device(self):
2563  cpu = torch.device('cpu')
2564  self.assertEqual('cpu', str(cpu))
2565  self.assertEqual('cpu', cpu.type)
2566  self.assertEqual(None, cpu.index)
2567 
2568  cpu0 = torch.device('cpu:0')
2569  self.assertEqual('cpu:0', str(cpu0))
2570  self.assertEqual('cpu', cpu0.type)
2571  self.assertEqual(0, cpu0.index)
2572 
2573  cpu0 = torch.device('cpu', 0)
2574  self.assertEqual('cpu:0', str(cpu0))
2575  self.assertEqual('cpu', cpu0.type)
2576  self.assertEqual(0, cpu0.index)
2577 
2578  cuda = torch.device('cuda')
2579  self.assertEqual('cuda', str(cuda))
2580  self.assertEqual('cuda', cuda.type)
2581  self.assertEqual(None, cuda.index)
2582 
2583  cuda1 = torch.device('cuda:1')
2584  self.assertEqual('cuda:1', str(cuda1))
2585  self.assertEqual('cuda', cuda1.type)
2586  self.assertEqual(1, cuda1.index)
2587 
2588  cuda1 = torch.device('cuda', 1)
2589  self.assertEqual('cuda:1', str(cuda1))
2590  self.assertEqual('cuda', cuda1.type)
2591  self.assertEqual(1, cuda1.index)
2592 
2593  self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1'))
2594  self.assertRaises(RuntimeError, lambda: torch.device('cpu:1'))
2595  self.assertRaises(RuntimeError, lambda: torch.device('cpu', -1))
2596  self.assertRaises(RuntimeError, lambda: torch.device('cpu', 1))
2597  self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1'))
2598  self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1))
2599  self.assertRaises(RuntimeError, lambda: torch.device(-1))
2600 
2601  self.assertRaises(RuntimeError, lambda: torch.device('other'))
2602  self.assertRaises(RuntimeError, lambda: torch.device('other:0'))
2603 
2604  device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'}
2605  device_hash_set = set()
2606  for device in list(device_set):
2607  device_hash_set.add(hash(torch.device(device)))
2608  self.assertEqual(len(device_set), len(device_hash_set))
2609 
2610  def test_tensor_device(self):
2611  def assertEqual(device_str, fn):
2612  self.assertEqual(torch.device(device_str), fn().device)
2613  self.assertEqual(device_str, str(fn().device))
2614 
2615  assertEqual('cpu', lambda: torch.tensor(5))
2616  assertEqual('cpu', lambda: torch.ones((2, 3), dtype=torch.float32, device='cpu'))
2617  # NOTE: 'cpu' is the canonical representation of 'cpu:0', but 'cuda:X' is the canonical
2618  # representation of cuda devices.
2619  assertEqual('cpu', lambda: torch.ones((2, 3), dtype=torch.float32, device='cpu:0'))
2620  assertEqual('cpu', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cpu:0'))
2621  if TEST_NUMPY:
2622  assertEqual('cpu', lambda: torch.tensor(np.random.randn(2, 3), device='cpu'))
2623 
2625  assertEqual('cuda:0', lambda: torch.tensor(5).cuda(0))
2626  assertEqual('cuda:0', lambda: torch.tensor(5).cuda('cuda:0'))
2627  self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu'))
2628  self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu:0'))
2629  assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device=0))
2630  assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:0'))
2631  assertEqual('cuda:' + str(torch.cuda.current_device()),
2632  lambda: torch.tensor(5, dtype=torch.int64, device='cuda'))
2633  assertEqual('cuda:0', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:0'))
2634  if TEST_NUMPY:
2635  assertEqual('cuda:0', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:0'))
2636 
2637  if torch.cuda.device_count() > 1:
2638  assertEqual('cuda:1', lambda: torch.tensor(5).cuda(1))
2639  assertEqual('cuda:1', lambda: torch.tensor(5).cuda('cuda:1'))
2640  assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device=1))
2641  assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:1'))
2642  assertEqual('cuda:1', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:1'))
2643  if TEST_NUMPY:
2644  assertEqual('cuda:1', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:1'))
2645 
2646  def test_to(self):
2647  def test_copy_behavior(t, non_blocking=False):
2648  self.assertIs(t, t.to(t, non_blocking=non_blocking))
2649  self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
2650  self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
2651  self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
2652  self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
2653  self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True))
2654 
2655  devices = [t.device]
2656  if t.device.type == 'cuda':
2657  if t.device.index == -1:
2658  devices.append('cuda:{}'.format(torch.cuda.current_device()))
2659  elif t.device.index == torch.cuda.current_device():
2660  devices.append('cuda')
2661  for device in devices:
2662  self.assertIs(t, t.to(device, non_blocking=non_blocking))
2663  self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
2664  self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
2665  self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True))
2666 
2667  a = torch.tensor(5)
2668  test_copy_behavior(a)
2669  self.assertEqual(a.device, a.to('cpu').device)
2670  self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device)
2671  self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype)
2672  self.assertEqual(a.device, a.to(torch.float32).device)
2673  self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
2674  self.assertEqual(a.data_ptr(), a.to('cpu').data_ptr())
2675  self.assertEqual(a.data_ptr(), a.to(dtype=a.dtype, device=a.device, copy=False).data_ptr())
2676  self.assertEqual(a.data_ptr(), a.to('cpu', copy=False).data_ptr())
2677  self.assertNotEqual(a.data_ptr(), a.to('cpu', copy=True).data_ptr())
2678 
2680  for non_blocking in [True, False]:
2681  for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
2682  b = torch.tensor(5., device=cuda)
2683  test_copy_behavior(b, non_blocking)
2684  self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device)
2685  self.assertEqual(a.device, b.to('cpu', non_blocking=non_blocking).device)
2686  self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device)
2687  self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype)
2688  self.assertEqual(a.device, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device)
2689  self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
2690  self.assertEqual(b.device, b.to(dtype=torch.int32).device)
2691 
2692  def test_to_with_tensor(self):
2693  a = torch.tensor(5)
2694  self.assertEqual(a.device, a.to(a).device)
2695 
2697  for non_blocking in [True, False]:
2698  for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
2699  b = torch.tensor(5., device=cuda)
2700  self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device)
2701  self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device)
2702  self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device)
2703 
2704  def test_empty_full(self):
2705  do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cpu'))
2706  if torch.cuda.device_count() > 0:
2707  do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, None)
2708  do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cuda:0'))
2709 
2710  def test_dtype_out_match(self):
2711  d = torch.autograd.Variable(torch.DoubleTensor(2, 3))
2712  self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), out=d, dtype=torch.float32))
2713 
2714  def test_constructor_dtypes(self):
2715  default_type = torch.Tensor().type()
2716  self.assertIs(torch.Tensor().dtype, torch.get_default_dtype())
2717 
2718  self.assertIs(torch.uint8, torch.ByteTensor.dtype)
2719  self.assertIs(torch.float32, torch.FloatTensor.dtype)
2720  self.assertIs(torch.float64, torch.DoubleTensor.dtype)
2721 
2722  torch.set_default_tensor_type('torch.FloatTensor')
2723  self.assertIs(torch.float32, torch.get_default_dtype())
2724  self.assertIs(torch.FloatStorage, torch.Storage)
2725 
2726  torch.set_default_dtype(torch.float64)
2727  self.assertIs(torch.float64, torch.get_default_dtype())
2728  self.assertIs(torch.DoubleStorage, torch.Storage)
2729 
2730  torch.set_default_tensor_type(torch.FloatTensor)
2731  self.assertIs(torch.float32, torch.get_default_dtype())
2732  self.assertIs(torch.FloatStorage, torch.Storage)
2733 
2735  torch.set_default_tensor_type(torch.cuda.FloatTensor)
2736  self.assertIs(torch.float32, torch.get_default_dtype())
2737  self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype)
2738  self.assertIs(torch.cuda.FloatStorage, torch.Storage)
2739 
2740  torch.set_default_dtype(torch.float64)
2741  self.assertIs(torch.float64, torch.get_default_dtype())
2742  self.assertIs(torch.cuda.DoubleStorage, torch.Storage)
2743 
2744  # don't support integral or sparse default types.
2745  self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor'))
2746  self.assertRaises(TypeError, lambda: torch.set_default_dtype(torch.int64))
2747 
2748  # don't allow passing dtype to set_default_tensor_type
2749  self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32))
2750 
2751  torch.set_default_tensor_type(default_type)
2752 
2753  def test_constructor_device_legacy(self):
2754  self.assertRaises(RuntimeError, lambda: torch.FloatTensor(device='cuda'))
2755  self.assertRaises(RuntimeError, lambda: torch.FloatTensor(torch.Size([2, 3, 4]), device='cuda'))
2756  self.assertRaises(RuntimeError, lambda: torch.FloatTensor((2.0, 3.0), device='cuda'))
2757 
2758  self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cuda'))
2759  self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cuda'))
2760  self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cuda'))
2761 
2762  x = torch.randn((3,), device='cpu')
2763  self.assertRaises(RuntimeError, lambda: x.new(device='cuda'))
2764  self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda'))
2765  self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cuda'))
2766 
2768  self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(device='cpu'))
2769  self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(torch.Size([2, 3, 4]), device='cpu'))
2770  self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor((2.0, 3.0), device='cpu'))
2771 
2772  default_type = torch.Tensor().type()
2773  torch.set_default_tensor_type(torch.cuda.FloatTensor)
2774  self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu'))
2775  self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu'))
2776  self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu'))
2777  torch.set_default_tensor_type(torch.cuda.FloatTensor)
2778  torch.set_default_tensor_type(default_type)
2779 
2780  x = torch.randn((3,), device='cuda')
2781  self.assertRaises(RuntimeError, lambda: x.new(device='cpu'))
2782  self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu'))
2783  self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cpu'))
2784 
2785  def test_type(self):
2786  x = torch.randn(3, 3).double()
2787  self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32)
2788  self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32)
2789  self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype())
2790  self.assertEqual(x.type(torch.int32).dtype, torch.int32)
2791 
2792  def test_tensor_factory(self):
2793  expected = torch.Tensor([1, 1])
2794  # test data
2795  res1 = torch.tensor([1, 1])
2796  self.assertEqual(res1, expected)
2797 
2798  res1 = torch.tensor([1, 1], dtype=torch.int)
2799  self.assertEqual(res1, expected)
2800  self.assertIs(torch.int, res1.dtype)
2801 
2802  # test copy
2803  res2 = torch.tensor(expected)
2804  self.assertEqual(res2, expected)
2805  res2[1] = 2
2806  self.assertEqual(expected, torch.ones_like(expected))
2807 
2808  res2 = torch.tensor(expected, dtype=torch.int)
2809  self.assertEqual(res1, expected)
2810  self.assertIs(torch.int, res1.dtype)
2811 
2812  # test copy with numpy
2813  if TEST_NUMPY:
2814  for dtype in [np.float64, np.int64, np.int8, np.uint8]:
2815  a = np.array([5.]).astype(dtype)
2816  res1 = torch.tensor(a)
2817  self.assertEqual(5., res1[0].item())
2818  a[0] = 7.
2819  self.assertEqual(5., res1[0].item())
2820 
2821  # test boolean tensor
2822  a = torch.tensor([True, True, False, True, True], dtype=torch.bool)
2823  b = torch.tensor([-1, -1.1, 0, 1, 1.1], dtype=torch.bool)
2824  self.assertEqual(a, b)
2825 
2826  def test_tensor_factory_copy_var(self):
2827 
2828  def check_copy(copy, is_leaf, requires_grad, data_ptr=None):
2829  if data_ptr is None:
2830  data_ptr = copy.data_ptr
2831  self.assertEqual(copy.data, source.data)
2832  self.assertTrue(copy.is_leaf == is_leaf)
2833  self.assertTrue(copy.requires_grad == requires_grad)
2834  self.assertTrue(copy.data_ptr == data_ptr)
2835 
2836  source = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
2837  # test torch.tensor()
2838  check_copy(torch.tensor(source), True, False)
2839  check_copy(torch.tensor(source, requires_grad=False), True, False)
2840  check_copy(torch.tensor(source, requires_grad=True), True, True)
2841 
2842  # test tensor.new_tensor()
2843  copy = torch.randn(1)
2844  check_copy(copy.new_tensor(source), True, False)
2845  check_copy(copy.new_tensor(source, requires_grad=False), True, False)
2846  check_copy(copy.new_tensor(source, requires_grad=True), True, True)
2847 
2848  # test torch.as_tensor()
2849  check_copy(torch.as_tensor(source), source.is_leaf, source.requires_grad, source.data_ptr) # not copy
2850  check_copy(torch.as_tensor(source, dtype=torch.float), False, True) # copy and keep the graph
2851 
2852  def test_tensor_factory_type_inference(self):
2853  def test_inference(default_dtype):
2854  saved_dtype = torch.get_default_dtype()
2855  torch.set_default_dtype(default_dtype)
2856  self.assertIs(default_dtype, torch.tensor(()).dtype)
2857  self.assertIs(default_dtype, torch.tensor(5.).dtype)
2858  self.assertIs(torch.int64, torch.tensor(5).dtype)
2859  self.assertIs(torch.uint8, torch.tensor(True).dtype)
2860  self.assertIs(torch.int32, torch.tensor(5, dtype=torch.int32).dtype)
2861  self.assertIs(default_dtype, torch.tensor(((7, 5), (9, 5.))).dtype)
2862  self.assertIs(default_dtype, torch.tensor(((5., 5), (3, 5))).dtype)
2863  self.assertIs(torch.int64, torch.tensor(((5, 3), (3, 5))).dtype)
2864 
2865  if TEST_NUMPY:
2866  self.assertIs(torch.float64, torch.tensor(np.array(())).dtype)
2867  self.assertIs(torch.float64, torch.tensor(np.array(5.)).dtype)
2868  if np.array(5).dtype == np.int64: # np long, which can be 4 bytes (e.g. on windows)
2869  self.assertIs(torch.int64, torch.tensor(np.array(5)).dtype)
2870  else:
2871  self.assertIs(torch.int32, torch.tensor(np.array(5)).dtype)
2872  self.assertIs(torch.uint8, torch.tensor(np.array(3, dtype=np.uint8)).dtype)
2873  self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype)
2874  self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype)
2875  self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype)
2876  torch.set_default_dtype(saved_dtype)
2877 
2878  test_inference(torch.float64)
2879  test_inference(torch.float32)
2880 
2881  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
2882  def test_tensor_factory_cuda_type_inference(self):
2883  saved_type = torch.Tensor().type()
2884  torch.set_default_tensor_type(torch.cuda.DoubleTensor)
2885  torch.set_default_dtype(torch.float32)
2886  self.assertIs(torch.float32, torch.tensor(0.).dtype)
2887  self.assertEqual(torch.device('cuda:0'), torch.tensor(0.).device)
2888  torch.set_default_dtype(torch.float64)
2889  self.assertIs(torch.float64, torch.tensor(0.).dtype)
2890  self.assertEqual(torch.device('cuda:0'), torch.tensor(0.).device)
2891  torch.set_default_tensor_type(saved_type)
2892 
2893  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
2894  def test_tensor_factory_cuda_type(self):
2895  saved_type = torch.Tensor().type()
2896  torch.set_default_tensor_type(torch.cuda.FloatTensor)
2897  x = torch.zeros((5, 5))
2898  self.assertIs(torch.float32, x.dtype)
2899  self.assertTrue(x.is_cuda)
2900  torch.set_default_tensor_type(torch.cuda.DoubleTensor)
2901  x = torch.zeros((5, 5))
2902  self.assertIs(torch.float64, x.dtype)
2903  self.assertTrue(x.is_cuda)
2904  torch.set_default_tensor_type(saved_type)
2905 
2906  # This is a temporary test for a boolean tensors on CPU. Once the CUDA part
2907  # will be done, these test cases will be moved down to test_tensor_factories_empty test
2908  def test_tensor_factories_empty_bool(self):
2909  expectedShape = (1, 2)
2910  test = torch.empty(expectedShape, dtype=torch.bool)
2911  self.assertEqual(expectedShape, test.shape)
2912  self.assertEqual(expectedShape, torch.empty_like(test).shape)
2913 
2914  test = torch.full(expectedShape, True, dtype=torch.bool)
2915  self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool))
2916  self.assertEqual(expectedShape, test.shape)
2917  self.assertEqual(expectedShape, torch.full_like(test, True).shape)
2918 
2919  def test_tensor_factories_empty(self):
2920  # ensure we can create empty tensors from each factory function
2921  shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)]
2922  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
2923 
2924  for device in devices:
2925  for shape in shapes:
2926  self.assertEqual(shape, torch.zeros(shape, device=device).shape)
2927  self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device)).shape)
2928  self.assertEqual(shape, torch.empty(shape, device=device).shape)
2929  self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device)).shape)
2930  self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device).shape)
2931  self.assertEqual(shape, torch.full(shape, 3, device=device).shape)
2932  self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device), 3).shape)
2933  self.assertEqual(shape, torch.ones(shape, device=device).shape)
2934  self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device)).shape)
2935  self.assertEqual(shape, torch.rand(shape, device=device).shape)
2936  self.assertEqual(shape, torch.rand_like(torch.zeros(shape, device=device)).shape)
2937  self.assertEqual(shape, torch.randn(shape, device=device).shape)
2938  self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device)).shape)
2939  self.assertEqual(shape, torch.randint(6, shape, device=device).shape)
2940  self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device), 6).shape)
2941 
2942  self.assertEqual((0,), torch.arange(0, device=device).shape)
2943  self.assertEqual((0, 0), torch.eye(0, device=device).shape)
2944  self.assertEqual((0, 0), torch.eye(0, 0, device=device).shape)
2945  self.assertEqual((5, 0), torch.eye(5, 0, device=device).shape)
2946  self.assertEqual((0, 5), torch.eye(0, 5, device=device).shape)
2947  self.assertEqual((0,), torch.linspace(1, 1, 0, device=device).shape)
2948  self.assertEqual((0,), torch.logspace(1, 1, 0, device=device).shape)
2949  self.assertEqual((0,), torch.randperm(0, device=device).shape)
2950  self.assertEqual((0,), torch.bartlett_window(0, device=device).shape)
2951  self.assertEqual((0,), torch.bartlett_window(0, periodic=False, device=device).shape)
2952  self.assertEqual((0,), torch.hamming_window(0, device=device).shape)
2953  self.assertEqual((0,), torch.hann_window(0, device=device).shape)
2954  self.assertEqual((1, 1, 0), torch.tensor([[[]]], device=device).shape)
2955  self.assertEqual((1, 1, 0), torch.as_tensor([[[]]], device=device).shape)
2956 
2957  def test_new_tensor(self):
2958  expected = torch.autograd.Variable(torch.ByteTensor([1, 1]))
2959  # test data
2960  res1 = expected.new_tensor([1, 1])
2961  self.assertEqual(res1, expected)
2962  res1 = expected.new_tensor([1, 1], dtype=torch.int)
2963  self.assertEqual(res1, expected)
2964  self.assertIs(torch.int, res1.dtype)
2965 
2966  # test copy
2967  res2 = expected.new_tensor(expected)
2968  self.assertEqual(res2, expected)
2969  res2[1] = 2
2970  self.assertEqual(expected, torch.ones_like(expected))
2971  res2 = expected.new_tensor(expected, dtype=torch.int)
2972  self.assertEqual(res2, expected)
2973  self.assertIs(torch.int, res2.dtype)
2974 
2975  # test copy with numpy
2976  if TEST_NUMPY:
2977  a = np.array([5.])
2978  res1 = torch.tensor(a)
2979  res1 = res1.new_tensor(a)
2980  self.assertEqual(5., res1[0].item())
2981  a[0] = 7.
2982  self.assertEqual(5., res1[0].item())
2983 
2984  if torch.cuda.device_count() >= 2:
2985  expected = expected.cuda(1)
2986  res1 = expected.new_tensor([1, 1])
2987  self.assertEqual(res1.get_device(), expected.get_device())
2988  res1 = expected.new_tensor([1, 1], dtype=torch.int)
2989  self.assertIs(torch.int, res1.dtype)
2990  self.assertEqual(res1.get_device(), expected.get_device())
2991 
2992  res2 = expected.new_tensor(expected)
2993  self.assertEqual(res2.get_device(), expected.get_device())
2994  res2 = expected.new_tensor(expected, dtype=torch.int)
2995  self.assertIs(torch.int, res1.dtype)
2996  self.assertEqual(res2.get_device(), expected.get_device())
2997  res2 = expected.new_tensor(expected, dtype=torch.int, device=0)
2998  self.assertIs(torch.int, res1.dtype)
2999  self.assertEqual(res2.get_device(), 0)
3000 
3001  res1 = expected.new_tensor(1)
3002  self.assertEqual(res1.get_device(), expected.get_device())
3003  res1 = expected.new_tensor(1, dtype=torch.int)
3004  self.assertIs(torch.int, res1.dtype)
3005  self.assertEqual(res1.get_device(), expected.get_device())
3006 
3007  def test_as_tensor(self):
3008  # from python data
3009  x = [[0, 1], [2, 3]]
3010  self.assertEqual(torch.tensor(x), torch.as_tensor(x))
3011  self.assertEqual(torch.tensor(x, dtype=torch.float32), torch.as_tensor(x, dtype=torch.float32))
3012 
3013  # python data with heterogeneous types
3014  z = [0, 'torch']
3015  with self.assertRaisesRegex(TypeError, "invalid data type"):
3016  torch.tensor(z)
3017  torch.as_tensor(z)
3018 
3019  # python data with self-referential lists
3020  z = [0]
3021  z += [z]
3022  with self.assertRaisesRegex(TypeError, "self-referential lists are incompatible"):
3023  torch.tensor(z)
3024  torch.as_tensor(z)
3025 
3026  z = [[1, 2], z]
3027  with self.assertRaisesRegex(TypeError, "self-referential lists are incompatible"):
3028  torch.tensor(z)
3029  torch.as_tensor(z)
3030 
3031  # from tensor (doesn't copy unless type is different)
3032  y = torch.tensor(x)
3033  self.assertIs(y, torch.as_tensor(y))
3034  self.assertIsNot(y, torch.as_tensor(y, dtype=torch.float32))
3036  self.assertIsNot(y, torch.as_tensor(y, device='cuda'))
3037  y_cuda = y.to('cuda')
3038  self.assertIs(y_cuda, torch.as_tensor(y_cuda))
3039  self.assertIs(y_cuda, torch.as_tensor(y_cuda, device='cuda'))
3040 
3041  if TEST_NUMPY:
3042  # doesn't copy
3043  for dtype in [np.float64, np.int64, np.int8, np.uint8]:
3044  n = np.random.rand(5, 6).astype(dtype)
3045  n_astensor = torch.as_tensor(n)
3046  self.assertEqual(torch.tensor(n), n_astensor)
3047  n_astensor[0][0] = 25.7
3048  self.assertEqual(torch.tensor(n), n_astensor)
3049 
3050  # changing dtype causes copy
3051  n = np.random.rand(5, 6).astype(np.float32)
3052  n_astensor = torch.as_tensor(n, dtype=torch.float64)
3053  self.assertEqual(torch.tensor(n, dtype=torch.float64), n_astensor)
3054  n_astensor[0][1] = 250.8
3055  self.assertNotEqual(torch.tensor(n, dtype=torch.float64), n_astensor)
3056 
3057  # changing device causes copy
3059  n = np.random.randn(5, 6)
3060  n_astensor = torch.as_tensor(n, device='cuda')
3061  self.assertEqual(torch.tensor(n, device='cuda'), n_astensor)
3062  n_astensor[0][2] = 250.9
3063  self.assertNotEqual(torch.tensor(n, device='cuda'), n_astensor)
3064 
3065  def test_diag(self):
3066  x = torch.rand(100, 100)
3067  res1 = torch.diag(x)
3068  res2 = torch.Tensor()
3069  torch.diag(x, out=res2)
3070  self.assertEqual(res1, res2)
3071 
3072  @staticmethod
3073  def _test_diagonal(self, dtype, device):
3074  x = torch.randn((100, 100), dtype=dtype, device=device)
3075  result = torch.diagonal(x)
3076  expected = torch.diag(x)
3077  self.assertEqual(result, expected)
3078 
3079  x = torch.randn((100, 100), dtype=dtype, device=device)
3080  result = torch.diagonal(x, 17)
3081  expected = torch.diag(x, 17)
3082  self.assertEqual(result, expected)
3083 
3084  def test_diagonal(self):
3085  self._test_diagonal(self, dtype=torch.float32, device='cpu')
3086 
3087  @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
3088  def test_diagonal_multidim(self):
3089  x = torch.randn(10, 11, 12, 13)
3090  xn = x.numpy()
3091  for args in [(2, 2, 3),
3092  (2,),
3093  (-2, 1, 2),
3094  (0, -2, -1)]:
3095  result = torch.diagonal(x, *args)
3096  expected = xn.diagonal(*args)
3097  self.assertEqual(expected.shape, result.shape)
3098  self.assertTrue(np.allclose(expected, result.numpy()))
3099  # test non-continguous
3100  xp = x.permute(1, 2, 3, 0)
3101  result = torch.diagonal(xp, 0, -2, -1)
3102  expected = xp.numpy().diagonal(0, -2, -1)
3103  self.assertEqual(expected.shape, result.shape)
3104  self.assertTrue(np.allclose(expected, result.numpy()))
3105 
3106  @staticmethod
3107  def _test_diag_embed(self, dtype, device):
3108  x = torch.arange(3 * 4, dtype=dtype, device=device).view(3, 4)
3109  result = torch.diag_embed(x)
3110  expected = torch.stack([torch.diag(r) for r in x], 0)
3111  self.assertEqual(result, expected)
3112 
3113  result = torch.diag_embed(x, offset=1, dim1=0, dim2=2)
3114  expected = torch.stack([torch.diag(r, 1) for r in x], 1)
3115  self.assertEqual(result, expected)
3116 
3117  def test_diag_embed(self):
3118  self._test_diag_embed(self, dtype=torch.float32, device='cpu')
3119 
3120  @staticmethod
3121  def _test_diagflat(self, dtype, device):
3122  # Basic sanity test
3123  x = torch.randn((100,), dtype=dtype, device=device)
3124  result = torch.diagflat(x)
3125  expected = torch.diag(x)
3126  self.assertEqual(result, expected)
3127 
3128  # Test offset
3129  x = torch.randn((100,), dtype=dtype, device=device)
3130  result = torch.diagflat(x, 17)
3131  expected = torch.diag(x, 17)
3132  self.assertEqual(result, expected)
3133 
3134  # Test where input has more than one dimension
3135  x = torch.randn((2, 3, 4), dtype=dtype, device=device)
3136  result = torch.diagflat(x)
3137  expected = torch.diag(x.contiguous().view(-1))
3138  self.assertEqual(result, expected)
3139 
3140  # Noncontig input
3141  x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0)
3142  self.assertFalse(x.is_contiguous())
3143  result = torch.diagflat(x)
3144  expected = torch.diag(x.contiguous().view(-1))
3145  self.assertEqual(result, expected)
3146 
3147  def test_diagflat(self):
3148  self._test_diagflat(self, dtype=torch.float32, device='cpu')
3149 
3150  def test_eye(self):
3151  res1 = torch.eye(100, 100)
3152  res2 = torch.Tensor()
3153  torch.eye(100, 100, out=res2)
3154  self.assertEqual(res1, res2)
3155 
3156  def test_renorm(self):
3157  m1 = torch.randn(10, 5)
3158  res1 = torch.Tensor()
3159 
3160  def renorm(matrix, value, dim, max_norm):
3161  m1 = matrix.transpose(dim, 0).contiguous()
3162  # collapse non-dim dimensions.
3163  m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0))))
3164  norms = m2.norm(value, 1, True)
3165  # clip
3166  new_norms = norms.clone()
3167  new_norms[torch.gt(norms, max_norm)] = max_norm
3168  new_norms.div_(norms.add_(1e-7))
3169  # renormalize
3170  m1.mul_(new_norms.expand_as(m1))
3171  return m1.transpose(dim, 0)
3172 
3173  # note that the axis fed to torch.renorm is different (2~=1)
3174  maxnorm = m1.norm(2, 1).mean()
3175  m2 = renorm(m1, 2, 1, maxnorm)
3176  m1.renorm_(2, 1, maxnorm)
3177  self.assertEqual(m1, m2, 1e-5)
3178  self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), 1e-5)
3179 
3180  m1 = torch.randn(3, 4, 5)
3181  m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
3182  maxnorm = m2.norm(2, 0).mean()
3183  m2 = renorm(m2, 2, 1, maxnorm)
3184  m1.renorm_(2, 1, maxnorm)
3185  m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
3186  self.assertEqual(m3, m2)
3187  self.assertEqual(m3.norm(2, 0), m2.norm(2, 0))
3188 
3189  @staticmethod
3190  def _test_renorm_ps(self, device):
3191  # full reduction
3192  x = torch.randn(5, 5)
3193  xn = x.numpy()
3194  for p in [1, 2, 3, 4, inf]:
3195  res = x.renorm(p, 1, 1)
3196  expected = x / x.norm(p, 0, keepdim=True).clamp(min=1)
3197  self.assertEqual(res.numpy(), expected.numpy(), "renorm failed for {}-norm".format(p))
3198 
3199  def test_renorm_ps(self):
3200  self._test_renorm_ps(self, device='cpu')
3201 
3202  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
3203  def test_renorm_ps_cuda(self):
3204  self._test_renorm_ps(self, device='cuda')
3205 
3206  @staticmethod
3207  def _test_multinomial(self, type):
3208  def make_prob_dist(shape, is_contiguous):
3209  if is_contiguous:
3210  return type(*shape).uniform_()
3211  elif len(shape) == 1:
3212  return type(*(shape + [5])).uniform_()[:, 2]
3213  else:
3214  # num dim = 2
3215  new_shape = [2, shape[1], 7, 1, shape[0], 1, 10]
3216  prob_dist = type(*new_shape).uniform_()
3217  prob_dist = prob_dist.transpose(1, 4)
3218  prob_dist = prob_dist[1, :, 5, 0, :, 0, 4]
3219  assert not prob_dist.is_contiguous() # sanity check
3220  return prob_dist
3221 
3222  for is_contiguous in (True, False):
3223  # with replacement
3224  n_row = 3
3225  for n_col in range(4, 5 + 1):
3226  prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
3227  # indices that shouldn't be sampled (<0 means none)
3228  zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist()
3229  for i, j in enumerate(zero_prob_indices):
3230  if j >= 0:
3231  prob_dist[i, j] = 0
3232  n_sample = n_col * 3
3233  sample_indices = torch.multinomial(prob_dist, n_sample, True)
3234  self.assertEqual(prob_dist.dim(), 2)
3235  self.assertEqual(sample_indices.size(1), n_sample)
3236  for i in range(n_row):
3237  zero_prob_idx = zero_prob_indices[i]
3238  if zero_prob_idx < 0:
3239  continue
3240  for j in range(n_sample):
3241  self.assertNotEqual(sample_indices[i, j], zero_prob_idx,
3242  "sampled an index with zero probability")
3243 
3244  # without replacement
3245  n_row = 3
3246  for n_col in range(2, 10 + 1, 2):
3247  prob_dist = make_prob_dist([n_row, n_col], is_contiguous)
3248  # indices that shouldn't be sampled (<0 means none)
3249  zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist()
3250  for i, j in enumerate(zero_prob_indices):
3251  if j >= 0:
3252  prob_dist[i, j] = 0
3253  n_sample = max(1, n_col - 2)
3254  sample_indices = torch.multinomial(prob_dist, n_sample, False)
3255  self.assertEqual(prob_dist.dim(), 2)
3256  self.assertEqual(sample_indices.size(1), n_sample)
3257  for i in range(n_row):
3258  row_samples = {}
3259  zero_prob_idx = zero_prob_indices[i]
3260  for j in range(n_sample):
3261  sample_idx = sample_indices[i, j]
3262  if zero_prob_idx >= 0:
3263  self.assertNotEqual(sample_idx, zero_prob_idx,
3264  "sampled an index with zero probability")
3265  self.assertNotIn(sample_idx, row_samples, "sampled an index twice")
3266  row_samples[sample_idx] = True
3267 
3268  # vector
3269  n_col = 4
3270  prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1)
3271  zero_prob_idx = 1 # index that shouldn't be sampled
3272  prob_dist[zero_prob_idx] = 0
3273  n_sample = 20
3274  sample_indices = torch.multinomial(prob_dist, n_sample, True)
3275  for sample_index in sample_indices:
3276  self.assertNotEqual(sample_index, zero_prob_idx, "sampled an index with zero probability")
3277  s_dim = sample_indices.dim()
3278  self.assertEqual(sample_indices.dim(), 1, "wrong number of dimensions")
3279  self.assertEqual(prob_dist.dim(), 1, "wrong number of prob_dist dimensions")
3280  self.assertEqual(sample_indices.size(0), n_sample, "wrong number of samples")
3281 
3282  def test_multinomial(self):
3283  self._test_multinomial(self, torch.FloatTensor)
3284 
3285  def _spawn_method(self, method, arg):
3286  try:
3287  mp.set_start_method('spawn')
3288  except RuntimeError:
3289  pass
3290  with mp.Pool(1) as pool:
3291  self.assertTrue(pool.map(method, [arg]))
3292 
3293  @staticmethod
3294  def _test_multinomial_invalid_probs(probs):
3295  try:
3296  # n_sample = 1 is a special case, test n_sample=2 which is more general
3297  torch.multinomial(probs.to('cpu'), 2)
3298  return False # Should not be reached
3299  except RuntimeError as e:
3300  return 'invalid multinomial distribution' in str(e)
3301 
3302  @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
3303  don't support multiprocessing with spawn start method")
3304  @unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows')
3305  @unittest.skipIf(not PY3,
3306  "spawn start method is not supported in Python 2, \
3307  but we need it for for testing failure case for CPU RNG on Windows")
3308  def test_multinomial_invalid_probs(self):
3309  test_method = _TestTorchMixin._test_multinomial_invalid_probs
3310  self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
3311  self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
3312  self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
3313  self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
3314  self._spawn_method(test_method, torch.Tensor([0, 1, 0]))
3315 
3316  @suppress_warnings
3317  def test_range(self):
3318  res1 = torch.range(0, 1)
3319  res2 = torch.Tensor()
3320  torch.range(0, 1, out=res2)
3321  self.assertEqual(res1, res2, 0)
3322 
3323  # Check range for non-contiguous tensors.
3324  x = torch.zeros(2, 3)
3325  torch.range(0, 3, out=x.narrow(1, 1, 2))
3326  res2 = torch.Tensor(((0, 0, 1), (0, 2, 3)))
3327  self.assertEqual(x, res2, 1e-16)
3328 
3329  # Check negative
3330  res1 = torch.Tensor((1, 0))
3331  res2 = torch.Tensor()
3332  torch.range(1, 0, -1, out=res2)
3333  self.assertEqual(res1, res2, 0)
3334 
3335  # Equal bounds
3336  res1 = torch.ones(1)
3337  res2 = torch.Tensor()
3338  torch.range(1, 1, -1, out=res2)
3339  self.assertEqual(res1, res2, 0)
3340  torch.range(1, 1, 1, out=res2)
3341  self.assertEqual(res1, res2, 0)
3342 
3343  # FloatTensor
3344  res1 = torch.range(0.6, 0.9, 0.1, out=torch.FloatTensor())
3345  self.assertEqual(res1.size(0), 4)
3346  res1 = torch.range(1, 10, 0.3, out=torch.FloatTensor())
3347  self.assertEqual(res1.size(0), 31)
3348 
3349  # DoubleTensor
3350  res1 = torch.range(0.6, 0.9, 0.1, out=torch.DoubleTensor())
3351  self.assertEqual(res1.size(0), 4)
3352  res1 = torch.range(1, 10, 0.3, out=torch.DoubleTensor())
3353  self.assertEqual(res1.size(0), 31)
3354 
3355  def test_range_warning(self):
3356  with warnings.catch_warnings(record=True) as w:
3357  torch.range(0, 10)
3358  self.assertEqual(len(w), 1)
3359 
3360  def test_arange(self):
3361  res1 = torch.arange(0, 1)
3362  res2 = torch.Tensor()
3363  torch.arange(0, 1, out=res2)
3364  self.assertEqual(res1, res2, 0)
3365 
3366  # Check arange with only one argument
3367  res1 = torch.arange(10)
3368  res2 = torch.arange(0, 10)
3369  self.assertEqual(res1, res2, 0)
3370 
3371  # Check arange for non-contiguous tensors.
3372  x = torch.zeros(2, 3)
3373  torch.arange(0, 4, out=x.narrow(1, 1, 2))
3374  res2 = torch.Tensor(((0, 0, 1), (0, 2, 3)))
3375  self.assertEqual(x, res2, 1e-16)
3376 
3377  # Check negative
3378  res1 = torch.Tensor((1, 0))
3379  res2 = torch.Tensor()
3380  torch.arange(1, -1, -1, out=res2)
3381  self.assertEqual(res1, res2, 0)
3382 
3383  # Equal bounds
3384  res1 = torch.ones(1)
3385  res2 = torch.Tensor()
3386  torch.arange(1, 0, -1, out=res2)
3387  self.assertEqual(res1, res2, 0)
3388  torch.arange(1, 2, 1, out=res2)
3389  self.assertEqual(res1, res2, 0)
3390 
3391  # FloatTensor
3392  res1 = torch.arange(0.6, 0.89, 0.1, out=torch.FloatTensor())
3393  self.assertEqual(res1, [0.6, 0.7, 0.8])
3394  res1 = torch.arange(1, 10, 0.3, out=torch.FloatTensor())
3395  self.assertEqual(res1.size(0), 30)
3396  self.assertEqual(res1[0], 1)
3397  self.assertEqual(res1[29], 9.7)
3398 
3399  # DoubleTensor
3400  res1 = torch.arange(0.6, 0.89, 0.1, out=torch.DoubleTensor())
3401  self.assertEqual(res1, [0.6, 0.7, 0.8])
3402  res1 = torch.arange(1, 10, 0.3, out=torch.DoubleTensor())
3403  self.assertEqual(res1.size(0), 30)
3404  self.assertEqual(res1[0], 1)
3405  self.assertEqual(res1[29], 9.7)
3406 
3407  # Check that it's exclusive
3408  r = torch.arange(0, 5)
3409  self.assertEqual(r.min(), 0)
3410  self.assertEqual(r.max(), 4)
3411  self.assertEqual(r.numel(), 5)
3412 
3413  r = torch.arange(0, 5, 2)
3414  self.assertEqual(r.min(), 0)
3415  self.assertEqual(r.max(), 4)
3416  self.assertEqual(r.numel(), 3)
3417 
3418  r1 = torch.arange(0, 5 + 1e-6)
3419  r2 = torch.arange(0, 5)
3420  r3 = torch.arange(0, 5 - 1e-6)
3421  self.assertEqual(r1[:-1], r2, 0)
3422  self.assertEqual(r2, r3, 0)
3423 
3424  r1 = torch.arange(10, -1 + 1e-6, -1)
3425  r2 = torch.arange(10, -1, -1)
3426  r3 = torch.arange(10, -1 - 1e-6, -1)
3427  self.assertEqual(r1, r2, 0)
3428  self.assertEqual(r2, r3[:-1], 0)
3429 
3430  msg = "unsupported range"
3431  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf')))
3432  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf')))
3433 
3434  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
3435  for device in devices:
3436  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(-5, float('nan'), device=device))
3437  # check with step size
3438  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('-inf'), -1, device=device))
3439  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'), device=device))
3440  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('-inf'), 10, device=device))
3441  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), 10, device=device))
3442  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'), device=device))
3443  self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), device=device))
3444 
3445  self.assertRaisesRegex(
3446  RuntimeError, "overflow",
3447  lambda: torch.arange(1.175494351e-38, 3.402823466e+38, device=device))
3448 
3449  def test_arange_inference(self):
3450  saved_dtype = torch.get_default_dtype()
3451  torch.set_default_dtype(torch.float32)
3452  # end only
3453  self.assertIs(torch.float32, torch.arange(1.).dtype)
3454  self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype)
3455  self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64)).dtype)
3456 
3457  self.assertIs(torch.int64, torch.arange(1).dtype)
3458  self.assertIs(torch.int64, torch.arange(torch.tensor(1)).dtype)
3459  self.assertIs(torch.int64, torch.arange(torch.tensor(1, dtype=torch.int16)).dtype)
3460 
3461  # start, end, [step]
3462  self.assertIs(torch.float32, torch.arange(1., 3).dtype)
3463  self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64), 3).dtype)
3464  self.assertIs(torch.float32, torch.arange(1, 3.).dtype)
3465  self.assertIs(torch.float32, torch.arange(torch.tensor(1, dtype=torch.int16), torch.tensor(3.)).dtype)
3466  self.assertIs(torch.float32, torch.arange(1, 3, 1.).dtype)
3467  self.assertIs(torch.float32,
3468  torch.arange(torch.tensor(1),
3469  torch.tensor(3, dtype=torch.int16),
3470  torch.tensor(1., dtype=torch.float64)).dtype)
3471 
3472  self.assertIs(torch.int64, torch.arange(1, 3).dtype)
3473  self.assertIs(torch.int64, torch.arange(torch.tensor(1), 3).dtype)
3474  self.assertIs(torch.int64, torch.arange(torch.tensor(1), torch.tensor(3, dtype=torch.int16)).dtype)
3475  self.assertIs(torch.int64, torch.arange(1, 3, 1).dtype)
3476  self.assertIs(torch.int64,
3477  torch.arange(torch.tensor(1),
3478  torch.tensor(3),
3479  torch.tensor(1, dtype=torch.int16)).dtype)
3480  torch.set_default_dtype(saved_dtype)
3481 
3482  def test_randint_inference(self):
3483  size = (2, 1)
3484  for args in [(3,), (1, 3)]: # (low,) and (low, high)
3485  self.assertIs(torch.int64, torch.randint(*args, size=size).dtype)
3486  self.assertIs(torch.int64, torch.randint(*args, size=size, layout=torch.strided).dtype)
3487  self.assertIs(torch.int64, torch.randint(*args, size=size, generator=torch.default_generator).dtype)
3488  self.assertIs(torch.float32, torch.randint(*args, size=size, dtype=torch.float32).dtype)
3489  out = torch.empty(size, dtype=torch.float32)
3490  self.assertIs(torch.float32, torch.randint(*args, size=size, out=out).dtype)
3491  self.assertIs(torch.float32, torch.randint(*args, size=size, out=out, dtype=torch.float32).dtype)
3492  out = torch.empty(size, dtype=torch.int64)
3493  self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype)
3494  self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype)
3495 
3496  @staticmethod
3497  def _select_broadcastable_dims(dims_full=None):
3498  # select full dimensionality
3499  if dims_full is None:
3500  dims_full = []
3501  ndims = random.randint(1, 4)
3502  dims_full = [random.randint(1, 8) for _ in range(ndims)]
3503  else:
3504  ndims = len(dims_full)
3505 
3506  # select actual dimensions for ops:
3507  # larger: full ndims, individual sizes may be reduced
3508  # smaller: possibly reduced ndims, sizes may be reduced
3509  smaller_ndims = random.randint(1, ndims)
3510  dims_small = []
3511  dims_large = []
3512  for i in range(ndims - 1, -1, -1):
3513  j = random.randint(1, 3)
3514  if j == 1: # no reduced singleton dimension
3515  ds = dims_full[i]
3516  dl = dims_full[i]
3517  elif j == 2: # larger may have reduced singleton dimension
3518  ds = dims_full[i]
3519  dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
3520  elif j == 3: # smaller may have reduced singleton dimension
3521  ds = 1
3522  dl = dims_full[i]
3523  dims_large = [dl] + dims_large
3524  if len(dims_small) < smaller_ndims:
3525  dims_small = [ds] + dims_small
3526  return (dims_small, dims_large, dims_full)
3527 
3528  @staticmethod
3529  def _test_broadcast(self, cast):
3530 
3531  # all functions
3532  fns = {
3533  "dist", "atan2", "pow", "lerp", "add",
3534  "sub", "mul", "div", "fmod", "remainder",
3535  "eq", "ge", "gt", "le", "lt", "max", "min", "ne",
3536  "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill",
3537  "map", "map2", "copy"
3538  }
3539  # functions with three tensor arguments
3540  fns_3_args = {"addcdiv", "addcmul", "map2"}
3541 
3542  for fn in fns:
3543  (dims_small, dims_large, dims_full) = self._select_broadcastable_dims()
3544  full1d = cast(torch.randn(*dims_full).flatten().float())
3545  small = cast(torch.randn(*dims_small).float())
3546  large = cast(torch.randn(*dims_large).float())
3547  small_expanded = small.expand(*dims_full)
3548  large_expanded = large.expand(*dims_full)
3549  small2 = None
3550  small2_expanded = None
3551  if fn in fns_3_args:
3552  # create another smaller tensor
3553  (dims_small2, _, _) = self._select_broadcastable_dims(dims_full)
3554  small2 = cast(torch.randn(*dims_small2).float())
3555  small2_expanded = small2.expand(*dims_full)
3556 
3557  if small.is_cuda and fn in ['map', 'map2']:
3558  # map and map2 are not implementd on CUDA tensors
3559  continue
3560 
3561  if hasattr(large_expanded, fn):
3562  # run through tensor versions of functions
3563  # and verify fully expanded inputs give same results
3564  expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
3565 
3566  def tensorfn(myfn, t1, t2):
3567  if fn == "lerp":
3568  return myfn(t1, 0.5)
3569  elif fn == "masked_select":
3570  return myfn(t1 < 0)
3571  elif fn == "masked_scatter":
3572  return myfn(t1 < 0.5, full1d)
3573  elif fn == "masked_fill":
3574  return myfn(t1 < 0.5, 1.0)
3575  elif fn in fns_3_args:
3576  return myfn(1, t1, t2)
3577  else:
3578  return myfn(t1)
3579 
3580  # test various orders
3581  for first, second, third in [(large, small, small2), (small, large, small2),
3582  (small2, small, large), (small2, large, small)]:
3583  if first is None:
3584  break # ignore last iter when small2 is None
3585  method_expanded = getattr(expanded[first], fn)
3586  method = getattr(first, fn)
3587  r1 = tensorfn(method_expanded, expanded[second], expanded[third])
3588  r2 = tensorfn(method, second, third)
3589  self.assertEqual(r1, r2)
3590 
3591  # now for torch. versions of functions
3592  if hasattr(torch, fn):
3593  fntorch = getattr(torch, fn)
3594  expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded}
3595 
3596  def torchfn(t1, t2, t3):
3597  if fn == "lerp":
3598  return fntorch(t1, t2, 0.5)
3599  elif fn == "masked_select":
3600  return fntorch(t1, t2 < 0)
3601  elif fn == "masked_scatter":
3602  return fntorch(t1, t2 < 0.5, full1d)
3603  elif fn == "masked_fill":
3604  return fntorch(t1, t2 < 0.5, 1.0)
3605  elif fn in fns_3_args:
3606  return fntorch(t1, 1.0, t2, t3)
3607  else:
3608  return fntorch(t1, t2)
3609 
3610  # test various orders
3611  for first, second, third in [(large, small, small2), (small, large, small2),
3612  (small2, small, large), (small2, large, small)]:
3613  if first is None:
3614  break # ignore last iter when small2 is None
3615  r1 = torchfn(expanded[first], expanded[second], expanded[third])
3616  r2 = torchfn(first, second, third)
3617  self.assertEqual(r1, r2)
3618 
3619  # now for in place functions
3620  # in-place tensor is not broadcastable; test only guaranteed
3621  # to work by broadcasting other argument(s)
3622  if not hasattr(large_expanded, fn + "_"):
3623  continue
3624 
3625  # need to clone largeExpanded so we can reuse, since functions are in-place
3626  large_expanded_clone = large_expanded.clone()
3627 
3628  def tensorfn_inplace(t0, t1, t2=None):
3629  t0_fn = getattr(t0, fn + "_")
3630  if fn == "lerp":
3631  return t0_fn(t1, 0.5)
3632  elif fn == "masked_scatter":
3633  return t0_fn(t1 < 0.5, full1d)
3634  elif fn == "masked_fill":
3635  return t0_fn(t1 < 0.5, 1.0)
3636  elif fn == "map":
3637  return t0_fn(t1, lambda x, y: x + y)
3638  elif fn == "map2":
3639  return t0_fn(t1, t2, lambda x, y, z: x + y + z)
3640  elif fn in fns_3_args:
3641  return t0_fn(1.0, t1, t2)
3642  else:
3643  return t0_fn(t1)
3644  r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
3645  r2 = tensorfn_inplace(large_expanded_clone, small, small2)
3646  # in-place pointwise operations don't actually work if the in-place
3647  # tensor is 0-strided (numpy has the same issue)
3648  if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()):
3649  self.assertEqual(r1, r2)
3650 
3651  def broadcastable(t0, t1, t2=None):
3652  try:
3653  t1.expand_as(t0)
3654  if t2 is not None:
3655  t2.expand_as(t0)
3656  except RuntimeError:
3657  return False
3658  return True
3659 
3660  def _test_in_place_broadcastable(t0, t1, t2=None):
3661  if not broadcastable(t0, t1, t2):
3662  same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True)
3663  if not same_size:
3664  self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2))
3665  else:
3666  tensorfn_inplace(t0, t1, t2)
3667 
3668  if fn not in fns_3_args:
3669  _test_in_place_broadcastable(small, large_expanded)
3670  _test_in_place_broadcastable(small, large)
3671  else:
3672  _test_in_place_broadcastable(small2, small_expanded, large_expanded)
3673  _test_in_place_broadcastable(small2, small, large)
3674 
3675  def test_broadcast(self):
3676  self._test_broadcast(self, lambda t: t)
3677 
3678  def test_broadcast_empty(self):
3679  # empty + empty
3680  self.assertRaises(RuntimeError, lambda: torch.randn(5, 0) + torch.randn(0, 5))
3681  self.assertEqual(torch.randn(5, 0), torch.randn(0) + torch.randn(5, 0))
3682  self.assertEqual(torch.randn(5, 0, 0), torch.randn(0) + torch.randn(5, 0, 1))
3683 
3684  # scalar + empty
3685  self.assertEqual(torch.randn(5, 0, 6), torch.randn(()) + torch.randn(5, 0, 6))
3686 
3687  # non-empty, empty
3688  self.assertEqual(torch.randn(0), torch.randn(0) + torch.randn(1))
3689  self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7),
3690  torch.randn(0, 7, 0, 6, 5, 0, 1) + torch.randn(1, 1, 5, 1, 7))
3691  self.assertRaises(RuntimeError, lambda: torch.randn(7, 0) + torch.randn(2, 1))
3692 
3693  def test_broadcast_tensors(self):
3694  x0 = torch.randn(2, 1, 3)
3695  x1 = torch.randn(3)
3696  x2 = torch.randn(3, 1)
3697  expected_size = (2, 3, 3)
3698 
3699  y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2)
3700  self.assertTrue(y0.size() == expected_size)
3701  self.assertTrue(y1.size() == expected_size)
3702  self.assertTrue(y2.size() == expected_size)
3703 
3704  @staticmethod
3705  def _test_contiguous(self, cast):
3706  x = cast(torch.randn(1, 16, 5, 5))
3707  self.assertTrue(x.is_contiguous())
3708  stride = list(x.stride())
3709  stride[0] = 20
3710  # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
3711  x.set_(x.storage(), 0, x.size(), stride)
3712  self.assertTrue(x.is_contiguous())
3713 
3714  def test_contiguous(self):
3715  return self._test_contiguous(self, lambda t: t)
3716 
3717  def test_empty_tensor_props(self):
3718  sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)]
3719  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
3720  for size in sizes:
3721  for device in devices:
3722  x = torch.empty(tuple(size), device=device)
3723  self.assertEqual(size, x.shape)
3724  self.assertTrue(x.is_contiguous())
3725  size_ones_instead_of_zeros = (x if x != 0 else 1 for x in size)
3726  y = torch.empty(tuple(size_ones_instead_of_zeros), device=device)
3727  self.assertEqual(x.stride(), y.stride())
3728 
3729  def test_scalars_as_floats(self):
3730  "zero-dim variables that don't require grad should bind to scalar arguments"
3731  x = torch.tensor(2.)
3732  y = torch.tensor(3.)
3733  # 3 + (3 * 3) * 2
3734  self.assertEqual(y.addcmul(y, y, value=x), 21)
3735 
3736  x = torch.tensor(2., requires_grad=True)
3737  self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x))
3738 
3739  @staticmethod
3740  def _test_broadcast_fused_matmul(self, cast):
3741  fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
3742 
3743  for fn in fns:
3744  batch_dim = random.randint(1, 8)
3745  n_dim = random.randint(1, 8)
3746  m_dim = random.randint(1, 8)
3747  p_dim = random.randint(1, 8)
3748 
3749  def dims_full_for_fn():
3750  if fn == "baddbmm":
3751  return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
3752  elif fn == "addbmm":
3753  return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
3754  elif fn == "addmm":
3755  return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
3756  elif fn == "addmv":
3757  return ([n_dim], [n_dim, m_dim], [m_dim])
3758  elif fn == "addr":
3759  return ([n_dim, m_dim], [n_dim], [m_dim])
3760  else:
3761  raise AssertionError("unknown function")
3762 
3763  (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
3764  (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
3765 
3766  t0_small = cast(torch.randn(*t0_dims_small).float())
3767  t1 = cast(torch.randn(*t1_dims).float())
3768  t2 = cast(torch.randn(*t2_dims).float())
3769 
3770  t0_full = cast(t0_small.expand(*t0_dims_full))
3771 
3772  fntorch = getattr(torch, fn)
3773  r0 = fntorch(t0_small, t1, t2)
3774  r1 = fntorch(t0_full, t1, t2)
3775  self.assertEqual(r0, r1)
3776 
3777  def test_broadcast_fused_matmul(self):
3778  self._test_broadcast_fused_matmul(self, lambda t: t)
3779 
3780  @staticmethod
3781  def _test_broadcast_batched_matmul(self, cast):
3782  n_dim = random.randint(1, 8)
3783  m_dim = random.randint(1, 8)
3784  p_dim = random.randint(1, 8)
3785  full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))]
3786  (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims)
3787 
3788  def verify_batched_matmul(full_lhs, one_dimensional):
3789  if not one_dimensional:
3790  lhs_dims = [n_dim, m_dim]
3791  rhs_dims = [m_dim, p_dim]
3792  result_dims = [n_dim, p_dim]
3793  else:
3794  lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim]
3795  rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim]
3796  result_dims = [n_dim] if full_lhs else [p_dim]
3797 
3798  lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim]
3799  rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1]
3800  full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims
3801  dim0_dims = rhs_dims if full_lhs else lhs_dims
3802  small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims)
3803 
3804  small = cast(torch.randn(*(small_dims)).float())
3805  dim0 = cast(torch.randn(*(dim0_dims)).float())
3806  full = cast(torch.randn(*(full_batch_dims + full_mat_dims)).float())
3807  if not one_dimensional:
3808  (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,))
3809  else:
3810  (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,))
3811 
3812  def maybe_squeeze_result(l, r, result):
3813  if len(lhs_dims) == 1 and l.dim() != 1:
3814  return result.squeeze(-2)
3815  elif len(rhs_dims) == 1 and r.dim() != 1:
3816  return result.squeeze(-1)
3817  else:
3818  return result
3819 
3820  for lhs in lhsTensors:
3821  lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims)))
3822  lhs_expanded_matmul_fn = getattr(lhs_expanded, "matmul")
3823  for rhs in rhsTensors:
3824  rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)).
3825  expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims))))
3826  truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded))
3827  for l in (lhs, lhs_expanded):
3828  for r in (rhs, rhs_expanded):
3829  l_matmul_fn = getattr(l, "matmul")
3830  result = maybe_squeeze_result(l, r, l_matmul_fn(r))
3831  self.assertEqual(truth, result)
3832  # test torch.matmul function as well
3833  torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
3834  self.assertEqual(truth, torch_result)
3835  # test torch.matmul with out
3836  out = torch.zeros_like(torch_result)
3837  torch.matmul(l, r, out=out)
3838  self.assertEqual(truth, maybe_squeeze_result(l, r, out))
3839 
3840  # compare to bmm
3841  bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
3842  rhs_expanded.contiguous().view(-1, *rhs_mat_dims)))
3843  self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims))
3844 
3845  for indices in product((True, False), repeat=2):
3846  verify_batched_matmul(*indices)
3847 
3848  def test_broadcast_batched_matmul(self):
3849  self._test_broadcast_batched_matmul(self, lambda t: t)
3850 
3851  def test_copy_broadcast(self):
3852  torch.zeros(5, 6).copy_(torch.zeros(6))
3853  self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30)))
3854 
3855  def test_randperm(self):
3856  _RNGState = torch.get_rng_state()
3857  res1 = torch.randperm(100)
3858  res2 = torch.LongTensor()
3859  torch.set_rng_state(_RNGState)
3860  torch.randperm(100, out=res2)
3861  self.assertEqual(res1, res2, 0)
3862 
3863  # randperm of 0 elements is an empty tensor
3864  res1 = torch.randperm(0)
3865  res2 = torch.LongTensor(5)
3866  torch.randperm(0, out=res2)
3867  self.assertEqual(res1.numel(), 0)
3868  self.assertEqual(res2.numel(), 0)
3869 
3870  def test_random(self):
3871  # This test is flaky with p<=(2/(ub-lb))^200=6e-36
3872  t = torch.FloatTensor(200)
3873  lb = 1
3874  ub = 4
3875 
3876  t.fill_(-1)
3877  t.random_(lb, ub)
3878  self.assertEqual(t.min(), lb)
3879  self.assertEqual(t.max(), ub - 1)
3880 
3881  t.fill_(-1)
3882  t.random_(ub)
3883  self.assertEqual(t.min(), 0)
3884  self.assertEqual(t.max(), ub - 1)
3885 
3886  @staticmethod
3887  def _test_random_neg_values(self, use_cuda=False):
3888  signed_types = ['torch.DoubleTensor', 'torch.FloatTensor', 'torch.LongTensor',
3889  'torch.IntTensor', 'torch.ShortTensor']
3890  for tname in signed_types:
3891  res = torch.rand(SIZE, SIZE).type(tname)
3892  if use_cuda:
3893  res = res.cuda()
3894  res.random_(-10, -1)
3895  self.assertLessEqual(res.max().item(), 9)
3896  self.assertGreaterEqual(res.min().item(), -10)
3897 
3898  def test_random_neg_values(self):
3899  self._test_random_neg_values(self)
3900 
3901  def assertIsOrdered(self, order, x, mxx, ixx, task):
3902  SIZE = 4
3903  if order == 'descending':
3904  def check_order(a, b):
3905  # `a != a` because we put NaNs
3906  # at the end of ascending sorted lists,
3907  # and the beginning of descending ones.
3908  return a != a or a >= b
3909  elif order == 'ascending':
3910  def check_order(a, b):
3911  # see above
3912  return b != b or a <= b
3913  else:
3914  error('unknown order "{}", must be "ascending" or "descending"'.format(order))
3915 
3916  are_ordered = True
3917  for j, k in product(range(SIZE), range(1, SIZE)):
3918  self.assertTrue(check_order(mxx[j][k - 1], mxx[j][k]),
3919  'torch.sort ({}) values unordered for {}'.format(order, task))
3920 
3921  seen = set()
3922  indicesCorrect = True
3923  size = x.size(x.dim() - 1)
3924  for k in range(size):
3925  seen.clear()
3926  for j in range(size):
3927  self.assertEqual(x[k][ixx[k][j]], mxx[k][j],
3928  'torch.sort ({}) indices wrong for {}'.format(order, task))
3929  seen.add(ixx[k][j])
3930  self.assertEqual(len(seen), size)
3931 
3932  def test_sort(self):
3933  SIZE = 4
3934  x = torch.rand(SIZE, SIZE)
3935  res1val, res1ind = torch.sort(x)
3936 
3937  # Test use of result tensor
3938  res2val = torch.Tensor()
3939  res2ind = torch.LongTensor()
3940  torch.sort(x, out=(res2val, res2ind))
3941  self.assertEqual(res1val, res2val, 0)
3942  self.assertEqual(res1ind, res2ind, 0)
3943  self.assertEqual(torch.argsort(x), res1ind)
3944  self.assertEqual(x.argsort(), res1ind)
3945 
3946  # Test sorting of random numbers
3947  self.assertIsOrdered('ascending', x, res2val, res2ind, 'random')
3948 
3949  # Test simple sort
3950  self.assertEqual(
3951  torch.sort(torch.Tensor((50, 40, 30, 20, 10)))[0],
3952  torch.Tensor((10, 20, 30, 40, 50)),
3953  0
3954  )
3955 
3956  # Test that we still have proper sorting with duplicate keys
3957  x = torch.floor(torch.rand(SIZE, SIZE) * 10)
3958  torch.sort(x, out=(res2val, res2ind))
3959  self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys')
3960 
3961  # DESCENDING SORT
3962  x = torch.rand(SIZE, SIZE)
3963  res1val, res1ind = torch.sort(x, x.dim() - 1, True)
3964 
3965  # Test use of result tensor
3966  res2val = torch.Tensor()
3967  res2ind = torch.LongTensor()
3968  torch.sort(x, x.dim() - 1, True, out=(res2val, res2ind))
3969  self.assertEqual(res1val, res2val, 0)
3970  self.assertEqual(res1ind, res2ind, 0)
3971  self.assertEqual(torch.argsort(x, x.dim() - 1, True), res1ind)
3972  self.assertEqual(x.argsort(x.dim() - 1, True), res1ind)
3973 
3974  # Test sorting of random numbers
3975  self.assertIsOrdered('descending', x, res2val, res2ind, 'random')
3976 
3977  # Test simple sort task
3978  self.assertEqual(
3979  torch.sort(torch.Tensor((10, 20, 30, 40, 50)), 0, True)[0],
3980  torch.Tensor((50, 40, 30, 20, 10)),
3981  0
3982  )
3983 
3984  # Test that we still have proper sorting with duplicate keys
3985  self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys')
3986 
3987  # Test sorting with NaNs
3988  x = torch.rand(SIZE, SIZE)
3989  x[1][2] = float('NaN')
3990  x[3][0] = float('NaN')
3991  torch.sort(x, out=(res2val, res2ind))
3992  self.assertIsOrdered('ascending', x, res2val, res2ind,
3993  'random with NaNs')
3994  torch.sort(x, out=(res2val, res2ind), descending=True)
3995  self.assertIsOrdered('descending', x, res2val, res2ind,
3996  'random with NaNs')
3997 
3998  @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
3999  def test_tensordot(self):
4000  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
4001  for d in devices:
4002  a = torch.arange(60., device=d).reshape(3, 4, 5)
4003  b = torch.arange(24., device=d).reshape(4, 3, 2)
4004  c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
4005  cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
4006  axes=([1, 0], [0, 1])))
4007  self.assertEqual(c, cn)
4008  a = torch.randn(2, 3, 4, 5, device=d)
4009  b = torch.randn(4, 5, 6, 7, device=d)
4010  c = torch.tensordot(a, b, dims=2).cpu()
4011  cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
4012  axes=2))
4013  self.assertEqual(c, cn)
4014  c = torch.tensordot(a, b).cpu()
4015  cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
4016  self.assertEqual(c, cn)
4017 
4018  def test_topk(self):
4019  def topKViaSort(t, k, dim, dir):
4020  sorted, indices = t.sort(dim, dir)
4021  return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
4022 
4023  def compareTensors(t, res1, ind1, res2, ind2, dim):
4024  # Values should be exactly equivalent
4025  self.assertEqual(res1, res2, 0)
4026 
4027  # Indices might differ based on the implementation, since there is
4028  # no guarantee of the relative order of selection
4029  if not ind1.eq(ind2).all():
4030  # To verify that the indices represent equivalent elements,
4031  # gather from the input using the topk indices and compare against
4032  # the sort indices
4033  vals = t.gather(dim, ind2)
4034  self.assertEqual(res1, vals, 0)
4035 
4036  def compare(t, k, dim, dir):
4037  topKVal, topKInd = t.topk(k, dim, dir, True)
4038  sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
4039  compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
4040 
4041  t = torch.rand(random.randint(1, SIZE),
4042  random.randint(1, SIZE),
4043  random.randint(1, SIZE))
4044 
4045  for _kTries in range(3):
4046  for _dimTries in range(3):
4047  for transpose in (True, False):
4048  for dir in (True, False):
4049  testTensor = t
4050  if transpose:
4051  dim1 = random.randrange(t.ndimension())
4052  dim2 = dim1
4053  while dim1 == dim2:
4054  dim2 = random.randrange(t.ndimension())
4055 
4056  testTensor = t.transpose(dim1, dim2)
4057 
4058  dim = random.randrange(testTensor.ndimension())
4059  k = random.randint(1, testTensor.size(dim))
4060  compare(testTensor, k, dim, dir)
4061 
4062  def test_topk_arguments(self):
4063  q = torch.randn(10, 2, 10)
4064  # Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)
4065  self.assertRaises(TypeError, lambda: q.topk(4, True))
4066 
4067  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
4068  def test_topk_noncontiguous_gpu(self):
4069  t = torch.randn(20, device="cuda")[::2]
4070  top1, idx1 = t.topk(5)
4071  top2, idx2 = t.contiguous().topk(5)
4072  self.assertEqual(top1, top2)
4073  self.assertEqual(idx1, idx2)
4074 
4075  @staticmethod
4076  def _test_kthvalue(self, device='cpu'):
4077  SIZE = 50
4078  x = torch.rand(SIZE, SIZE, SIZE, device=device)
4079  x0 = x.clone()
4080 
4081  k = random.randint(1, SIZE)
4082  res1val, res1ind = torch.kthvalue(x, k, keepdim=False)
4083  res2val, res2ind = torch.sort(x)
4084 
4085  self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4086  self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4087  # test use of result tensors
4088  k = random.randint(1, SIZE)
4089  res1val = torch.tensor([], device=device)
4090  res1ind = torch.tensor([], dtype=torch.long, device=device)
4091  torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind))
4092  res2val, res2ind = torch.sort(x)
4093  self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4094  self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4095 
4096  # test non-default dim
4097  k = random.randint(1, SIZE)
4098  res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=False)
4099  res2val, res2ind = torch.sort(x, 0)
4100  self.assertEqual(res1val, res2val[k - 1], 0)
4101  self.assertEqual(res1ind, res2ind[k - 1], 0)
4102 
4103  # non-contiguous
4104  y = x.narrow(1, 0, 1)
4105  y0 = y.contiguous()
4106  k = random.randint(1, SIZE)
4107  res1val, res1ind = torch.kthvalue(y, k)
4108  res2val, res2ind = torch.kthvalue(y0, k)
4109  self.assertEqual(res1val, res2val, 0)
4110  self.assertEqual(res1ind, res2ind, 0)
4111 
4112  # check that the input wasn't modified
4113  self.assertEqual(x, x0, 0)
4114 
4115  # simple test case (with repetitions)
4116  y = torch.tensor((3., 5, 4, 1, 1, 5), device=device)
4117  self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0)
4118  self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0)
4119 
4120  # simple test case (with NaN)
4121  SIZE = 50
4122  x = torch.rand(SIZE, SIZE, SIZE, device=device)
4123  x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan
4124  ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1]
4125  res2val, res2ind = torch.sort(x)
4126  for k in ks:
4127  res1val, res1ind = torch.kthvalue(x, k, keepdim=False)
4128  self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
4129  self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
4130 
4131  def test_kthvalue(self):
4132  self._test_kthvalue(self)
4133 
4134  def test_median(self):
4135  for size in (155, 156):
4136  x = torch.rand(size, size)
4137  x0 = x.clone()
4138 
4139  nelem = x.nelement()
4140  res1val = torch.median(x)
4141  res2val, _ = torch.sort(x.view(nelem))
4142  ind = int(math.floor((nelem + 1) / 2) - 1)
4143 
4144  self.assertEqual(res2val[ind], res1val, 0)
4145 
4146  res1val, res1ind = torch.median(x, dim=1, keepdim=False)
4147  res2val, res2ind = torch.sort(x)
4148  ind = int(math.floor((size + 1) / 2) - 1)
4149 
4150  self.assertEqual(res2val.select(1, ind), res1val, 0)
4151  self.assertEqual(res2val.select(1, ind), res1val, 0)
4152 
4153  # Test use of result tensor
4154  res2val = torch.Tensor()
4155  res2ind = torch.LongTensor()
4156  torch.median(x, dim=-1, keepdim=False, out=(res2val, res2ind))
4157  self.assertEqual(res2val, res1val, 0)
4158  self.assertEqual(res2ind, res1ind, 0)
4159 
4160  # Test non-default dim
4161  res1val, res1ind = torch.median(x, 0, keepdim=False)
4162  res2val, res2ind = torch.sort(x, 0)
4163  self.assertEqual(res1val, res2val[ind], 0)
4164  self.assertEqual(res1ind, res2ind[ind], 0)
4165 
4166  # input unchanged
4167  self.assertEqual(x, x0, 0)
4168 
4169  def test_mode(self):
4170  x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE)
4171  x[:2] = 1
4172  x[:, :2] = 1
4173  x0 = x.clone()
4174 
4175  # Pre-calculated results.
4176  res1val = torch.Tensor(SIZE).fill_(1)
4177  # The indices are the position of the last appearance of the mode element.
4178  res1ind = torch.LongTensor(SIZE).fill_(1)
4179  res1ind[0] = SIZE - 1
4180  res1ind[1] = SIZE - 1
4181 
4182  res2val, res2ind = torch.mode(x, keepdim=False)
4183  self.assertEqual(res1val, res2val, 0)
4184  self.assertEqual(res1ind, res2ind, 0)
4185 
4186  # Test use of result tensor
4187  res2val = torch.Tensor()
4188  res2ind = torch.LongTensor()
4189  torch.mode(x, keepdim=False, out=(res2val, res2ind))
4190  self.assertEqual(res1val, res2val, 0)
4191  self.assertEqual(res1ind, res2ind, 0)
4192 
4193  # Test non-default dim
4194  res2val, res2ind = torch.mode(x, 0, False)
4195  self.assertEqual(res1val, res2val, 0)
4196  self.assertEqual(res1ind, res2ind, 0)
4197 
4198  # input unchanged
4199  self.assertEqual(x, x0, 0)
4200 
4201  def test_trilu_indices(self):
4202  for test_args in tri_tests_args:
4203  _compare_trilu_indices(self, *test_args)
4204  run_additional_tri_tests(self, 'cpu')
4205 
4206  # test default options
4207  x = torch.ones(
4208  3, 3, dtype=torch.long, device='cpu', layout=torch.strided)
4209  self.assertEqual(
4210  x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3))
4211  self.assertEqual(
4212  x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3))
4213 
4214  @staticmethod
4215  def _test_triu_tril(self, cast):
4216  def gen_mask(shape, diagonal, cast, upper):
4217  mask = torch.zeros(*shape[-2:]).byte()
4218  for i in range(shape[-2]):
4219  for j in range(shape[-1]):
4220  cond = j - i < diagonal if upper else j - i > diagonal
4221  if cond:
4222  mask[i, j] = 1
4223  return cast(mask.expand(*shape))
4224 
4225  torch_functions = {True: torch.triu, False: torch.tril}
4226  if TEST_NUMPY:
4227  numpy_functions = {True: np.triu, False: np.tril}
4228 
4229  def run_test(shape, cast, diagonal):
4230  x_cpu = torch.randn(*shape)
4231  x = cast(x_cpu)
4232 
4233  for upper in [True, False]:
4234  # normal test with mask
4235  torch_tri_func = torch_functions[upper]
4236  res1 = torch_tri_func(x, diagonal=diagonal)
4237  res2 = cast(torch.Tensor())
4238  torch_tri_func(x, diagonal=diagonal, out=res2)
4239  exp_mask = gen_mask(shape, diagonal, cast, upper)
4240  expected = torch.where(exp_mask, torch.tensor(0).type_as(x), x)
4241  self.assertEqual(res1, res2, 0)
4242  self.assertEqual(expected, res1, 0)
4243 
4244  # non-contiguous and expanded tensors test
4245  if not (0 in shape or 1 in shape):
4246  for s in range(-len(shape), -1):
4247  # non-contiguous tensors
4248  x_nc = x.clone().transpose(s, s + 1)
4249  exp_mask = gen_mask(x_nc.size(), diagonal, cast, upper)
4250  assert not x_nc.is_contiguous(), "x is intentionally non-contiguous"
4251  exp_nc = torch.where(exp_mask, torch.tensor(0).type_as(x), x_nc)
4252  self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
4253  x_nc_is_contiguous = x_nc.is_contiguous()
4254  if upper:
4255  self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
4256  else:
4257  self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
4258 
4259  self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous,
4260  "contiguity of x_nc should not be changed")
4261 
4262  # expanded tensors
4263  expanded_size = (x.size(0),) + x.size()
4264  x_expanded = x.clone().expand(*expanded_size)
4265  assert 0 in x_expanded.stride(), "x intentionally has 0 in its stride"
4266  output = torch_tri_func(x_expanded, diagonal)
4267  self.assertEqual(output, expected.expand(expanded_size), 0)
4268  self.assertTrue(0 in x_expanded.stride(),
4269  "geometry of x_expanded should be the same")
4270  if upper:
4271  self.assertEqual(output, x_expanded.triu_(diagonal), 0)
4272  else:
4273  self.assertEqual(output, x_expanded.tril_(diagonal), 0)
4274 
4275  if not TEST_NUMPY:
4276  continue
4277 
4278  # numpy test
4279  numpy_tri_func = numpy_functions[upper]
4280  self.assertEqual(numpy_tri_func(x_cpu.numpy(), diagonal), res1.cpu().numpy())
4281 
4282  diagonals = [-2, -1, 0, 1, 2]
4283  shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices
4284  (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices
4285  (3, 7), (5, 3, 7), (7, 5, 3, 7), # thin matrices
4286  (3, 0), (0, 3, 3), (3, 3, 0, 0), # no numel matrices
4287  (3, 1), (5, 3, 1), (7, 5, 3, 1), # very fat matrices
4288  (1, 3), (5, 1, 3), (7, 5, 1, 3)] # very thin matrices
4289  for s, d in product(shapes, diagonals):
4290  run_test(s, cast, d)
4291 
4292  def test_triu_tril(self):
4293  self._test_triu_tril(self, lambda t: t)
4294 
4295  def test_cat(self):
4296  SIZE = 10
4297  for dtype in (torch.half, torch.double, torch.int):
4298  for dim in range(-3, 3):
4299  pos_dim = dim if dim >= 0 else 3 + dim
4300  x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4301  y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4302  z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE)).to(dtype).transpose(0, pos_dim)
4303 
4304  res1 = torch.cat((x, y, z), dim)
4305  self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
4306  self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
4307  self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)
4308 
4309  x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE)).to(dtype)
4310  self.assertEqual(torch.cat(torch.split(x, 7)), x)
4311  self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
4312 
4313  y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE)).to(dtype)
4314  z = torch.cat([x, y])
4315  self.assertEqual(z.size(), (21, SIZE, SIZE))
4316 
4317  self.assertRaises(RuntimeError, lambda: torch.cat([]))
4318  self.assertRaisesRegex(TypeError, 'got None', lambda: torch.cat([x, None]))
4319 
4320  def test_cat_bad_input_sizes(self):
4321  x = torch.randn(2, 1)
4322  y = torch.randn(2, 1, 1)
4323  z = torch.randn(2, 1, 1)
4324  self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z]))
4325 
4326  x = torch.randn(2, 1, 2)
4327  y = torch.randn(2, 1, 1)
4328  z = torch.randn(2, 2, 1)
4329  self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))
4330 
4331  def test_cat_scalars(self):
4332  x = torch.tensor(0)
4333  y = torch.tensor(1)
4334  with self.assertRaisesRegex(RuntimeError, 'zero-dimensional.*cannot be concatenated'):
4335  torch.cat([x, y])
4336 
4337  @staticmethod
4338  def _test_cat_empty_legacy(self, use_cuda=False):
4339  # FIXME: this is legacy behavior and should be removed
4340  # when we support empty tensors with arbitrary sizes
4341  dtype = torch.float32
4342  device = 'cuda' if use_cuda else 'cpu'
4343 
4344  x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
4345  empty = torch.randn((0,), dtype=dtype, device=device)
4346 
4347  res1 = torch.cat([x, empty], dim=1)
4348  res2 = torch.cat([empty, x], dim=1)
4349  self.assertEqual(res1, res2)
4350 
4351  conv = torch.nn.Conv2d(3, 3, kernel_size=1).float()
4352  if use_cuda:
4353  conv = conv.cuda()
4354  res1 = torch.cat([conv(x), empty], dim=1)
4355  res2 = torch.cat([empty, conv(x)], dim=1)
4356  self.assertEqual(res1, res2)
4357 
4358  res1 = torch.cat([empty, empty], dim=1)
4359  self.assertEqual(res1, empty)
4360 
4361  with self.assertRaisesRegex(RuntimeError,
4362  'expected a non-empty list of Tensors'):
4363  torch.cat([], dim=1)
4364 
4365  def test_cat_empty_legacy(self):
4366  self._test_cat_empty_legacy(self)
4367 
4368  @staticmethod
4369  def _test_cat_empty(self, use_cuda=False):
4370  dtype = torch.float32
4371  device = 'cuda' if use_cuda else 'cpu'
4372 
4373  x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device)
4374  empty = torch.randn((4, 0, 32, 32), dtype=dtype, device=device)
4375 
4376  res1 = torch.cat([x, empty], dim=1)
4377  res2 = torch.cat([empty, x], dim=1)
4378  self.assertEqual(res1, res2)
4379 
4380  conv = torch.nn.Conv2d(3, 3, kernel_size=1).float()
4381  if use_cuda:
4382  conv = conv.cuda()
4383  res1 = torch.cat([conv(x), empty], dim=1)
4384  res2 = torch.cat([empty, conv(x)], dim=1)
4385  self.assertEqual(res1, res2)
4386 
4387  res1 = torch.cat([empty, empty], dim=1)
4388  self.assertEqual(res1, empty)
4389 
4390  # check non-legacy-behavior (sizes don't match)
4391  empty = torch.randn((4, 0, 31, 32), dtype=dtype, device=device)
4392  self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1))
4393  self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1))
4394 
4395  # check non-legacy-behavior (dimensions don't match)
4396  empty = torch.randn((4, 0), dtype=dtype, device=device)
4397  self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1))
4398  self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1))
4399 
4400  def test_cat_empty(self):
4401  self._test_cat_empty(self)
4402 
4403  def test_narrow(self):
4404  x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
4405  self.assertEqual(x.narrow(0, 0, 1), torch.Tensor([[0, 1, 2]]))
4406  self.assertEqual(x.narrow(0, 0, 2), torch.Tensor([[0, 1, 2], [3, 4, 5]]))
4407  self.assertEqual(x.narrow(0, 1, 1), torch.Tensor([[3, 4, 5]]))
4408  self.assertEqual(x.narrow(0, -1, 1), torch.Tensor([[6, 7, 8]]))
4409  self.assertEqual(x.narrow(0, -2, 2), torch.Tensor([[3, 4, 5], [6, 7, 8]]))
4410  self.assertEqual(x.narrow(0, -3, 3), torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
4411  self.assertEqual(x.narrow(-1, -1, 1), torch.Tensor([[2], [5], [8]]))
4412  self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]]))
4413 
4414  def test_narrow_empty(self):
4415  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
4416  for device in devices:
4417  x = torch.randn(2, 3, 4, device=device)
4418  for d in range(x.dim()):
4419  y = x.narrow(d, x.size(d), 0)
4420  sz = list(x.size())
4421  sz[d] = 0
4422  self.assertEqual(sz, y.size())
4423 
4424  def test_stack(self):
4425  for dtype in (torch.half, torch.double, torch.int):
4426  x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4427  y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4428  z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4429  for dim in range(4):
4430  res = torch.stack((x, y, z), dim)
4431  res_neg = torch.stack((x, y, z), dim - 4)
4432  expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
4433  self.assertEqual(res, res_neg)
4434  self.assertEqual(res.size(), expected_size)
4435  self.assertEqual(res.select(dim, 0), x, 0)
4436  self.assertEqual(res.select(dim, 1), y, 0)
4437  self.assertEqual(res.select(dim, 2), z, 0)
4438 
4439  def test_stack_out(self):
4440  for dtype in (torch.half, torch.double, torch.int):
4441  x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4442  y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4443  z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype)
4444  for dim in range(4):
4445  expected_size = x.size()[:dim] + (3,) + x.size()[dim:]
4446  res_out = x.new(expected_size)
4447  res_neg_out = x.new(expected_size)
4448  res_out_dp = res_out.data_ptr()
4449  res_out_neg_dp = res_neg_out.data_ptr()
4450  torch.stack((x, y, z), dim, out=res_out)
4451  torch.stack((x, y, z), dim - 4, out=res_neg_out)
4452  self.assertEqual(res_out, res_neg_out)
4453  self.assertEqual(res_out.size(), expected_size)
4454  self.assertEqual(res_out_dp, res_out.data_ptr())
4455  self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr())
4456  self.assertEqual(res_out.select(dim, 0), x, 0)
4457  self.assertEqual(res_out.select(dim, 1), y, 0)
4458  self.assertEqual(res_out.select(dim, 2), z, 0)
4459 
4460  def test_unbind(self):
4461  x = torch.rand(2, 3, 4, 5)
4462  for dim in range(4):
4463  res = torch.unbind(x, dim)
4464  res2 = x.unbind(dim)
4465  self.assertEqual(x.size(dim), len(res))
4466  self.assertEqual(x.size(dim), len(res2))
4467  for i in range(dim):
4468  self.assertEqual(x.select(dim, i), res[i])
4469  self.assertEqual(x.select(dim, i), res2[i])
4470 
4471  @skipIfRocm
4472  def test_linspace(self):
4473  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
4474  for device in devices:
4475  _from = random.random()
4476  to = _from + random.random()
4477  res1 = torch.linspace(_from, to, 137, device=device)
4478  res2 = torch.tensor((), device=device)
4479  torch.linspace(_from, to, 137, out=res2)
4480  self.assertEqual(res1, res2, 0)
4481  self.assertRaises(RuntimeError, lambda: torch.linspace(0, 1, -1, device=device))
4482  self.assertEqual(torch.linspace(0, 1, 1, device=device), torch.zeros(1, device=device), 0)
4483 
4484  # Check linspace for generating with start > end.
4485  self.assertEqual(torch.linspace(2, 0, 3, device=device), torch.tensor((2, 1, 0), device=device), 0)
4486 
4487  # Check linspace for non-contiguous tensors.
4488  x = torch.zeros(2, 3, device=device)
4489  y = torch.linspace(0, 3, 4, out=x.narrow(1, 1, 2))
4490  self.assertEqual(x, torch.tensor(((0, 0, 1), (0, 2, 3)), device=device), 0)
4491 
4492  def test_logspace(self):
4493  _from = random.random()
4494  to = _from + random.random()
4495  res1 = torch.logspace(_from, to, 137)
4496  res2 = torch.Tensor()
4497  torch.logspace(_from, to, 137, out=res2)
4498  self.assertEqual(res1, res2, 0)
4499  self.assertRaises(RuntimeError, lambda: torch.logspace(0, 1, -1))
4500  self.assertEqual(torch.logspace(0, 1, 1), torch.ones(1), 0)
4501 
4502  # Check logspace_ for generating with start > end.
4503  self.assertEqual(torch.logspace(1, 0, 2), torch.Tensor((10, 1)), 0)
4504 
4505  # Check logspace_ for non-contiguous tensors.
4506  x = torch.zeros(2, 3)
4507  y = torch.logspace(0, 3, 4, out=x.narrow(1, 1, 2))
4508  self.assertEqual(x, torch.Tensor(((0, 1, 10), (0, 100, 1000))), 0)
4509 
4510  def test_rand(self):
4511  torch.manual_seed(123456)
4512  res1 = torch.rand(SIZE, SIZE)
4513  res2 = torch.Tensor()
4514  torch.manual_seed(123456)
4515  torch.rand(SIZE, SIZE, out=res2)
4516  self.assertEqual(res1, res2)
4517 
4518  def test_randint(self):
4519  torch.manual_seed(123456)
4520  res1 = torch.randint(0, 6, (SIZE, SIZE))
4521  res2 = torch.Tensor()
4522  torch.manual_seed(123456)
4523  torch.randint(0, 6, (SIZE, SIZE), out=res2)
4524  torch.manual_seed(123456)
4525  res3 = torch.randint(6, (SIZE, SIZE))
4526  res4 = torch.Tensor()
4527  torch.manual_seed(123456)
4528  torch.randint(6, (SIZE, SIZE), out=res4)
4529  self.assertEqual(res1, res2)
4530  self.assertEqual(res1, res3)
4531  self.assertEqual(res1, res4)
4532  self.assertEqual(res2, res3)
4533  self.assertEqual(res2, res4)
4534  self.assertEqual(res3, res4)
4535  res1 = res1.view(-1)
4536  high = (res1 < 6).type(torch.LongTensor)
4537  low = (res1 >= 0).type(torch.LongTensor)
4538  tensorSize = res1.size()[0]
4539  assert(tensorSize == high.sum())
4540  assert(tensorSize == low.sum())
4541 
4542  def test_randn(self):
4543  torch.manual_seed(123456)
4544  res1 = torch.randn(SIZE, SIZE)
4545  res2 = torch.Tensor()
4546  torch.manual_seed(123456)
4547  torch.randn(SIZE, SIZE, out=res2)
4548  self.assertEqual(res1, res2)
4549 
4550  def test_slice(self):
4551  empty = torch.empty(0, 4)
4552  x = torch.arange(0., 16).view(4, 4)
4553  self.assertEqual(x[:], x)
4554  self.assertEqual(x[:4], x)
4555  # start and stop are clamped to the size of dim
4556  self.assertEqual(x[:5], x)
4557  # if start >= stop then the result is empty
4558  self.assertEqual(x[2:1], empty)
4559  self.assertEqual(x[2:2], empty)
4560  # out of bounds is also empty
4561  self.assertEqual(x[10:12], empty)
4562  # additional correctness checks
4563  self.assertEqual(x[:1].data.tolist(), [[0, 1, 2, 3]])
4564  self.assertEqual(x[:-3].data.tolist(), [[0, 1, 2, 3]])
4565  self.assertEqual(x[:, -2:3].data.tolist(), [[2], [6], [10], [14]])
4566  self.assertEqual(x[0:-1:2].data.tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]])
4567 
4568  def test_is_signed(self):
4569  self.assertEqual(torch.IntTensor(5).is_signed(), True)
4570  self.assertEqual(torch.ByteTensor(5).is_signed(), False)
4571  self.assertEqual(torch.CharTensor(5).is_signed(), True)
4572  self.assertEqual(torch.FloatTensor(5).is_signed(), True)
4573  self.assertEqual(torch.HalfTensor(10).is_signed(), True)
4574 
4575  @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
4576  def test_is_signed_cuda(self):
4577  self.assertEqual(torch.cuda.IntTensor(5).is_signed(), True)
4578  self.assertEqual(torch.cuda.ByteTensor(5).is_signed(), False)
4579  self.assertEqual(torch.cuda.CharTensor(5).is_signed(), True)
4580  self.assertEqual(torch.cuda.FloatTensor(5).is_signed(), True)
4581  self.assertEqual(torch.cuda.HalfTensor(10).is_signed(), True)
4582 
4583  @staticmethod
4584  def _test_solve(self, cast):
4585  a = cast(torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
4586  (-6.05, -3.30, 5.36, -4.44, 1.08),
4587  (-0.45, 2.58, -2.70, 0.27, 9.04),
4588  (8.32, 2.71, 4.35, -7.17, 2.14),
4589  (-9.67, -5.14, -7.26, 6.08, -6.87)))).t()
4590  b = cast(torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
4591  (-1.56, 4.00, -8.67, 1.75, 2.86),
4592  (9.81, -4.09, -4.57, -8.61, 8.99)))).t()
4593 
4594  res1 = torch.solve(b, a)[0]
4595  self.assertLessEqual(b.dist(torch.mm(a, res1)), 1e-12)
4596 
4597  ta = cast(torch.Tensor())
4598  tb = cast(torch.Tensor())
4599  res2 = torch.solve(b, a, out=(tb, ta))[0]
4600  res3 = torch.solve(b, a, out=(b, a))[0]
4601  self.assertEqual(res1, tb)
4602  self.assertEqual(res1, b)
4603  self.assertEqual(res1, res2)
4604  self.assertEqual(res1, res3)
4605 
4606  # test reuse
4607  res1 = torch.solve(b, a)[0]
4608  ta = cast(torch.Tensor())
4609  tb = cast(torch.Tensor())
4610  torch.solve(b, a, out=(tb, ta))[0]
4611  self.assertEqual(res1, tb)
4612  torch.solve(b, a, out=(tb, ta))[0]
4613  self.assertEqual(res1, tb)
4614 
4615  @skipIfNoLapack
4616  def test_solve(self):
4617  self._test_solve(self, lambda t: t)
4618 
4619  @staticmethod
4620  def _test_solve_batched(self, cast):
4621  from common_utils import random_fullrank_matrix_distinct_singular_value
4622  # test against solve: one batch
4623  A = cast(random_fullrank_matrix_distinct_singular_value(5, 1))
4624  b = cast(torch.randn(1, 5, 10))
4625  x_exp, LU_exp = torch.solve(b.squeeze(0), A.squeeze(0))
4626  x, LU = torch.solve(b, A)
4627  self.assertEqual(x, x_exp.unsqueeze(0))
4628  self.assertEqual(LU, LU_exp.unsqueeze(0))
4629 
4630  # test against solve in a loop: four batches
4631  A = cast(random_fullrank_matrix_distinct_singular_value(5, 4))
4632  b = cast(torch.randn(4, 5, 10))
4633 
4634  x_exp_list = []
4635  LU_exp_list = []
4636  for i in range(4):
4637  x_exp, LU_exp = torch.solve(b[i], A[i])
4638  x_exp_list.append(x_exp)
4639  LU_exp_list.append(LU_exp)
4640  x_exp = torch.stack(x_exp_list)
4641  LU_exp = torch.stack(LU_exp_list)
4642 
4643  x, LU = torch.solve(b, A)
4644  self.assertEqual(x, x_exp)
4645  self.assertEqual(LU, LU_exp)
4646 
4647  # basic correctness test
4648  A = cast(random_fullrank_matrix_distinct_singular_value(5, 3))
4649  b = cast(torch.randn(3, 5, 10))
4650  x, LU = torch.solve(b, A)
4651  self.assertEqual(torch.matmul(A, x), b)
4652 
4653  # Test non-contiguous inputs.
4654  if not TEST_NUMPY:
4655  return
4656  import numpy
4657  from numpy.linalg import solve
4658  A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2)
4659  b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0)
4660  x, _ = torch.solve(b, A)
4661  x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4662  self.assertEqual(x.data, cast(x_exp))
4663 
4664  @skipIfNoLapack
4665  def test_solve_batched(self):
4666  self._test_solve_batched(self, lambda t: t)
4667 
4668  @staticmethod
4669  def _test_solve_batched_dims(self, cast):
4670  if not TEST_NUMPY:
4671  return
4672 
4673  from numpy.linalg import solve
4674  from common_utils import random_fullrank_matrix_distinct_singular_value
4675  # test against numpy.linalg.solve
4676  A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3))
4677  b = cast(torch.randn(2, 1, 3, 4, 6))
4678  x, _ = torch.solve(b, A)
4679  x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4680  self.assertEqual(x.data, cast(x_exp))
4681 
4682  # test column major format
4683  A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3)).transpose(-2, -1)
4684  b = cast(torch.randn(2, 1, 3, 6, 4)).transpose(-2, -1)
4685  assert not A.is_contiguous()
4686  assert not b.is_contiguous()
4687  x, _ = torch.solve(b, A)
4688  x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4689  self.assertEqual(x.data, cast(x_exp))
4690 
4691  # broadcasting b
4692  A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3))
4693  b = cast(torch.randn(4, 6))
4694  x, _ = torch.solve(b, A)
4695  x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4696  self.assertEqual(x.data, cast(x_exp))
4697 
4698  # broadcasting A
4699  A = cast(random_fullrank_matrix_distinct_singular_value(4))
4700  b = cast(torch.randn(2, 1, 3, 4, 2))
4701  x, _ = torch.solve(b, A)
4702  x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4703  self.assertEqual(x.data, cast(x_exp))
4704 
4705  # broadcasting both A & b
4706  A = cast(random_fullrank_matrix_distinct_singular_value(4, 1, 3, 1))
4707  b = cast(torch.randn(2, 1, 3, 4, 5))
4708  x, _ = torch.solve(b, A)
4709  x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
4710  self.assertEqual(x.data, cast(x_exp))
4711 
4712  @skipIfNoLapack
4713  def test_solve_batched_dims(self):
4714  self._test_solve_batched_dims(self, lambda t: t)
4715 
4716  def test_solve_methods_arg_device(self):
4717  if not torch.cuda.is_available():
4718  return
4719 
4720  for b_device, A_device in product(['cpu', 'cuda'], repeat=2):
4721  if b_device == A_device:
4722  continue
4723 
4724  b = torch.randn(3, 1, device=b_device)
4725  A = torch.randn(3, 3, device=A_device)
4726  err_str = "Expected b and A to be on the same device"
4727  with self.assertRaisesRegex(RuntimeError, err_str):
4728  torch.gesv(b, A)
4729 
4730  with self.assertRaisesRegex(RuntimeError, err_str):
4731  torch.cholesky_solve(b, A)
4732 
4733  @skipIfNoLapack
4734  def test_qr(self):
4735 
4736  # Since the QR decomposition is unique only up to the signs of the rows of
4737  # R, we must ensure these are positive before doing the comparison.
4738  def canonicalize(q, r):
4739  d = r.diag().sign().diag()
4740  return torch.mm(q, d), torch.mm(d, r)
4741 
4742  def canon_and_check(q, r, expected_q, expected_r):
4743  q_canon, r_canon = canonicalize(q, r)
4744  expected_q_canon, expected_r_canon = canonicalize(expected_q, expected_r)
4745  self.assertEqual(q_canon, expected_q_canon)
4746  self.assertEqual(r_canon, expected_r_canon)
4747 
4748  def check_qr(a, expected_q, expected_r):
4749  # standard invocation
4750  q, r = torch.qr(a)
4751  canon_and_check(q, r, expected_q, expected_r)
4752 
4753  # in-place
4754  q, r = torch.Tensor(), torch.Tensor()
4755  torch.qr(a, out=(q, r))
4756  canon_and_check(q, r, expected_q, expected_r)
4757 
4758  # manually calculate qr using geqrf and orgqr
4759  m = a.size(0)
4760  n = a.size(1)
4761  k = min(m, n)
4762  result, tau = torch.geqrf(a)
4763  self.assertEqual(result.size(0), m)
4764  self.assertEqual(result.size(1), n)
4765  self.assertEqual(tau.size(0), k)
4766  r = torch.triu(result.narrow(0, 0, k))
4767  q = torch.orgqr(result, tau)
4768  q, r = q.narrow(1, 0, k), r
4769  canon_and_check(q, r, expected_q, expected_r)
4770 
4771  # check square case
4772  a = torch.Tensor(((1, 2, 3), (4, 5, 6), (7, 8, 10)))
4773 
4774  expected_q = torch.Tensor((
4775  (-1.230914909793328e-01, 9.045340337332914e-01, 4.082482904638621e-01),
4776  (-4.923659639173310e-01, 3.015113445777629e-01, -8.164965809277264e-01),
4777  (-8.616404368553292e-01, -3.015113445777631e-01, 4.082482904638634e-01)))
4778  expected_r = torch.Tensor((
4779  (-8.124038404635959e+00, -9.601136296387955e+00, -1.193987e+01),
4780  (0.000000000000000e+00, 9.045340337332926e-01, 1.507557e+00),
4781  (0.000000000000000e+00, 0.000000000000000e+00, 4.082483e-01)))
4782 
4783  check_qr(a, expected_q, expected_r)
4784 
4785  # check rectangular thin
4786  a = torch.Tensor((
4787  (1, 2, 3),
4788  (4, 5, 6),
4789  (7, 8, 9),
4790  (10, 11, 13),
4791  ))
4792  expected_q = torch.Tensor((
4793  (-0.0776150525706334, -0.833052161400748, 0.3651483716701106),
4794  (-0.3104602102825332, -0.4512365874254053, -0.1825741858350556),
4795  (-0.5433053679944331, -0.0694210134500621, -0.7302967433402217),
4796  (-0.7761505257063329, 0.3123945605252804, 0.5477225575051663)
4797  ))
4798  expected_r = torch.Tensor((
4799  (-12.8840987267251261, -14.5916298832790581, -17.0753115655393231),
4800  (0, -1.0413152017509357, -1.770235842976589),
4801  (0, 0, 0.5477225575051664)
4802  ))
4803 
4804  check_qr(a, expected_q, expected_r)
4805 
4806  # check rectangular fat
4807  a = torch.Tensor((
4808  (1, 2, 3, 4),
4809  (5, 6, 7, 8),
4810  (9, 10, 11, 13)
4811  ))
4812  expected_q = torch.Tensor((
4813  (-0.0966736489045663, 0.907737593658436, 0.4082482904638653),
4814  (-0.4833682445228317, 0.3157348151855452, -0.8164965809277254),
4815  (-0.870062840141097, -0.2762679632873518, 0.4082482904638621)
4816  ))
4817  expected_r = torch.Tensor((
4818  (-1.0344080432788603e+01, -1.1794185166357092e+01,
4819  -1.3244289899925587e+01, -1.5564457473635180e+01),
4820  (0.0000000000000000e+00, 9.4720444555662542e-01,
4821  1.8944088911132546e+00, 2.5653453733825331e+00),
4822  (0.0000000000000000e+00, 0.0000000000000000e+00,
4823  1.5543122344752192e-15, 4.0824829046386757e-01)
4824  ))
4825  check_qr(a, expected_q, expected_r)
4826 
4827  # check big matrix
4828  a = torch.randn(1000, 1000)
4829  q, r = torch.qr(a)
4830  a_qr = torch.mm(q, r)
4831  self.assertEqual(a, a_qr, prec=1e-3)
4832 
4833  @skipIfNoLapack
4834  def test_ormqr(self):
4835  mat1 = torch.randn(7, 7)
4836  mat2 = torch.randn(7, 7)
4837  q, r = torch.qr(mat1)
4838  m, tau = torch.geqrf(mat1)
4839  out_holder = torch.empty_like(mat1)
4840 
4841  res1 = torch.mm(q, mat2)
4842  res2 = torch.ormqr(m, tau, mat2, left=True, transpose=False)
4843  torch.ormqr(m, tau, mat2, out=out_holder)
4844  self.assertEqual(res1, res2)
4845  self.assertEqual(res2, out_holder)
4846 
4847  res1 = torch.mm(mat2, q)
4848  res2 = torch.ormqr(m, tau, mat2, left=False, transpose=False)
4849  torch.ormqr(m, tau, mat2, left=False, transpose=False, out=out_holder)
4850  self.assertEqual(res1, res2)
4851  self.assertEqual(res2, out_holder)
4852 
4853  res1 = torch.mm(q.t(), mat2)
4854  res2 = torch.ormqr(m, tau, mat2, left=True, transpose=True)
4855  torch.ormqr(m, tau, mat2, left=True, transpose=True, out=out_holder)
4856  self.assertEqual(res1, res2)
4857  self.assertEqual(res2, out_holder)
4858 
4859  res1 = torch.mm(mat2, q.t())
4860  res2 = torch.ormqr(m, tau, mat2, left=False, transpose=True)
4861  torch.ormqr(m, tau, mat2, left=False, transpose=True, out=out_holder)
4862  self.assertEqual(res1, res2)
4863  self.assertEqual(res2, out_holder)
4864 
4865  @staticmethod
4866  def _test_geqrf(self, cast):
4867  a = cast(torch.randn(5, 5))
4868  b, c = torch.geqrf(a)
4869  b_placeholder, c_placeholder = torch.empty_like(b), torch.empty_like(c)
4870  torch.geqrf(a, out=(b_placeholder, c_placeholder))
4871  self.assertEqual(b, b_placeholder)
4872  self.assertEqual(c, c_placeholder)
4873 
4874  @skipIfNoLapack
4875  def test_geqrf(self):
4876  self._test_geqrf(self, lambda t: t)
4877 
4878  @staticmethod
4879  def _test_trtrs(self, cast):
4880  a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
4881  (-6.05, -3.30, 5.36, -4.44, 1.08),
4882  (-0.45, 2.58, -2.70, 0.27, 9.04),
4883  (8.32, 2.71, 4.35, -7.17, 2.14),
4884  (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
4885  b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
4886  (-1.56, 4.00, -8.67, 1.75, 2.86),
4887  (9.81, -4.09, -4.57, -8.61, 8.99))).t()
4888 
4889  a = cast(a)
4890  b = cast(b)
4891 
4892  U = torch.triu(a)
4893  L = torch.tril(a)
4894 
4895  # solve Ux = b
4896  x = torch.trtrs(b, U)[0]
4897  self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
4898  x = torch.trtrs(b, U, True, False, False)[0]
4899  self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
4900 
4901  # solve Lx = b
4902  x = torch.trtrs(b, L, False)[0]
4903  self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
4904  x = torch.trtrs(b, L, False, False, False)[0]
4905  self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
4906 
4907  # solve U'x = b
4908  x = torch.trtrs(b, U, True, True)[0]
4909  self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
4910  x = torch.trtrs(b, U, True, True, False)[0]
4911  self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
4912 
4913  # solve U'x = b by manual transposition
4914  y = torch.trtrs(b, U.t(), False, False)[0]
4915  self.assertLessEqual(x.dist(y), 1e-12)
4916 
4917  # solve L'x = b
4918  x = torch.trtrs(b, L, False, True)[0]
4919  self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
4920  x = torch.trtrs(b, L, False, True, False)[0]
4921  self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
4922 
4923  # solve L'x = b by manual transposition
4924  y = torch.trtrs(b, L.t(), True, False)[0]
4925  self.assertLessEqual(x.dist(y), 1e-12)
4926 
4927  # test reuse
4928  res1 = torch.trtrs(b, a)[0]
4929  ta = cast(torch.Tensor())
4930  tb = cast(torch.Tensor())
4931  torch.trtrs(b, a, out=(tb, ta))
4932  self.assertEqual(res1, tb, 0)
4933  tb.zero_()
4934  torch.trtrs(b, a, out=(tb, ta))
4935  self.assertEqual(res1, tb, 0)
4936 
4937  @skipIfNoLapack
4938  def test_trtrs(self):
4939  self._test_trtrs(self, lambda t: t)
4940 
4941  @staticmethod
4942  def _test_trtrs_batched(self, cast):
4943  def trtrs_test_helper(A_dims, b_dims, cast, upper, unitriangular):
4944  A = cast(torch.randn(*A_dims))
4945  A = A.triu() if upper else A.tril()
4946  if unitriangular:
4947  A.diagonal(dim1=-2, dim2=-1).fill_(1.)
4948  b = cast(torch.randn(*b_dims))
4949  return A, b
4950 
4951  for upper, transpose, unitriangular in product([True, False], repeat=3):
4952  # test against trtrs: one batch with all possible arguments
4953  A, b = trtrs_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular)
4954  x_exp = torch.trtrs(b.squeeze(0), A.squeeze(0),
4955  upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4956  x = torch.trtrs(b, A,
4957  upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4958  self.assertEqual(x, x_exp.unsqueeze(0))
4959 
4960  # test against trtrs in a loop: four batches with all possible arguments
4961  A, b = trtrs_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular)
4962  x_exp_list = []
4963  for i in range(4):
4964  x_exp = torch.trtrs(b[i], A[i],
4965  upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4966  x_exp_list.append(x_exp)
4967  x_exp = torch.stack(x_exp_list)
4968 
4969  x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4970  self.assertEqual(x, x_exp)
4971 
4972  # basic correctness test
4973  A, b = trtrs_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular)
4974  x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4975  if transpose:
4976  self.assertLessEqual(b.dist(torch.matmul(A.transpose(-1, -2), x)), 2e-12)
4977  else:
4978  self.assertLessEqual(b.dist(torch.matmul(A, x)), 2e-12)
4979 
4980  @skipIfNoLapack
4981  def test_trtrs_batched(self):
4982  _TestTorchMixin._test_trtrs_batched(self, lambda t: t)
4983 
4984  @staticmethod
4985  def _test_trtrs_batched_dims(self, cast):
4986  if not TEST_SCIPY:
4987  return
4988 
4989  from scipy.linalg import solve_triangular as tri_solve
4990 
4991  def scipy_tri_solve_batched(A, B, upper, trans, diag):
4992  batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2]
4993  single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:]
4994  expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A),
4995  torch.Size(batch_dims_B)))
4996  expand_A = np.broadcast_to(A, expand_dims + single_dim_A)
4997  expand_B = np.broadcast_to(B, expand_dims + single_dim_B)
4998  flat_A = expand_A.reshape((-1,) + single_dim_A)
4999  flat_B = expand_B.reshape((-1,) + single_dim_B)
5000  flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag)
5001  for a, b in zip(flat_A, flat_B)])
5002  return flat_X.reshape(expand_B.shape)
5003 
5004  def run_test(A_dims, b_dims, cast, upper, transpose, unitriangular):
5005  A = torch.randn(*A_dims)
5006  A = A.triu() if upper else A.tril()
5007  if unitriangular:
5008  A.diagonal(dim1=-2, dim2=-1).fill_(1.)
5009  b = torch.randn(*b_dims)
5010  x_exp = torch.Tensor(scipy_tri_solve_batched(A.numpy(), b.numpy(),
5011  upper, transpose, unitriangular))
5012  A, b = cast(A), cast(b)
5013  x = torch.trtrs(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
5014 
5015  self.assertEqual(x, cast(x_exp))
5016 
5017  for upper, transpose, unitriangular in product([True, False], repeat=3):
5018  # test against scipy.linalg.solve_triangular
5019  run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast, upper, transpose, unitriangular) # no broadcasting
5020  run_test((2, 1, 3, 4, 4), (4, 6), cast, upper, transpose, unitriangular) # broadcasting b
5021  run_test((4, 4), (2, 1, 3, 4, 2), cast, upper, transpose, unitriangular) # broadcasting A
5022  run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper, transpose, unitriangular) # broadcasting A & b
5023 
5024  @skipIfNoLapack
5025  def test_trtrs_batched_dims(self):
5026  self._test_trtrs_batched_dims(self, lambda t: t)
5027 
5028  @skipIfNoLapack
5029  def test_gels(self):
5030  def _test_underdetermined(a, b, expectedNorm):
5031  m = a.size()[0]
5032  n = a.size()[1]
5033  assert(m <= n)
5034 
5035  a_copy = a.clone()
5036  b_copy = b.clone()
5037  res1 = torch.gels(b, a)[0]
5038  self.assertEqual(a, a_copy, 0)
5039  self.assertEqual(b, b_copy, 0)
5040  self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8)
5041 
5042  ta = torch.Tensor()
5043  tb = torch.Tensor()
5044  res2 = torch.gels(b, a, out=(tb, ta))[0]
5045  self.assertEqual(a, a_copy, 0)
5046  self.assertEqual(b, b_copy, 0)
5047  self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8)
5048 
5049  res3 = torch.gels(b, a, out=(b, a))[0]
5050  self.assertEqual((torch.mm(a_copy, b) - b_copy).norm(), expectedNorm, 1e-8)
5051  self.assertEqual(res1, tb, 0)
5052  self.assertEqual(res1, b, 0)
5053  self.assertEqual(res1, res2, 0)
5054  self.assertEqual(res1, res3, 0)
5055 
5056  def _test_overdetermined(a, b, expectedNorm):
5057  m = a.size()[0]
5058  n = a.size()[1]
5059  assert(m > n)
5060 
5061  def check_norm(a, b, expected_norm, gels_result):
5062  # Checks |ax - b| and the residual info from the result
5063  n = a.size()[1]
5064 
5065  # The first n rows is the least square solution.
5066  # Rows n to m-1 contain residual information.
5067  x = gels_result[:n]
5068  resid_info = gels_result[n:]
5069 
5070  resid_norm = (torch.mm(a, x) - b).norm()
5071  self.assertEqual(resid_norm, expectedNorm, 1e-8)
5072  self.assertEqual(resid_info.norm(), resid_norm, 1e-8)
5073 
5074  a_copy = a.clone()
5075  b_copy = b.clone()
5076  res1 = torch.gels(b, a)[0]
5077  self.assertEqual(a, a_copy, 0)
5078  self.assertEqual(b, b_copy, 0)
5079  check_norm(a, b, expectedNorm, res1)
5080 
5081  ta = torch.Tensor()
5082  tb = torch.Tensor()
5083  res2 = torch.gels(b, a, out=(tb, ta))[0]
5084  self.assertEqual(a, a_copy, 0)
5085  self.assertEqual(b, b_copy, 0)
5086  check_norm(a, b, expectedNorm, res2)
5087 
5088  res3 = torch.gels(b, a, out=(b, a))[0]
5089  check_norm(a_copy, b_copy, expectedNorm, res3)
5090 
5091  self.assertEqual(res1, tb, 0)
5092  self.assertEqual(res1, b, 0)
5093  self.assertEqual(res1, res2, 0)
5094  self.assertEqual(res1, res3, 0)
5095 
5096  # basic test
5097  expectedNorm = 0
5098  a = torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5099  (-7.84, -0.28, 3.24, 8.09),
5100  (-4.39, -3.24, 6.27, 5.28),
5101  (4.53, 3.83, -6.64, 2.06))).t()
5102  b = torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5103  (9.35, -4.43, -0.70, -0.26))).t()
5104  _test_underdetermined(a, b, expectedNorm)
5105 
5106  # test overderemined
5107  expectedNorm = 17.390200628863
5108  a = torch.Tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45),
5109  (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70),
5110  (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19),
5111  (4.53, 3.83, -6.64, 2.06, -2.47, 4.70))).t()
5112  b = torch.Tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93),
5113  (9.35, -4.43, -0.70, -0.26, -7.36, -2.52))).t()
5114  _test_overdetermined(a, b, expectedNorm)
5115 
5116  # test underdetermined
5117  expectedNorm = 0
5118  a = torch.Tensor(((1.44, -9.96, -7.55),
5119  (-7.84, -0.28, 3.24),
5120  (-4.39, -3.24, 6.27),
5121  (4.53, 3.83, -6.64))).t()
5122  b = torch.Tensor(((8.58, 8.26, 8.48),
5123  (9.35, -4.43, -0.70))).t()
5124  _test_underdetermined(a, b, expectedNorm)
5125 
5126  # test reuse
5127  expectedNorm = 0
5128  a = torch.Tensor(((1.44, -9.96, -7.55, 8.34),
5129  (-7.84, -0.28, 3.24, 8.09),
5130  (-4.39, -3.24, 6.27, 5.28),
5131  (4.53, 3.83, -6.64, 2.06))).t()
5132  b = torch.Tensor(((8.58, 8.26, 8.48, -5.28),
5133  (9.35, -4.43, -0.70, -0.26))).t()
5134  ta = torch.Tensor()
5135  tb = torch.Tensor()
5136  torch.gels(b, a, out=(tb, ta))
5137  self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5138  torch.gels(b, a, out=(tb, ta))
5139  self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5140  torch.gels(b, a, out=(tb, ta))
5141  self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8)
5142 
5143  @skipIfNoLapack
5144  def test_eig(self):
5145  a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00),
5146  (-6.49, 3.80, 0.00, 0.00, 0.00),
5147  (-0.47, -6.39, 4.17, 0.00, 0.00),
5148  (-7.20, 1.50, -1.51, 5.70, 0.00),
5149  (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous()
5150  e = torch.eig(a)[0]
5151  ee, vv = torch.eig(a, True)
5152  te = torch.Tensor()
5153  tv = torch.Tensor()
5154  eee, vvv = torch.eig(a, True, out=(te, tv))
5155  self.assertEqual(e, ee, 1e-12)
5156  self.assertEqual(ee, eee, 1e-12)
5157  self.assertEqual(ee, te, 1e-12)
5158  self.assertEqual(vv, vvv, 1e-12)
5159  self.assertEqual(vv, tv, 1e-12)
5160 
5161  # test reuse
5162  X = torch.randn(4, 4)
5163  X = torch.mm(X.t(), X)
5164  e, v = torch.zeros(4, 2), torch.zeros(4, 4)
5165  torch.eig(X, True, out=(e, v))
5166  Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
5167  self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
5168  self.assertFalse(v.is_contiguous(), 'V is contiguous')
5169 
5170  torch.eig(X, True, out=(e, v))
5171  Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t()))
5172  self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
5173  self.assertFalse(v.is_contiguous(), 'V is contiguous')
5174 
5175  # test non-contiguous
5176  X = torch.randn(4, 4)
5177  X = torch.mm(X.t(), X)
5178  e = torch.zeros(4, 2, 2)[:, 1]
5179  v = torch.zeros(4, 2, 4)[:, 1]
5180  self.assertFalse(v.is_contiguous(), 'V is contiguous')
5181  self.assertFalse(e.is_contiguous(), 'E is contiguous')
5182  torch.eig(X, True, out=(e, v))
5183  Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t())
5184  self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
5185 
5186  @staticmethod
5187  def _test_symeig(self, conv_fn):
5188  xval = conv_fn(torch.rand(100, 3))
5189  cov = torch.mm(xval.t(), xval)
5190  rese = conv_fn(torch.zeros(3))
5191  resv = conv_fn(torch.zeros(3, 3))
5192 
5193  # First call to symeig
5194  self.assertTrue(resv.is_contiguous(), 'resv is not contiguous')
5195  torch.symeig(cov.clone(), True, out=(rese, resv))
5196  ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
5197  self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')
5198 
5199  # Second call to symeig
5200  self.assertFalse(resv.is_contiguous(), 'resv is contiguous')
5201  torch.symeig(cov.clone(), True, out=(rese, resv))
5202  ahat = torch.mm(torch.mm(resv, torch.diag(rese)), resv.t())
5203  self.assertEqual(cov, ahat, 1e-8, 'VeV\' wrong')
5204 
5205  # test eigenvectors=False
5206  rese2 = conv_fn(torch.zeros(3))
5207  resv2 = conv_fn(torch.randn(3, 3))
5208  expected_resv2 = conv_fn(torch.zeros(3, 3))
5209  torch.symeig(cov.clone(), False, out=(rese2, resv2))
5210  self.assertEqual(rese, rese2)
5211  self.assertEqual(resv2, expected_resv2)
5212 
5213  # test non-contiguous
5214  X = conv_fn(torch.rand(5, 5))
5215  X = X.t() * X
5216  e = conv_fn(torch.zeros(4, 2)).select(1, 1)
5217  v = conv_fn(torch.zeros(4, 2, 4))[:, 1]
5218  self.assertFalse(v.is_contiguous(), 'V is contiguous')
5219  self.assertFalse(e.is_contiguous(), 'E is contiguous')
5220  torch.symeig(X, True, out=(e, v))
5221  Xhat = torch.mm(torch.mm(v, torch.diag(e)), v.t())
5222  self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong')
5223 
5224  @skipIfNoLapack
5225  def test_symeig(self):
5226  self._test_symeig(self, lambda x: x)
5227 
5228  @skipIfNoLapack
5229  def test_svd(self):
5230  a = torch.Tensor(((8.79, 6.11, -9.15, 9.57, -3.49, 9.84),
5231  (9.93, 6.91, -7.93, 1.64, 4.02, 0.15),
5232  (9.83, 5.04, 4.86, 8.83, 9.80, -8.99),
5233  (5.45, -0.27, 4.85, 0.74, 10.00, -6.02),
5234  (3.16, 7.98, 3.01, 5.80, 4.27, -5.31))).t().clone()
5235  u, s, v = torch.svd(a)
5236  uu = torch.Tensor()
5237  ss = torch.Tensor()
5238  vv = torch.Tensor()
5239  uuu, sss, vvv = torch.svd(a, out=(uu, ss, vv))
5240  self.assertEqual(u, uu, 0, 'torch.svd')
5241  self.assertEqual(u, uuu, 0, 'torch.svd')
5242  self.assertEqual(s, ss, 0, 'torch.svd')
5243  self.assertEqual(s, sss, 0, 'torch.svd')
5244  self.assertEqual(v, vv, 0, 'torch.svd')
5245  self.assertEqual(v, vvv, 0, 'torch.svd')
5246 
5247  # test reuse
5248  X = torch.randn(4, 4)
5249  U, S, V = torch.svd(X)
5250  Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5251  self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
5252 
5253  self.assertFalse(U.is_contiguous(), 'U is contiguous')
5254  torch.svd(X, out=(U, S, V))
5255  Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5256  self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
5257 
5258  # test non-contiguous
5259  X = torch.randn(5, 5)
5260  U = torch.zeros(5, 2, 5)[:, 1]
5261  S = torch.zeros(5, 2)[:, 1]
5262  V = torch.zeros(5, 2, 5)[:, 1]
5263 
5264  self.assertFalse(U.is_contiguous(), 'U is contiguous')
5265  self.assertFalse(S.is_contiguous(), 'S is contiguous')
5266  self.assertFalse(V.is_contiguous(), 'V is contiguous')
5267  torch.svd(X, out=(U, S, V))
5268  Xhat = torch.mm(U, torch.mm(S.diag(), V.t()))
5269  self.assertEqual(X, Xhat, 1e-8, 'USV\' wrong')
5270 
5271  @staticmethod
5272  def _test_svd_no_singularvectors(self, cast):
5273  for size in [(5, 5), (5, 20), (20, 5)]:
5274  a = cast(torch.randn(*size))
5275  u, s_expect, v = torch.svd(a)
5276  u, s_actual, v = torch.svd(a, compute_uv=False)
5277  self.assertEqual(s_expect, s_actual, "Singular values don't match")
5278 
5279  @skipIfNoLapack
5280  def test_svd_no_singularvectors(self):
5281  self._test_svd_no_singularvectors(self, lambda t: t)
5282 
5283  @staticmethod
5284  def _test_matrix_rank(self, conv_fn):
5285  a = conv_fn(torch.eye(10))
5286  self.assertEqual(torch.matrix_rank(a).item(), 10)
5287  self.assertEqual(torch.matrix_rank(a, True).item(), 10)
5288 
5289  a[5, 5] = 0
5290  self.assertEqual(torch.matrix_rank(a).item(), 9)
5291  self.assertEqual(torch.matrix_rank(a, True).item(), 9)
5292 
5293  a = conv_fn(torch.randn(24, 42))
5294  self.assertEqual(torch.matrix_rank(a), torch.matrix_rank(a.t()))
5295  aaT = torch.mm(a, a.t())
5296  self.assertEqual(torch.matrix_rank(aaT), torch.matrix_rank(aaT, True))
5297  aTa = torch.mm(a.t(), a)
5298  self.assertEqual(torch.matrix_rank(aTa), torch.matrix_rank(aTa, True))
5299 
5300  if TEST_NUMPY:
5301  from numpy.linalg import matrix_rank
5302  a = conv_fn(torch.randn(35, 75))
5303  self.assertEqual(torch.matrix_rank(a).item(), matrix_rank(a.cpu().numpy()))
5304  self.assertEqual(torch.matrix_rank(a, 0.01).item(), matrix_rank(a.cpu().numpy(), 0.01))
5305 
5306  aaT = torch.mm(a, a.t())
5307  self.assertEqual(torch.matrix_rank(aaT).item(), matrix_rank(aaT.cpu().numpy()))
5308  self.assertEqual(torch.matrix_rank(aaT, 0.01).item(), matrix_rank(aaT.cpu().numpy(), 0.01))
5309 
5310  if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
5311  self.assertEqual(torch.matrix_rank(aaT, True).item(), matrix_rank(aaT.cpu().numpy(), True))
5312  self.assertEqual(torch.matrix_rank(aaT, 0.01, True).item(),
5313  matrix_rank(aaT.cpu().numpy(), 0.01, True))
5314 
5315  @skipIfNoLapack
5316  def test_matrix_rank(self):
5317  self._test_matrix_rank(self, lambda x: x)
5318 
5319  @staticmethod
5320  def _test_signal_window_functions(self, device='cpu'):
5321  if not TEST_SCIPY:
5322  raise unittest.SkipTest('Scipy not found')
5323 
5324  def test(name):
5325  torch_method = getattr(torch, name + '_window')
5326  for size in [1, 2, 5, 10, 50, 100, 1024, 2048]:
5327  for periodic in [True, False]:
5328  res = torch_method(size, periodic=periodic, device=device)
5329  ref = torch.from_numpy(signal.get_window(name, size, fftbins=periodic))
5330  self.assertEqual(res, ref)
5331  with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'):
5332  torch_method(3, layout=torch.sparse_coo)
5333  with self.assertRaisesRegex(RuntimeError, r'floating point'):
5334  torch_method(3, dtype=torch.long)
5335  self.assertTrue(torch_method(3, requires_grad=True).requires_grad)
5336  self.assertFalse(torch_method(3).requires_grad)
5337 
5338  for window in ['hann', 'hamming', 'bartlett', 'blackman']:
5339  test(window)
5340 
5341  def test_signal_window_functions(self):
5342  self._test_signal_window_functions(self)
5343 
5344  @staticmethod
5345  def _test_inverse(self, conv_fn):
5346  from common_utils import random_fullrank_matrix_distinct_singular_value
5347 
5348  # no batches: 2-D tensors
5349  matrix = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
5350  matrix_inverse = torch.inverse(matrix)
5351  identity = conv_fn(torch.eye(5))
5352  self.assertEqual(identity, torch.mm(matrix, matrix_inverse), 1e-8, 'inverse value')
5353  self.assertEqual(identity, torch.mm(matrix_inverse, matrix), 1e-8, 'inverse value')
5354 
5355  matrix_inverse_out = conv_fn(torch.empty(5, 5))
5356  torch.inverse(matrix, out=matrix_inverse_out)
5357  self.assertEqual(matrix_inverse_out, matrix_inverse, 0, 'inverse value in-place')
5358  # second call, now that matrix_inverse_out is transposed
5359  torch.inverse(matrix, out=matrix_inverse_out)
5360  self.assertEqual(matrix_inverse_out, matrix_inverse, 0, 'inverse value in-place')
5361 
5362  # one batch
5363  matrix = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 1))
5364  matrix_inverse = torch.inverse(matrix)
5365  expected_inv = matrix.squeeze(0).inverse()
5366  self.assertEqual(matrix_inverse, expected_inv.unsqueeze(0))
5367 
5368  # four batches
5369  matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 4))
5370  expected_inv_list = []
5371  for i in range(0, 4):
5372  expected_inv_list.append(torch.inverse(matrices[i]))
5373  expected_inv = torch.stack(expected_inv_list)
5374  matrices_inverse = torch.inverse(matrices)
5375  self.assertEqual(matrices_inverse, expected_inv)
5376 
5377  # six batches (2 x 3)
5378  matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 2, 3))
5379  expected_inv_list = []
5380  for mat in matrices.view(-1, 5, 5):
5381  expected_inv_list.append(torch.inverse(mat))
5382  expected_inv = torch.stack(expected_inv_list).view(2, 3, 5, 5)
5383  matrices_inverse = torch.inverse(matrices)
5384  self.assertEqual(matrices_inverse, expected_inv)
5385 
5386  # incorrect input test
5387  with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
5388  torch.inverse(torch.randn(2, 3, 4, 3))
5389 
5390  # correctness test
5391  matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 3))
5392  matrices_inverse = torch.inverse(matrices)
5393  self.assertEqual(torch.matmul(matrices, matrices_inverse), identity.expand_as(matrices))
5394  self.assertEqual(torch.matmul(matrices_inverse, matrices), identity.expand_as(matrices))
5395 
5396  # torch.inverse with out and batches
5397  matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 3))
5398  matrices_inverse = conv_fn(torch.empty(3, 5, 5))
5399  torch.inverse(matrices, out=matrices_inverse)
5400  self.assertEqual(torch.inverse(matrices), matrices_inverse)
5401 
5402  # non-contiguous inputs
5403  if not TEST_NUMPY:
5404  return
5405 
5406  from numpy.linalg import inv
5407  matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 2)).permute(0, 2, 1)
5408  assert not matrices.is_contiguous()
5409  matrices_inverse = torch.inverse(matrices)
5410  expected_inv = torch.as_tensor(inv(matrices.cpu().numpy()))
5411  self.assertEqual(matrices_inverse, conv_fn(expected_inv))
5412 
5413  @skipIfNoLapack
5414  def test_inverse(self):
5415  self._test_inverse(self, lambda t: t)
5416 
5417  @staticmethod
5418  def _test_pinverse(self, conv_fn):
5419  def run_test(M):
5420  # Testing against definition for pseudo-inverses
5421  MPI = torch.pinverse(M)
5422  self.assertEqual(M, M.mm(MPI).mm(M), 1e-8, 'pseudo-inverse condition 1')
5423  self.assertEqual(MPI, MPI.mm(M).mm(MPI), 1e-8, 'pseudo-inverse condition 2')
5424  self.assertEqual(M.mm(MPI), (M.mm(MPI)).t(), 1e-8, 'pseudo-inverse condition 3')
5425  self.assertEqual(MPI.mm(M), (MPI.mm(M)).t(), 1e-8, 'pseudo-inverse condition 4')
5426 
5427  # Square matrix
5428  M = conv_fn(torch.randn(5, 5))
5429  run_test(M)
5430 
5431  # Rectangular matrix
5432  M = conv_fn(torch.randn(3, 4))
5433  run_test(M)
5434 
5435  # Test inverse and pseudo-inverse for invertible matrix
5436  M = torch.randn(5, 5)
5437  M = conv_fn(M.mm(M.t()))
5438  self.assertEqual(conv_fn(torch.eye(5)), M.pinverse().mm(M), 1e-7, 'pseudo-inverse for invertible matrix')
5439 
5440  @skipIfNoLapack
5441  def test_pinverse(self):
5442  self._test_pinverse(self, conv_fn=lambda x: x)
5443 
5444  @staticmethod
5445  def _test_matrix_power(self, conv_fn):
5446  def run_test(M, sign=1):
5447  if sign == -1:
5448  M = M.inverse()
5449  MP2 = torch.matrix_power(M, 2)
5450  self.assertEqual(MP2, torch.matmul(M, M))
5451 
5452  MP3 = torch.matrix_power(M, 3)
5453  self.assertEqual(MP3, torch.matmul(MP2, M))
5454 
5455  MP4 = torch.matrix_power(M, 4)
5456  self.assertEqual(MP4, torch.matmul(MP2, MP2))
5457 
5458  MP6 = torch.matrix_power(M, 6)
5459  self.assertEqual(MP6, torch.matmul(MP3, MP3))
5460 
5461  MP0 = torch.matrix_power(M, 0)
5462  self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M))
5463 
5464  # Single matrix
5465  M = conv_fn(torch.randn(5, 5))
5466  run_test(M)
5467 
5468  # Batch matrices
5469  M = conv_fn(torch.randn(3, 3, 3))
5470  run_test(M)
5471 
5472  # Many batch matrices
5473  M = conv_fn(torch.randn(2, 3, 3, 3))
5474  run_test(M)
5475 
5476  # This is for negative powers
5477  from common_utils import random_fullrank_matrix_distinct_singular_value
5478  M = conv_fn(random_fullrank_matrix_distinct_singular_value(5))
5479  run_test(M, sign=-1)
5480 
5481  M = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 3))
5482  run_test(M, sign=-1)
5483 
5484  M = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 2, 3))
5485  run_test(M, sign=-1)
5486 
5487  @skipIfNoLapack
5488  def test_matrix_power(self):
5489  self._test_matrix_power(self, conv_fn=lambda x: x)
5490 
5491  @staticmethod
5492  def _test_chain_matmul(self, cast):
5493  def product(matrices):
5494  for mat in matrices[1:]:
5495  matrices[0] = matrices[0].mm(mat)
5496  return matrices[0]
5497 
5498  def run_test(p, cast):
5499  matrices = []
5500  for (pi, pi_1) in zip(p[:-1], p[1:]):
5501  matrices.append(cast(torch.randn(pi, pi_1)))
5502  self.assertEqual(torch.chain_matmul(*matrices), product(matrices))
5503 
5504  run_test([10, 20, 30, 5], cast)
5505  run_test([15, 5, 10, 20, 25], cast)
5506 
5507  def test_chain_matmul(self):
5508  self._test_chain_matmul(self, cast=lambda x: x)
5509 
5510  @staticmethod
5511  def _test_det_logdet_slogdet(self, conv_fn):
5512  def reference_det(M):
5513  # naive row reduction
5514  M = M.clone()
5515  l = M.size(0)
5516  multiplier = 1
5517  for i in range(l):
5518  if M[i, 0] != 0:
5519  if i != 0:
5520  M[0], M[i] = M[i], M[0]
5521  multiplier = -1
5522  break
5523  else:
5524  return 0
5525  for i in range(1, l):
5526  row = M[i]
5527  for j in range(i):
5528  row -= row[j] / M[j, j] * M[j]
5529  M[i] = row
5530  return M.diag().prod() * multiplier
5531 
5532  def test_single_det(M, target, desc):
5533  det = M.det()
5534  logdet = M.logdet()
5535  sdet, logabsdet = M.slogdet()
5536  self.assertEqual(det, target, 1e-7, '{} (det)'.format(desc))
5537  if det.item() < 0:
5538  self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc))
5539  self.assertTrue(sdet.item() == -1, '{} (slogdet sign negative case)'.format(desc))
5540  self.assertEqual(logabsdet.exp(), det.abs(), 1e-7, '{} (slogdet logabsdet negative case)'.format(desc))
5541  elif det.item() == 0:
5542  self.assertEqual(logdet.exp().item(), 0, 1e-7, '{} (logdet zero case)'.format(desc))
5543  self.assertTrue(sdet.item() == 0, '{} (slogdet sign zero case)'.format(desc))
5544  self.assertEqual(logabsdet.exp().item(), 0, 1e-7, '{} (slogdet logabsdet zero case)'.format(desc))
5545  else:
5546  self.assertEqual(logdet.exp(), det, 1e-7, '{} (logdet positive case)'.format(desc))
5547  self.assertTrue(sdet.item() == 1, '{} (slogdet sign positive case)'.format(desc))
5548  self.assertEqual(logabsdet.exp(), det, 1e-7, '{} (slogdet logabsdet positive case)'.format(desc))
5549 
5550  eye = conv_fn(torch.eye(5))
5551  test_single_det(eye, torch.tensor(1, dtype=eye.dtype), 'identity')
5552 
5553  # TODO: Remove when MAGMA 2.5.0 is built for CUDA 8 and CUDA 9.2
5554  is_cuda_8_92 = False
5555  if torch.cuda.is_available() and torch.version.cuda is not None:
5556  is_cuda_8_92 = any(x in torch.version.cuda for x in ['8.0', '9.2'])
5557 
5558  def test(M):
5559  assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5'
5560  M = conv_fn(M)
5561 
5562  if M.is_cuda and is_cuda_8_92:
5563  return
5564 
5565  M_det = M.det()
5566  ref_M_det = reference_det(M)
5567 
5568  test_single_det(M, ref_M_det, 'basic')
5569  if abs(ref_M_det.item()) >= 1e-10: # skip singular
5570  test_single_det(M, M.inverse().det().pow_(-1), 'inverse')
5571  test_single_det(M, M.t().det(), 'transpose')
5572 
5573  for x in [0, 2, 4]:
5574  for scale in [-2, -0.1, 0, 10]:
5575  target = M_det * scale
5576  # dim 0
5577  M_clone = M.clone()
5578  M_clone[:, x] *= scale
5579  test_single_det(M_clone, target, 'scale a row')
5580  # dim 1
5581  M_clone = M.clone()
5582  M_clone[x, :] *= scale
5583  test_single_det(M_clone, target, 'scale a column')
5584 
5585  for x1, x2 in [(0, 3), (4, 1), (3, 2)]:
5586  assert x1 != x2, 'x1 and x2 needs to be different for this test'
5587  target = M_det.clone().zero_()
5588  # dim 0
5589  M_clone = M.clone()
5590  M_clone[:, x2] = M_clone[:, x1]
5591  test_single_det(M_clone, target, 'two rows are same')
5592  # dim 1
5593  M_clone = M.clone()
5594  M_clone[x2, :] = M_clone[x1, :]
5595  test_single_det(M_clone, target, 'two columns are same')
5596 
5597  for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]:
5598  target = -M_det * scale1 * scale2
5599  # dim 0
5600  M_clone = M.clone()
5601  t = M_clone[:, x1] * scale1
5602  M_clone[:, x1] += M_clone[:, x2] * scale2
5603  M_clone[:, x2] = t
5604  test_single_det(M_clone, target, 'exchanging rows')
5605  # dim 1
5606  M_clone = M.clone()
5607  t = M_clone[x1, :] * scale1
5608  M_clone[x1, :] += M_clone[x2, :] * scale2
5609  M_clone[x2, :] = t
5610  test_single_det(M_clone, target, 'exchanging columns')
5611 
5612  def get_random_mat_scale(n):
5613  # For matrices with values i.i.d. with 0 mean, unit variance, and
5614  # subexponential tail, we have:
5615  # E[log det(A^2)] \approx log((n-1)!)
5616  #
5617  # Notice:
5618  # log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)]
5619  #
5620  # So:
5621  # stddev[det(A)] >= sqrt( (n-1)! )
5622  #
5623  # We use this as an intuitive guideline to scale random generated
5624  # matrices so our closeness tests can work more robustly:
5625  # scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n))
5626  #
5627  # source: https://arxiv.org/pdf/1112.0752.pdf
5628  return math.factorial(n - 1) ** (-1.0 / (2 * n))
5629 
5630  for n in [5, 10, 25]:
5631  scale = get_random_mat_scale(n)
5632  test(torch.randn(n, n) * scale)
5633  r = torch.randn(n, n) * scale
5634  # symmetric psd
5635  test(r.mm(r.t()))
5636  # symmetric pd
5637  r = torch.randn(n, n) * scale
5638  test(r.mm(r.t()) + torch.eye(n) * 1e-6)
5639  # symmetric
5640  r = torch.randn(n, n) * scale
5641  for i in range(n):
5642  for j in range(i):
5643  r[i, j] = r[j, i]
5644  test(r)
5645  # non-contiguous
5646  test((torch.randn(n, n, n + 1) * scale)[:, 2, 1:])
5647  # det = 0
5648  r = torch.randn(n, n) * scale
5649  u, s, v = r.svd()
5650  if reference_det(u) < 0:
5651  u = -u
5652  if reference_det(v) < 0:
5653  v = -v
5654  s[0] *= -1
5655  s[-1] = 0
5656  test(u.mm(s.diag()).mm(v))
5657 
5658  @skipIfNoLapack
5659  def test_det_logdet_slogdet(self):
5660  self._test_det_logdet_slogdet(self, lambda x: x)
5661 
5662  @staticmethod
5663  def _test_fft_ifft_rfft_irfft(self, device='cpu'):
5664  def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
5665  x = prepro_fn(torch.randn(*sizes, device=device))
5666  for normalized in (True, False):
5667  res = x.fft(signal_ndim, normalized=normalized)
5668  rec = res.ifft(signal_ndim, normalized=normalized)
5669  self.assertEqual(x, rec, 1e-8, 'fft and ifft')
5670  res = x.ifft(signal_ndim, normalized=normalized)
5671  rec = res.fft(signal_ndim, normalized=normalized)
5672  self.assertEqual(x, rec, 1e-8, 'ifft and fft')
5673 
5674  def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
5675  x = prepro_fn(torch.randn(*sizes, device=device))
5676  signal_numel = 1
5677  signal_sizes = x.size()[-signal_ndim:]
5678  for normalized, onesided in product((True, False), repeat=2):
5679  res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided)
5680  if not onesided: # check Hermitian symmetry
5681  def test_one_sample(res, test_num=10):
5682  idxs_per_dim = [torch.LongTensor(test_num).random_(s).tolist() for s in signal_sizes]
5683  for idx in zip(*idxs_per_dim):
5684  reflected_idx = tuple((s - i) % s for i, s in zip(idx, res.size()))
5685  idx_val = res.__getitem__(idx)
5686  reflected_val = res.__getitem__(reflected_idx)
5687  self.assertEqual(idx_val[0], reflected_val[0], 'rfft hermitian symmetry on real part')
5688  self.assertEqual(idx_val[1], -reflected_val[1], 'rfft hermitian symmetry on imaginary part')
5689  if len(sizes) == signal_ndim:
5690  test_one_sample(res)
5691  else:
5692  output_non_batch_shape = res.size()[-(signal_ndim + 1):]
5693  flatten_batch_res = res.view(-1, *output_non_batch_shape)
5694  nb = flatten_batch_res.size(0)
5695  test_idxs = torch.LongTensor(min(nb, 4)).random_(nb)
5696  for test_idx in test_idxs.tolist():
5697  test_one_sample(flatten_batch_res[test_idx])
5698  # compare with C2C
5699  xc = torch.stack([x, torch.zeros_like(x)], -1)
5700  xc_res = xc.fft(signal_ndim, normalized=normalized)
5701  self.assertEqual(res, xc_res)
5702  test_input_signal_sizes = [signal_sizes]
5703  rec = res.irfft(signal_ndim, normalized=normalized,
5704  onesided=onesided, signal_sizes=signal_sizes)
5705  self.assertEqual(x, rec, 1e-8, 'rfft and irfft')
5706  if not onesided: # check that we can use C2C ifft
5707  rec = res.ifft(signal_ndim, normalized=normalized)
5708  self.assertEqual(x, rec.select(-1, 0), 1e-8, 'twosided rfft and ifft real')
5709  self.assertEqual(rec.select(-1, 1).data.abs().mean(), 0, 1e-8, 'twosided rfft and ifft imaginary')
5710 
5711  # contiguous case
5712  _test_real((100,), 1)
5713  _test_real((10, 1, 10, 100), 1)
5714  _test_real((100, 100), 2)
5715  _test_real((2, 2, 5, 80, 60), 2)
5716  _test_real((50, 40, 70), 3)
5717  _test_real((30, 1, 50, 25, 20), 3)
5718 
5719  _test_complex((100, 2), 1)
5720  _test_complex((100, 100, 2), 1)
5721  _test_complex((100, 100, 2), 2)
5722  _test_complex((1, 20, 80, 60, 2), 2)
5723  _test_complex((50, 40, 70, 2), 3)
5724  _test_complex((6, 5, 50, 25, 20, 2), 3)
5725 
5726  # non-contiguous case
5727  _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type
5728  _test_real((100, 100, 3), 1, lambda x: x[:, :, 0])
5729  _test_real((100, 100), 2, lambda x: x.t())
5730  _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60])
5731  _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80])
5732  _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3))
5733 
5734  _test_complex((2, 100), 1, lambda x: x.t())
5735  _test_complex((100, 2), 1, lambda x: x.expand(100, 100, 2))
5736  _test_complex((300, 200, 3), 2, lambda x: x[:100, :100, 1:]) # input is not aligned to complex type
5737  _test_complex((20, 90, 110, 2), 2, lambda x: x[:, 5:85].narrow(2, 5, 100))
5738  _test_complex((40, 60, 3, 80, 2), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
5739  _test_complex((30, 55, 50, 22, 2), 3, lambda x: x[:, 3:53, 15:40, 1:21])
5740 
5741  # non-contiguous with strides not representable as aligned with complex type
5742  _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [3, 2, 1]))
5743  _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 2, 2]))
5744  _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 3, 1]))
5745  _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [3, 3, 1]))
5746  _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 2, 2]))
5747  _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 3, 1]))
5748 
5749  @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
5750  def test_fft_ifft_rfft_irfft(self):
5751  self._test_fft_ifft_rfft_irfft(self)
5752 
5753  @staticmethod
5754  def _test_stft(self, device='cpu'):
5755  if not TEST_LIBROSA:
5756  raise unittest.SkipTest('librosa not found')
5757 
5758  def librosa_stft(x, n_fft, hop_length, win_length, window, center):
5759  if window is None:
5760  window = np.ones(n_fft if win_length is None else win_length)
5761  else:
5762  window = window.cpu().numpy()
5763  input_1d = x.dim() == 1
5764  if input_1d:
5765  x = x.view(1, -1)
5766  result = []
5767  for xi in x:
5768  ri = librosa.stft(xi.cpu().numpy(), n_fft, hop_length, win_length, window, center=center)
5769  result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1)))
5770  result = torch.stack(result, 0)
5771  if input_1d:
5772  result = result[0]
5773  return result
5774 
5775  def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
5776  center=True, expected_error=None):
5777  x = torch.randn(*sizes, device=device)
5778  if win_sizes is not None:
5779  window = torch.randn(*win_sizes, device=device)
5780  else:
5781  window = None
5782  if expected_error is None:
5783  result = x.stft(n_fft, hop_length, win_length, window, center=center)
5784  ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
5785  self.assertEqual(result, ref_result, 7e-6, 'stft comparison against librosa')
5786  else:
5787  self.assertRaises(expected_error,
5788  lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
5789 
5790  for center in [True, False]:
5791  _test((10,), 7, center=center)
5792  _test((10, 4000), 1024, center=center)
5793 
5794  _test((10,), 7, 2, center=center)
5795  _test((10, 4000), 1024, 512, center=center)
5796 
5797  _test((10,), 7, 2, win_sizes=(7,), center=center)
5798  _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center)
5799 
5800  # spectral oversample
5801  _test((10,), 7, 2, win_length=5, center=center)
5802  _test((10, 4000), 1024, 512, win_length=100, center=center)
5803 
5804  _test((10, 4, 2), 1, 1, expected_error=RuntimeError)
5805  _test((10,), 11, 1, center=False, expected_error=RuntimeError)
5806  _test((10,), -1, 1, expected_error=RuntimeError)
5807  _test((10,), 3, win_length=5, expected_error=RuntimeError)
5808  _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
5809  _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
5810 
5811  def test_stft(self):
5812  self._test_stft(self)
5813 
5814  @unittest.skip("Not implemented yet")
5815  def test_conv2(self):
5816  x = torch.rand(math.floor(torch.uniform(50, 100)), math.floor(torch.uniform(50, 100)))
5817  k = torch.rand(math.floor(torch.uniform(10, 20)), math.floor(torch.uniform(10, 20)))
5818  imvc = torch.conv2(x, k)
5819  imvc2 = torch.conv2(x, k, 'V')
5820  imfc = torch.conv2(x, k, 'F')
5821 
5822  ki = k.clone()
5823  ks = k.storage()
5824  kis = ki.storage()
5825  for i in range(ks.size() - 1, 0, -1):
5826  kis[ks.size() - i + 1] = ks[i]
5827  # for i=ks.size(), 1, -1 do kis[ks.size()-i+1]=ks[i] end
5828  imvx = torch.xcorr2(x, ki)
5829  imvx2 = torch.xcorr2(x, ki, 'V')
5830  imfx = torch.xcorr2(x, ki, 'F')
5831 
5832  self.assertEqual(imvc, imvc2, 0, 'torch.conv2')
5833  self.assertEqual(imvc, imvx, 0, 'torch.conv2')
5834  self.assertEqual(imvc, imvx2, 0, 'torch.conv2')
5835  self.assertEqual(imfc, imfx, 0, 'torch.conv2')
5836  self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr2(x, x)[0][0]), 1e-10, 'torch.conv2')
5837 
5838  xx = torch.Tensor(2, x.size(1), x.size(2))
5839  xx[1].copy_(x)
5840  xx[2].copy_(x)
5841  kk = torch.Tensor(2, k.size(1), k.size(2))
5842  kk[1].copy_(k)
5843  kk[2].copy_(k)
5844 
5845  immvc = torch.conv2(xx, kk)
5846  immvc2 = torch.conv2(xx, kk, 'V')
5847  immfc = torch.conv2(xx, kk, 'F')
5848 
5849  self.assertEqual(immvc[0], immvc[1], 0, 'torch.conv2')
5850  self.assertEqual(immvc[0], imvc, 0, 'torch.conv2')
5851  self.assertEqual(immvc2[0], imvc2, 0, 'torch.conv2')
5852  self.assertEqual(immfc[0], immfc[1], 0, 'torch.conv2')
5853  self.assertEqual(immfc[0], imfc, 0, 'torch.conv2')
5854 
5855  @unittest.skip("Not implemented yet")
5856  def test_conv3(self):
5857  x = torch.rand(math.floor(torch.uniform(20, 40)),
5858  math.floor(torch.uniform(20, 40)),
5859  math.floor(torch.uniform(20, 40)))
5860  k = torch.rand(math.floor(torch.uniform(5, 10)),
5861  math.floor(torch.uniform(5, 10)),
5862  math.floor(torch.uniform(5, 10)))
5863  imvc = torch.conv3(x, k)
5864  imvc2 = torch.conv3(x, k, 'V')
5865  imfc = torch.conv3(x, k, 'F')
5866 
5867  ki = k.clone()
5868  ks = k.storage()
5869  kis = ki.storage()
5870  for i in range(ks.size() - 1, 0, -1):
5871  kis[ks.size() - i + 1] = ks[i]
5872  imvx = torch.xcorr3(x, ki)
5873  imvx2 = torch.xcorr3(x, ki, 'V')
5874  imfx = torch.xcorr3(x, ki, 'F')
5875 
5876  self.assertEqual(imvc, imvc2, 0, 'torch.conv3')
5877  self.assertEqual(imvc, imvx, 0, 'torch.conv3')
5878  self.assertEqual(imvc, imvx2, 0, 'torch.conv3')
5879  self.assertEqual(imfc, imfx, 0, 'torch.conv3')
5880  self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr3(x, x)[0][0][0]), 4e-10, 'torch.conv3')
5881 
5882  xx = torch.Tensor(2, x.size(1), x.size(2), x.size(3))
5883  xx[1].copy_(x)
5884  xx[2].copy_(x)
5885  kk = torch.Tensor(2, k.size(1), k.size(2), k.size(3))
5886  kk[1].copy_(k)
5887  kk[2].copy_(k)
5888 
5889  immvc = torch.conv3(xx, kk)
5890  immvc2 = torch.conv3(xx, kk, 'V')
5891  immfc = torch.conv3(xx, kk, 'F')
5892 
5893  self.assertEqual(immvc[0], immvc[1], 0, 'torch.conv3')
5894  self.assertEqual(immvc[0], imvc, 0, 'torch.conv3')
5895  self.assertEqual(immvc2[0], imvc2, 0, 'torch.conv3')
5896  self.assertEqual(immfc[0], immfc[1], 0, 'torch.conv3')
5897  self.assertEqual(immfc[0], imfc, 0, 'torch.conv3')
5898 
5899  @unittest.skip("Not implemented yet")
5900  def _test_conv_corr_eq(self, fn, fn_2_to_3):
5901  ix = math.floor(random.randint(20, 40))
5902  iy = math.floor(random.randint(20, 40))
5903  iz = math.floor(random.randint(20, 40))
5904  kx = math.floor(random.randint(5, 10))
5905  ky = math.floor(random.randint(5, 10))
5906  kz = math.floor(random.randint(5, 10))
5907 
5908  x = torch.rand(ix, iy, iz)
5909  k = torch.rand(kx, ky, kz)
5910 
5911  o3 = fn(x, k)
5912  o32 = torch.zeros(o3.size())
5913  fn_2_to_3(x, k, o3, o32)
5914  self.assertEqual(o3, o32)
5915 
5916  @unittest.skip("Not implemented yet")
5917  def test_xcorr3_xcorr2_eq(self):
5918  def reference(x, k, o3, o32):
5919  for i in range(o3.size(1)):
5920  for j in range(k.size(1)):
5921  o32[i].add(torch.xcorr2(x[i + j - 1], k[j]))
5922  self._test_conv_corr_eq(torch.xcorr3, reference)
5923 
5924  @unittest.skip("Not implemented yet")
5925  def test_xcorr3_xcorr2_eq_full(self):
5926  def reference(x, k, o3, o32):
5927  for i in range(x.size(1)):
5928  for j in range(k.size(1)):
5929  o32[i].add(torch.xcorr2(x[i], k[k.size(1) - j + 1], 'F'))
5930  self._test_conv_corr_eq(lambda x, k: torch.xcorr3(x, k, 'F'), reference)
5931 
5932  @unittest.skip("Not implemented yet")
5933  def test_conv3_conv2_eq_valid(self):
5934  def reference(x, k, o3, o32):
5935  for i in range(o3.size(1)):
5936  for j in range(k.size(1)):
5937  o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1]))
5938  self._test_conv_corr_eq(torch.conv3, reference)
5939 
5940  @unittest.skip("Not implemented yet")
5941  def test_fconv3_fconv2_eq(self):
5942  def reference(x, k, o3, o32):
5943  for i in range(o3.size(1)):
5944  for j in range(k.size(1)):
5945  o32[i + j - 1].add(torch.conv2(x[i], k[j], 'F'))
5946  self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k, 'F'), reference)
5947 
5948  def test_logical(self):
5949  x = torch.rand(100, 100) * 2 - 1
5950 
5951  xgt = torch.gt(x, 1)
5952  xlt = torch.lt(x, 1)
5953 
5954  xeq = torch.eq(x, 1)
5955  xne = torch.ne(x, 1)
5956 
5957  neqs = xgt + xlt
5958  all = neqs + xeq
5959  self.assertEqual(neqs.long().sum(), xne.long().sum(), 0)
5960  self.assertEqual(x.nelement(), all.long().sum())
5961 
5962  def test_isfinite(self):
5963  x = torch.Tensor([1, inf, 2, -inf, nan, -10])
5964  self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 0, 1, 0, 0, 1]))
5965 
5966  def test_isfinite_int(self):
5967  x = torch.tensor([1, 2, 3])
5968  self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 1, 1]))
5969 
5970  @staticmethod
5971  def _test_isinf(self, cast):
5972  t1 = cast(torch.Tensor([1, inf, 2, -inf, nan]))
5973  t2 = cast(torch.ByteTensor([1, 2, 3]))
5974  t3 = cast(torch.CharTensor([1, 2, 3]))
5975  t4 = cast(torch.ShortTensor([1, 2, 3]))
5976  t5 = cast(torch.IntTensor([1, 2, 3]))
5977  t6 = cast(torch.LongTensor([1, 2, 3]))
5978  self.assertEqual(torch.isinf(t1), cast(torch.ByteTensor([0, 1, 0, 1, 0])))
5979  self.assertEqual(torch.isinf(t2), cast(torch.ByteTensor([0, 0, 0])))
5980  self.assertEqual(torch.isinf(t3), cast(torch.ByteTensor([0, 0, 0])))
5981  self.assertEqual(torch.isinf(t4), cast(torch.ByteTensor([0, 0, 0])))
5982  self.assertEqual(torch.isinf(t5), cast(torch.ByteTensor([0, 0, 0])))
5983  self.assertEqual(torch.isinf(t6), cast(torch.ByteTensor([0, 0, 0])))
5984 
5985  def test_isinf(self):
5986  self._test_isinf(self, lambda t: t)
5987 
5988  def test_isnan(self):
5989  x = torch.Tensor([1, nan, 2])
5990  self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0]))
5991 
5992  def test_RNGState(self):
5993  state = torch.get_rng_state()
5994  stateCloned = state.clone()
5995  before = torch.rand(1000)
5996 
5997  self.assertEqual(state.ne(stateCloned).long().sum(), 0, 0)
5998 
5999  torch.set_rng_state(state)
6000  after = torch.rand(1000)
6001  self.assertEqual(before, after, 0)
6002 
6003  def test_RNGStateAliasing(self):
6004  # Fork the random number stream at this point
6005  gen = torch.Generator()
6006  gen.set_state(torch.get_rng_state())
6007  self.assertEqual(gen.get_state(), torch.get_rng_state())
6008 
6009  target_value = torch.rand(1000)
6010  # Dramatically alter the internal state of the main generator
6011  _ = torch.rand(100000)
6012  forked_value = torch.rand(1000, generator=gen)
6013  self.assertEqual(target_value, forked_value, 0, "RNG has not forked correctly.")
6014 
6015  def test_RNG_after_pickle(self):
6017  before = torch.rand(10)
6018 
6020  buf = io.BytesIO()
6021  tensor = torch.Tensor([1, 2, 3])
6022  ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor)
6023  after = torch.rand(10)
6024 
6025  self.assertEqual(before, after, 0)
6026 
6027  def test_boxMullerState(self):
6028  torch.manual_seed(123)
6029  odd_number = 101
6030  seeded = torch.randn(odd_number)
6031  state = torch.get_rng_state()
6032  midstream = torch.randn(odd_number)
6033  torch.set_rng_state(state)
6034  repeat_midstream = torch.randn(odd_number)
6035  torch.manual_seed(123)
6036  reseeded = torch.randn(odd_number)
6037  self.assertEqual(midstream, repeat_midstream, 0,
6038  'get_rng_state/set_rng_state not generating same sequence of normally distributed numbers')
6039  self.assertEqual(seeded, reseeded, 0,
6040  'repeated calls to manual_seed not generating same sequence of normally distributed numbers')
6041 
6042  def test_manual_seed(self):
6043  rng_state = torch.get_rng_state()
6044  torch.manual_seed(2)
6045  x = torch.randn(100)
6046  self.assertEqual(torch.initial_seed(), 2)
6047  torch.manual_seed(2)
6048  y = torch.randn(100)
6049  self.assertEqual(x, y)
6050  torch.set_rng_state(rng_state)
6051 
6052  @staticmethod
6053  def _test_cholesky(self, cast):
6054  x = cast(torch.rand(10, 10) + 1e-1)
6055  A = torch.mm(x, x.t())
6056 
6057  # default Case
6058  C = torch.cholesky(A)
6059  B = torch.mm(C, C.t())
6060  self.assertEqual(A, B, 1e-14)
6061 
6062  # test Upper Triangular
6063  U = torch.cholesky(A, True)
6064  B = torch.mm(U.t(), U)
6065  self.assertEqual(A, B, 1e-14, 'cholesky (upper) did not allow rebuilding the original matrix')
6066 
6067  # test Lower Triangular
6068  L = torch.cholesky(A, False)
6069  B = torch.mm(L, L.t())
6070  self.assertEqual(A, B, 1e-14, 'cholesky (lower) did not allow rebuilding the original matrix')
6071 
6072  @skipIfNoLapack
6073  def test_cholesky(self):
6074  self._test_cholesky(self, lambda t: t)
6075 
6076  @staticmethod
6077  def _test_cholesky_batched(self, cast):
6078  from common_utils import random_symmetric_pd_matrix
6079 
6080  def cholesky_test_helper(n, batch_dims, cast, upper):
6081  A = cast(random_symmetric_pd_matrix(n, *batch_dims))
6082  cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)])
6083  cholesky_exp = cholesky_exp.reshape_as(A)
6084  self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
6085 
6086  for upper, batchsize in product([True, False], [(3,), (3, 4), (2, 3, 4)]):
6087  cholesky_test_helper(3, batchsize, cast, upper)
6088 
6089  @skipIfNoLapack
6090  def test_cholesky_batched(self):
6091  self._test_cholesky_batched(self, lambda t: t)
6092 
6093  @staticmethod
6094  def _test_cholesky_solve(self, cast):
6095  a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
6096  (-6.05, -3.30, 5.36, -4.44, 1.08),
6097  (-0.45, 2.58, -2.70, 0.27, 9.04),
6098  (8.32, 2.71, 4.35, -7.17, 2.14),
6099  (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
6100  b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
6101  (-1.56, 4.00, -8.67, 1.75, 2.86),
6102  (9.81, -4.09, -4.57, -8.61, 8.99))).t()
6103 
6104  # make sure 'a' is symmetric PSD
6105  a = torch.mm(a, a.t())
6106  a, b = cast(a), cast(b)
6107 
6108  # upper Triangular Test
6109  U = torch.cholesky(a, True)
6110  x = torch.cholesky_solve(b, U, True)
6111  self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
6112 
6113  # lower Triangular Test
6114  L = torch.cholesky(a, False)
6115  x = torch.cholesky_solve(b, L, False)
6116  self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
6117 
6118  # default arg Test
6119  L_def = torch.cholesky(a)
6120  x_def = torch.cholesky_solve(b, L_def)
6121  self.assertLessEqual(b.dist(torch.mm(a, x_def)), 1e-12)
6122 
6123  @skipIfNoLapack
6124  def test_cholesky_solve(self):
6125  self._test_cholesky_solve(self, lambda t: t)
6126 
6127  @staticmethod
6128  def _test_cholesky_solve_batched(self, cast):
6129  from common_utils import random_symmetric_pd_matrix
6130 
6131  def cholesky_solve_test_helper(A_dims, b_dims, cast, upper):
6132  A = cast(random_symmetric_pd_matrix(*A_dims))
6133  L = torch.cholesky(A, upper)
6134  b = cast(torch.randn(*b_dims))
6135  return A, L, b
6136 
6137  for upper in [True, False]:
6138  # test against cholesky_solve: one batch with both choices of upper
6139  A, L, b = cholesky_solve_test_helper((5, 1), (1, 5, 10), cast, upper)
6140  x_exp = torch.cholesky_solve(b.squeeze(0), L.squeeze(0), upper=upper)
6141  x = torch.cholesky_solve(b, L, upper=upper)
6142  self.assertEqual(x, x_exp.unsqueeze(0))
6143 
6144  # test against cholesky_solve in a loop: four batches with both choices of upper
6145  A, L, b = cholesky_solve_test_helper((5, 4), (4, 5, 10), cast, upper)
6146  x_exp_list = []
6147  for i in range(4):
6148  x_exp = torch.cholesky_solve(b[i], L[i], upper=upper)
6149  x_exp_list.append(x_exp)
6150  x_exp = torch.stack(x_exp_list)
6151 
6152  x = torch.cholesky_solve(b, L, upper=upper)
6153  self.assertEqual(x, x_exp)
6154 
6155  # basic correctness test
6156  A, L, b = cholesky_solve_test_helper((5, 3), (3, 5, 10), cast, upper)
6157  x = torch.cholesky_solve(b, L, upper)
6158  self.assertLessEqual(b.dist(torch.matmul(A, x)), 1e-12)
6159 
6160  # Test non-contiguous inputs.
6161  if not TEST_NUMPY:
6162  return
6163  import numpy
6164  from numpy.linalg import solve
6165  A = random_symmetric_pd_matrix(2, 2)
6166  b = torch.randn(2, 2, 2)
6167  x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy()))
6168  A = cast(A).permute(0, 2, 1)
6169  b = cast(b).permute(2, 1, 0)
6170  assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs"
6171  L = torch.cholesky(A, upper)
6172  x = torch.cholesky_solve(b, L, upper=upper)
6173  self.assertEqual(x, cast(x_exp))
6174 
6175  @skipIfNoLapack
6176  def test_cholesky_solve_batched(self):
6177  self._test_cholesky_solve_batched(self, lambda t: t)
6178 
6179  @staticmethod
6180  def _test_cholesky_solve_batched_dims(self, cast):
6181  if not TEST_NUMPY:
6182  return
6183 
6184  from numpy.linalg import solve
6185  from common_utils import random_symmetric_pd_matrix
6186 
6187  def run_test(A_dims, b_dims, cast, upper):
6188  A = random_symmetric_pd_matrix(*A_dims)
6189  b = torch.randn(*b_dims)
6190  x_exp = torch.Tensor(solve(A.numpy(), b.numpy()))
6191  A, b = cast(A), cast(b)
6192  L = torch.cholesky(A, upper)
6193  x = torch.cholesky_solve(b, L, upper=upper)
6194  self.assertEqual(x, cast(x_exp))
6195 
6196  for upper in [True, False]:
6197  # test against numpy.linalg.solve
6198  run_test((4, 2, 1, 3), (2, 1, 3, 4, 6), cast, upper) # no broadcasting
6199  run_test((4, 2, 1, 3), (4, 6), cast, upper) # broadcasting b
6200  run_test((4,), (2, 1, 3, 4, 2), cast, upper) # broadcasting A
6201  run_test((4, 1, 3, 1), (2, 1, 3, 4, 5), cast, upper) # broadcasting A & b
6202 
6203  @skipIfNoLapack
6204  def test_cholesky_solve_batched_dims(self):
6205  self._test_cholesky_solve_batched_dims(self, lambda t: t)
6206 
6207  @skipIfNoLapack
6208  def test_potri(self):
6209  a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
6210  (-6.05, -3.30, 5.36, -4.44, 1.08),
6211  (-0.45, 2.58, -2.70, 0.27, 9.04),
6212  (8.32, 2.71, 4.35, -7.17, 2.14),
6213  (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
6214 
6215  # make sure 'a' is symmetric PSD
6216  a = torch.mm(a, a.t())
6217 
6218  # compute inverse directly
6219  inv0 = torch.inverse(a)
6220 
6221  # default case
6222  chol = torch.cholesky(a)
6223  inv1 = torch.potri(chol, False)
6224  self.assertLessEqual(inv0.dist(inv1), 1e-12)
6225 
6226  # upper Triangular Test
6227  chol = torch.cholesky(a, True)
6228  inv1 = torch.potri(chol, True)
6229  self.assertLessEqual(inv0.dist(inv1), 1e-12)
6230 
6231  # lower Triangular Test
6232  chol = torch.cholesky(a, False)
6233  inv1 = torch.potri(chol, False)
6234  self.assertLessEqual(inv0.dist(inv1), 1e-12)
6235 
6236  @skipIfNoLapack
6237  def test_pstrf(self):
6238  def checkPsdCholesky(a, uplo, inplace):
6239  if inplace:
6240  u = torch.empty_like(a)
6241  piv = a.new(a.size(0)).int()
6242  kwargs = {'out': (u, piv)}
6243  else:
6244  kwargs = {}
6245  args = [a]
6246 
6247  if uplo is not None:
6248  args += [uplo]
6249 
6250  u, piv = torch.pstrf(*args, **kwargs)
6251 
6252  if uplo is False:
6253  a_reconstructed = torch.mm(u, u.t())
6254  else:
6255  a_reconstructed = torch.mm(u.t(), u)
6256 
6257  piv = piv.long()
6258  a_permuted = a.index_select(0, piv).index_select(1, piv)
6259  self.assertEqual(a_permuted, a_reconstructed, 1e-14)
6260 
6261  dimensions = ((5, 1), (5, 3), (5, 5), (10, 10))
6262  for dim in dimensions:
6263  m = torch.Tensor(*dim).uniform_()
6264  a = torch.mm(m, m.t())
6265  # add a small number to the diagonal to make the matrix numerically positive semidefinite
6266  for i in range(m.size(0)):
6267  a[i][i] = a[i][i] + 1e-7
6268  for inplace in (True, False):
6269  for uplo in (None, True, False):
6270  checkPsdCholesky(a, uplo, inplace)
6271 
6272  def test_numel(self):
6273  b = torch.ByteTensor(3, 100, 100)
6274  self.assertEqual(b.nelement(), 3 * 100 * 100)
6275  self.assertEqual(b.numel(), 3 * 100 * 100)
6276 
6277  def _consecutive(self, size, start=1):
6278  sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0)
6279  sequence.add_(start - 1)
6280  return sequence.resize_(*size)
6281 
6282  @staticmethod
6283  def _test_index(self, conv_fn):
6284 
6285  def consec(size, start=1):
6286  sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0)
6287  sequence.add_(start - 1)
6288  return sequence.view(*size)
6289 
6290  reference = conv_fn(consec((3, 3, 3)))
6291 
6292  # empty tensor indexing
6293  self.assertEqual(reference[conv_fn(torch.LongTensor())], reference.new(0, 3, 3))
6294 
6295  self.assertEqual(reference[0], consec((3, 3)), 0)
6296  self.assertEqual(reference[1], consec((3, 3), 10), 0)
6297  self.assertEqual(reference[2], consec((3, 3), 19), 0)
6298  self.assertEqual(reference[0, 1], consec((3,), 4), 0)
6299  self.assertEqual(reference[0:2], consec((2, 3, 3)), 0)
6300  self.assertEqual(reference[2, 2, 2], 27, 0)
6301  self.assertEqual(reference[:], consec((3, 3, 3)), 0)
6302 
6303  # indexing with Ellipsis
6304  self.assertEqual(reference[..., 2], torch.Tensor([[3, 6, 9],
6305  [12, 15, 18],
6306  [21, 24, 27]]), 0)
6307  self.assertEqual(reference[0, ..., 2], torch.Tensor([3, 6, 9]), 0)
6308  self.assertEqual(reference[..., 2], reference[:, :, 2], 0)
6309  self.assertEqual(reference[0, ..., 2], reference[0, :, 2], 0)
6310  self.assertEqual(reference[0, 2, ...], reference[0, 2], 0)
6311  self.assertEqual(reference[..., 2, 2, 2], 27, 0)
6312  self.assertEqual(reference[2, ..., 2, 2], 27, 0)
6313  self.assertEqual(reference[2, 2, ..., 2], 27, 0)
6314  self.assertEqual(reference[2, 2, 2, ...], 27, 0)
6315  self.assertEqual(reference[...], reference, 0)
6316 
6317  reference_5d = conv_fn(consec((3, 3, 3, 3, 3)))
6318  self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], 0)
6319  self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], 0)
6320  self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], 0)
6321  self.assertEqual(reference_5d[...], reference_5d, 0)
6322 
6323  # LongTensor indexing
6324  reference = conv_fn(consec((5, 5, 5)))
6325  idx = conv_fn(torch.LongTensor([2, 4]))
6326  self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
6327  # TODO: enable one indexing is implemented like in numpy
6328  # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
6329  # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
6330 
6331  # None indexing
6332  self.assertEqual(reference[2, None], reference[2].unsqueeze(0))
6333  self.assertEqual(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0))
6334  self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1))
6335  self.assertEqual(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0))
6336  self.assertEqual(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2))
6337 
6338  # indexing 0-length slice
6339  self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)])
6340  self.assertEqual(torch.empty(0, 5), reference[slice(0), 2])
6341  self.assertEqual(torch.empty(0, 5), reference[2, slice(0)])
6342  self.assertEqual(torch.tensor([]), reference[2, 1:1, 2])
6343 
6344  # indexing with step
6345  reference = consec((10, 10, 10))
6346  self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0))
6347  self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0))
6348  self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0))
6349  self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1))
6350  self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0))
6351  self.assertEqual(reference[None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0))
6352  self.assertEqual(reference[:, 2, 1:6:2],
6353  torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1))
6354 
6355  lst = [list(range(i, i + 10)) for i in range(0, 100, 10)]
6356  tensor = conv_fn(torch.DoubleTensor(lst))
6357  for _i in range(100):
6358  idx1_start = random.randrange(10)
6359  idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
6360  idx1_step = random.randrange(1, 8)
6361  idx1 = slice(idx1_start, idx1_end, idx1_step)
6362  if random.randrange(2) == 0:
6363  idx2_start = random.randrange(10)
6364  idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
6365  idx2_step = random.randrange(1, 8)
6366  idx2 = slice(idx2_start, idx2_end, idx2_step)
6367  lst_indexed = list(map(lambda l: l[idx2], lst[idx1]))
6368  tensor_indexed = tensor[idx1, idx2]
6369  else:
6370  lst_indexed = lst[idx1]
6371  tensor_indexed = tensor[idx1]
6372  self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
6373 
6374  self.assertRaises(ValueError, lambda: reference[1:9:0])
6375  self.assertRaises(ValueError, lambda: reference[1:9:-1])
6376 
6377  self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
6378  self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
6379  self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
6380 
6381  self.assertRaises(IndexError, lambda: reference[0.0])
6382  self.assertRaises(TypeError, lambda: reference[0.0:2.0])
6383  self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
6384  self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
6385  self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
6386  self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
6387 
6388  def delitem():
6389  del reference[0]
6390 
6391  self.assertRaises(TypeError, delitem)
6392 
6393  def test_index(self):
6394  self._test_index(self, lambda x: x)
6395 
6396  @staticmethod
6397  def _test_advancedindex(self, conv_fn):
6398  # Tests for Integer Array Indexing, Part I - Purely integer array
6399  # indexing
6400 
6401  def consec(size, start=1):
6402  numel = reduce(lambda x, y: x * y, size, 1)
6403  sequence = torch.ones(numel).cumsum(0)
6404  sequence.add_(start - 1)
6405  return sequence.view(*size)
6406 
6407  # pick a random valid indexer type
6408  def ri(indices):
6409  choice = random.randint(0, 2)
6410  if choice == 0:
6411  return conv_fn(torch.LongTensor(indices))
6412  elif choice == 1:
6413  return list(indices)
6414  else:
6415  return tuple(indices)
6416 
6417  def validate_indexing(x):
6418  self.assertEqual(x[[0]], consec((1,)))
6419  self.assertEqual(x[ri([0]), ], consec((1,)))
6420  self.assertEqual(x[ri([3]), ], consec((1,), 4))
6421  self.assertEqual(x[[2, 3, 4]], consec((3,), 3))
6422  self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3))
6423  self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([1, 3, 5]))
6424 
6425  def validate_setting(x):
6426  dtype = x.type()
6427  x[[0]] = -2
6428  self.assertEqual(x[[0]], torch.Tensor([-2]).type(dtype))
6429  x[[0]] = -1
6430  self.assertEqual(x[ri([0]), ], torch.Tensor([-1]).type(dtype))
6431  x[[2, 3, 4]] = 4
6432  self.assertEqual(x[[2, 3, 4]], torch.Tensor([4, 4, 4]).type(dtype))
6433  x[ri([2, 3, 4]), ] = 3
6434  self.assertEqual(x[ri([2, 3, 4]), ], torch.Tensor([3, 3, 3]).type(dtype))
6435  x[ri([0, 2, 4]), ] = conv_fn(torch.Tensor([5, 4, 3])).type(dtype)
6436  self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([5, 4, 3]).type(dtype))
6437 
6438  # First, we will test indexing to generate return values
6439 
6440  # Case 1: Purely Integer Array Indexing
6441  reference = conv_fn(consec((10,)))
6442  validate_indexing(reference)
6443  validate_indexing(reference.type(torch.half))
6444 
6445  # setting values
6446  validate_setting(reference)
6447  validate_setting(reference.type(torch.half))
6448 
6449  # Tensor with stride != 1
6450 
6451  # strided is [1, 3, 5, 7]
6452  reference = conv_fn(consec((10,)))
6453  strided = conv_fn(torch.Tensor())
6454  strided.set_(reference.storage(), storage_offset=0,
6455  size=torch.Size([4]), stride=[2])
6456 
6457  self.assertEqual(strided[[0]], torch.Tensor([1]))
6458  self.assertEqual(strided[ri([0]), ], torch.Tensor([1]))
6459  self.assertEqual(strided[ri([3]), ], torch.Tensor([7]))
6460  self.assertEqual(strided[[1, 2]], torch.Tensor([3, 5]))
6461  self.assertEqual(strided[ri([1, 2]), ], torch.Tensor([3, 5]))
6462  self.assertEqual(strided[ri([[2, 1], [0, 3]]), ],
6463  torch.Tensor([[5, 3], [1, 7]]))
6464 
6465  # stride is [4, 8]
6466  strided = conv_fn(torch.Tensor())
6467  strided.set_(reference.storage(), storage_offset=4,
6468  size=torch.Size([2]), stride=[4])
6469  self.assertEqual(strided[[0]], torch.Tensor([5]))
6470  self.assertEqual(strided[ri([0]), ], torch.Tensor([5]))
6471  self.assertEqual(strided[ri([1]), ], torch.Tensor([9]))
6472  self.assertEqual(strided[[0, 1]], torch.Tensor([5, 9]))
6473  self.assertEqual(strided[ri([0, 1]), ], torch.Tensor([5, 9]))
6474  self.assertEqual(strided[ri([[0, 1], [1, 0]]), ],
6475  torch.Tensor([[5, 9], [9, 5]]))
6476 
6477  # reference is 1 2
6478  # 3 4
6479  # 5 6
6480  reference = conv_fn(consec((3, 2)))
6481  self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([1, 3, 5]))
6482  self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([2, 4, 6]))
6483  self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
6484  self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
6485  self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([1, 2]))
6486  self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
6487  torch.Tensor([2, 4, 4, 2, 6]))
6488  self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6489  torch.Tensor([1, 2, 3, 3]))
6490 
6491  rows = ri([[0, 0],
6492  [1, 2]])
6493  columns = [0],
6494  self.assertEqual(reference[rows, columns], torch.Tensor([[1, 1],
6495  [3, 5]]))
6496 
6497  rows = ri([[0, 0],
6498  [1, 2]])
6499  columns = ri([1, 0])
6500  self.assertEqual(reference[rows, columns], torch.Tensor([[2, 1],
6501  [4, 5]]))
6502  rows = ri([[0, 0],
6503  [1, 2]])
6504  columns = ri([[0, 1],
6505  [1, 0]])
6506  self.assertEqual(reference[rows, columns], torch.Tensor([[1, 2],
6507  [4, 5]]))
6508 
6509  # setting values
6510  reference[ri([0]), ri([1])] = -1
6511  self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1]))
6512  reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4]))
6513  self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1,
6514  2, -4]))
6515  reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6516  self.assertEqual(reference[rows, columns],
6517  torch.Tensor([[4, 6], [2, 3]]))
6518 
6519  # Verify still works with Transposed (i.e. non-contiguous) Tensors
6520 
6521  reference = conv_fn(torch.Tensor([[0, 1, 2, 3],
6522  [4, 5, 6, 7],
6523  [8, 9, 10, 11]])).t_()
6524 
6525  # Transposed: [[0, 4, 8],
6526  # [1, 5, 9],
6527  # [2, 6, 10],
6528  # [3, 7, 11]]
6529 
6530  self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([0, 1,
6531  2]))
6532  self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([4, 5,
6533  6]))
6534  self.assertEqual(reference[ri([0]), ri([0])], torch.Tensor([0]))
6535  self.assertEqual(reference[ri([2]), ri([1])], torch.Tensor([6]))
6536  self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([0, 4]))
6537  self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
6538  torch.Tensor([4, 5, 5, 4, 7]))
6539  self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6540  torch.Tensor([0, 4, 1, 1]))
6541 
6542  rows = ri([[0, 0],
6543  [1, 2]])
6544  columns = [0],
6545  self.assertEqual(reference[rows, columns], torch.Tensor([[0, 0],
6546  [1, 2]]))
6547 
6548  rows = ri([[0, 0],
6549  [1, 2]])
6550  columns = ri([1, 0])
6551  self.assertEqual(reference[rows, columns], torch.Tensor([[4, 0],
6552  [5, 2]]))
6553  rows = ri([[0, 0],
6554  [1, 3]])
6555  columns = ri([[0, 1],
6556  [1, 2]])
6557  self.assertEqual(reference[rows, columns], torch.Tensor([[0, 4],
6558  [5, 11]]))
6559 
6560  # setting values
6561  reference[ri([0]), ri([1])] = -1
6562  self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1]))
6563  reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4]))
6564  self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1,
6565  2, -4]))
6566  reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6567  self.assertEqual(reference[rows, columns],
6568  torch.Tensor([[4, 6], [2, 3]]))
6569 
6570  # stride != 1
6571 
6572  # strided is [[1 3 5 7],
6573  # [9 11 13 15]]
6574 
6575  reference = conv_fn(torch.arange(0., 24).view(3, 8))
6576  strided = conv_fn(torch.Tensor())
6577  strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
6578  stride=[8, 2])
6579 
6580  self.assertEqual(strided[ri([0, 1]), ri([0])], torch.Tensor([1, 9]))
6581  self.assertEqual(strided[ri([0, 1]), ri([1])], torch.Tensor([3, 11]))
6582  self.assertEqual(strided[ri([0]), ri([0])], torch.Tensor([1]))
6583  self.assertEqual(strided[ri([1]), ri([3])], torch.Tensor([15]))
6584  self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]], torch.Tensor([1, 7]))
6585  self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
6586  torch.Tensor([9, 11, 11, 9, 15]))
6587  self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
6588  torch.Tensor([1, 3, 9, 9]))
6589 
6590  rows = ri([[0, 0],
6591  [1, 1]])
6592  columns = [0],
6593  self.assertEqual(strided[rows, columns], torch.Tensor([[1, 1],
6594  [9, 9]]))
6595 
6596  rows = ri([[0, 1],
6597  [1, 0]])
6598  columns = ri([1, 2])
6599  self.assertEqual(strided[rows, columns], torch.Tensor([[3, 13],
6600  [11, 5]]))
6601  rows = ri([[0, 0],
6602  [1, 1]])
6603  columns = ri([[0, 1],
6604  [1, 2]])
6605  self.assertEqual(strided[rows, columns], torch.Tensor([[1, 3],
6606  [11, 13]]))
6607 
6608  # setting values
6609 
6610  # strided is [[10, 11],
6611  # [17, 18]]
6612 
6613  reference = conv_fn(torch.arange(0., 24).view(3, 8))
6614  strided = conv_fn(torch.Tensor())
6615  strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6616  stride=[7, 1])
6617  self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([11]))
6618  strided[ri([0]), ri([1])] = -1
6619  self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([-1]))
6620 
6621  reference = conv_fn(torch.arange(0., 24).view(3, 8))
6622  strided = conv_fn(torch.Tensor())
6623  strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6624  stride=[7, 1])
6625  self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([11,
6626  17]))
6627  strided[ri([0, 1]), ri([1, 0])] = conv_fn(torch.Tensor([-1, 2]))
6628  self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([-1,
6629  2]))
6630 
6631  reference = conv_fn(torch.arange(0., 24).view(3, 8))
6632  strided = conv_fn(torch.Tensor())
6633  strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
6634  stride=[7, 1])
6635 
6636  rows = ri([[0],
6637  [1]])
6638  columns = ri([[0, 1],
6639  [0, 1]])
6640  self.assertEqual(strided[rows, columns],
6641  torch.Tensor([[10, 11], [17, 18]]))
6642  strided[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
6643  self.assertEqual(strided[rows, columns],
6644  torch.Tensor([[4, 6], [2, 3]]))
6645 
6646  # Tests using less than the number of dims, and ellipsis
6647 
6648  # reference is 1 2
6649  # 3 4
6650  # 5 6
6651  reference = conv_fn(consec((3, 2)))
6652  self.assertEqual(reference[ri([0, 2]), ], torch.Tensor([[1, 2], [5, 6]]))
6653  self.assertEqual(reference[ri([1]), ...], torch.Tensor([[3, 4]]))
6654  self.assertEqual(reference[..., ri([1])], torch.Tensor([[2], [4], [6]]))
6655 
6656  # verify too many indices fails
6657  with self.assertRaises(IndexError):
6658  reference[ri([1]), ri([0, 2]), ri([3])]
6659 
6660  # test invalid index fails
6661  reference = conv_fn(torch.empty(10))
6662  # can't test cuda because it is a device assert
6663  if not reference.is_cuda:
6664  for err_idx in (10, -11):
6665  with self.assertRaisesRegex(IndexError, r'out of'):
6666  reference[err_idx]
6667  with self.assertRaisesRegex(IndexError, r'out of'):
6668  reference[conv_fn(torch.LongTensor([err_idx]))]
6669  with self.assertRaisesRegex(IndexError, r'out of'):
6670  reference[[err_idx]]
6671 
6672  if TEST_NUMPY:
6673  # we use numpy to compare against, to verify that our advanced
6674  # indexing semantics are the same, and also for ease of test
6675  # writing
6676 
6677  def tensor_indices_to_np(tensor, indices):
6678  # convert the Torch Tensor to a numpy array
6679  if (tensor.is_cuda):
6680  tensor = tensor.cpu()
6681  npt = tensor.numpy()
6682 
6683  # convert indices
6684  idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else
6685  i for i in indices)
6686 
6687  return npt, idxs
6688 
6689  def get_numpy(tensor, indices):
6690  npt, idxs = tensor_indices_to_np(tensor, indices)
6691 
6692  # index and return as a Torch Tensor
6693  return torch.Tensor(npt[idxs])
6694 
6695  def set_numpy(tensor, indices, value):
6696  if not isinstance(value, int):
6697  if value.is_cuda:
6698  value = value.cpu()
6699  value = value.numpy()
6700 
6701  npt, idxs = tensor_indices_to_np(tensor, indices)
6702  npt[idxs] = value
6703  return npt
6704 
6705  def assert_get_eq(tensor, indexer):
6706  self.assertEqual(tensor[indexer],
6707  conv_fn(get_numpy(tensor, indexer)))
6708 
6709  def assert_set_eq(tensor, indexer, val):
6710  pyt = tensor.clone()
6711  numt = tensor.clone()
6712  pyt[indexer] = val
6713  numt = conv_fn(torch.Tensor(set_numpy(numt, indexer, val)))
6714  self.assertEqual(pyt, numt)
6715 
6716  def get_set_tensor(indexed, indexer):
6717  set_size = indexed[indexer].size()
6718  set_count = indexed[indexer].numel()
6719  set_tensor = conv_fn(torch.randperm(set_count).view(set_size).double())
6720  return set_tensor
6721 
6722  # Tensor is 0 1 2 3 4
6723  # 5 6 7 8 9
6724  # 10 11 12 13 14
6725  # 15 16 17 18 19
6726  reference = conv_fn(torch.arange(0., 20).view(4, 5))
6727 
6728  indices_to_test = [
6729  # grab the second, fourth columns
6730  [slice(None), [1, 3]],
6731 
6732  # first, third rows,
6733  [[0, 2], slice(None)],
6734 
6735  # weird shape
6736  [slice(None), [[0, 1],
6737  [2, 3]]],
6738  # negatives
6739  [[-1], [0]],
6740  [[0, 2], [-1]],
6741  [slice(None), [-1]],
6742  ]
6743 
6744  # only test dupes on gets
6745  get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
6746 
6747  for indexer in get_indices_to_test:
6748  assert_get_eq(reference, indexer)
6749 
6750  for indexer in indices_to_test:
6751  assert_set_eq(reference, indexer, 44)
6752  assert_set_eq(reference,
6753  indexer,
6754  get_set_tensor(reference, indexer))
6755 
6756  reference = conv_fn(torch.arange(0., 160).view(4, 8, 5))
6757 
6758  indices_to_test = [
6759  [slice(None), slice(None), [0, 3, 4]],
6760  [slice(None), [2, 4, 5, 7], slice(None)],
6761  [[2, 3], slice(None), slice(None)],
6762  [slice(None), [0, 2, 3], [1, 3, 4]],
6763  [slice(None), [0], [1, 2, 4]],
6764  [slice(None), [0, 1, 3], [4]],
6765  [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
6766  [slice(None), [[0, 1], [2, 3]], [[0]]],
6767  [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
6768  [[0, 2, 3], [1, 3, 4], slice(None)],
6769  [[0], [1, 2, 4], slice(None)],
6770  [[0, 1, 3], [4], slice(None)],
6771  [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
6772  [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
6773  [[[0, 1], [2, 3]], [[0]], slice(None)],
6774  [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
6775  [[[2]], [[0, 3], [4, 1]], slice(None)],
6776 
6777  # less dim, ellipsis
6778  [[0, 2], ],
6779  [[0, 2], slice(None)],
6780  [[0, 2], Ellipsis],
6781  [[0, 2], slice(None), Ellipsis],
6782  [[0, 2], Ellipsis, slice(None)],
6783  [[0, 2], [1, 3]],
6784  [[0, 2], [1, 3], Ellipsis],
6785  [Ellipsis, [1, 3], [2, 3]],
6786  [Ellipsis, [2, 3, 4]],
6787  [Ellipsis, slice(None), [2, 3, 4]],
6788  [slice(None), Ellipsis, [2, 3, 4]],
6789 
6790  # ellipsis counts for nothing
6791  [Ellipsis, slice(None), slice(None), [0, 3, 4]],
6792  [slice(None), Ellipsis, slice(None), [0, 3, 4]],
6793  [slice(None), slice(None), Ellipsis, [0, 3, 4]],
6794  [slice(None), slice(None), [0, 3, 4], Ellipsis],
6795  [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
6796  [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
6797  [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
6798  ]
6799 
6800  for indexer in indices_to_test:
6801  assert_get_eq(reference, indexer)
6802  assert_set_eq(reference, indexer, 212)
6803  assert_set_eq(reference,
6804  indexer,
6805  get_set_tensor(reference, indexer))
6806 
6807  reference = conv_fn(torch.arange(0., 1296).view(3, 9, 8, 6))
6808 
6809  indices_to_test = [
6810  [slice(None), slice(None), slice(None), [0, 3, 4]],
6811  [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
6812  [slice(None), [2, 3], slice(None), slice(None)],
6813  [[1, 2], slice(None), slice(None), slice(None)],
6814  [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
6815  [slice(None), slice(None), [0], [1, 2, 4]],
6816  [slice(None), slice(None), [0, 1, 3], [4]],
6817  [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
6818  [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
6819  [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
6820  [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
6821  [slice(None), [0], [1, 2, 4], slice(None)],
6822  [slice(None), [0, 1, 3], [4], slice(None)],
6823  [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
6824  [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
6825  [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
6826  [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
6827  [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
6828  [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
6829  [[0], [1, 2, 4], slice(None), slice(None)],
6830  [[0, 1, 2], [4], slice(None), slice(None)],
6831  [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
6832  [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
6833  [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
6834  [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
6835  [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
6836  [slice(None), [2, 3, 4], [1, 3, 4], [4]],
6837  [slice(None), [0, 1, 3], [4], [1, 3, 4]],
6838  [slice(None), [6], [0, 2, 3], [1, 3, 4]],
6839  [slice(None), [2, 3, 5], [3], [4]],
6840  [slice(None), [0], [4], [1, 3, 4]],
6841  [slice(None), [6], [0, 2, 3], [1]],
6842  [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
6843  [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
6844  [[2, 0, 1], [1, 2, 3], [4], slice(None)],
6845  [[0, 1, 2], [4], [1, 3, 4], slice(None)],
6846  [[0], [0, 2, 3], [1, 3, 4], slice(None)],
6847  [[0, 2, 1], [3], [4], slice(None)],
6848  [[0], [4], [1, 3, 4], slice(None)],
6849  [[1], [0, 2, 3], [1], slice(None)],
6850  [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
6851 
6852  # less dim, ellipsis
6853  [Ellipsis, [0, 3, 4]],
6854  [Ellipsis, slice(None), [0, 3, 4]],
6855  [Ellipsis, slice(None), slice(None), [0, 3, 4]],
6856  [slice(None), Ellipsis, [0, 3, 4]],
6857  [slice(None), slice(None), Ellipsis, [0, 3, 4]],
6858  [slice(None), [0, 2, 3], [1, 3, 4]],
6859  [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
6860  [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
6861  [[0], [1, 2, 4]],
6862  [[0], [1, 2, 4], slice(None)],
6863  [[0], [1, 2, 4], Ellipsis],
6864  [[0], [1, 2, 4], Ellipsis, slice(None)],
6865  [[1], ],
6866  [[0, 2, 1], [3], [4]],
6867  [[0, 2, 1], [3], [4], slice(None)],
6868  [[0, 2, 1], [3], [4], Ellipsis],
6869  [Ellipsis, [0, 2, 1], [3], [4]],
6870  ]
6871 
6872  for indexer in indices_to_test:
6873  assert_get_eq(reference, indexer)
6874  assert_set_eq(reference, indexer, 1333)
6875  assert_set_eq(reference,
6876  indexer,
6877  get_set_tensor(reference, indexer))
6878  indices_to_test += [
6879  [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
6880  [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
6881  ]
6882  for indexer in indices_to_test:
6883  assert_get_eq(reference, indexer)
6884  assert_set_eq(reference, indexer, 1333)
6885 
6886  def test_advancedindex(self):
6887  self._test_advancedindex(self, lambda x: x)
6888 
6889  @staticmethod
6890  def _test_advancedindex_big(self, conv_fn):
6891  reference = conv_fn(torch.arange(0, 123344).int())
6892 
6893  self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
6894  torch.LongTensor([0, 123, 44488, 68807, 123343]))
6895 
6896  def test_advancedindex_big(self):
6897  self._test_advancedindex_big(self, lambda x: x)
6898 
6899  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
6900  def test_newaxis_numpy_comparison(self):
6901  def run_test(tensor, *idx):
6902  npt = tensor.numpy()
6903  self.assertEqual(tensor[idx], npt[idx])
6904 
6905  # 1D Tensor Tests
6906  x = torch.arange(0, 10)
6907  cases = [
6908  [None],
6909  [None, None],
6910  [Ellipsis, None],
6911  [None, Ellipsis],
6912  [2, None],
6913  [None, 2],
6914  [Ellipsis, None, 2],
6915  [Ellipsis, 2, None],
6916  [2, Ellipsis, None],
6917  [2, None, Ellipsis],
6918  [None, 2, Ellipsis],
6919  [None, Ellipsis, 2],
6920  ]
6921 
6922  for case in cases:
6923  run_test(x, *case)
6924 
6925  # 2D Tensor Tests
6926  x = torch.arange(0, 12).view(3, 4)
6927  cases = [
6928  [None],
6929  [None, None],
6930  [None, None, None],
6931  [Ellipsis, None],
6932  [Ellipsis, None, None],
6933  [None, Ellipsis],
6934  [None, Ellipsis, None],
6935  [None, None, Ellipsis],
6936  [2, None],
6937  [2, None, Ellipsis],
6938  [2, Ellipsis, None],
6939  [None, 2, Ellipsis],
6940  [Ellipsis, 2, None],
6941  [Ellipsis, None, 2],
6942  [None, Ellipsis, 2],
6943  [1, 2, None],
6944  [1, 2, Ellipsis, None],
6945  [1, Ellipsis, 2, None],
6946  [Ellipsis, 1, None, 2],
6947  [Ellipsis, 1, 2, None],
6948  [1, None, 2, Ellipsis],
6949  [None, 1, Ellipsis, 2],
6950  [None, 1, 2, Ellipsis],
6951  ]
6952 
6953  for case in cases:
6954  run_test(x, *case)
6955 
6956  def test_newindex(self):
6957  reference = self._consecutive((3, 3, 3))
6958  # This relies on __index__() being correct - but we have separate tests for that
6959 
6960  def checkPartialAssign(index):
6961  reference = torch.zeros(3, 3, 3)
6962  reference[index] = self._consecutive((3, 3, 3))[index]
6963  self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], 0)
6964  reference[index] = 0
6965  self.assertEqual(reference, torch.zeros(3, 3, 3), 0)
6966 
6967  checkPartialAssign(0)
6968  checkPartialAssign(1)
6969  checkPartialAssign(2)
6970  checkPartialAssign((0, 1))
6971  checkPartialAssign((1, 2))
6972  checkPartialAssign((0, 2))
6973  checkPartialAssign(torch.LongTensor((0, 2)))
6974 
6975  with self.assertRaises(IndexError):
6976  reference[1, 1, 1, 1] = 1
6977  with self.assertRaises(IndexError):
6978  reference[1, 1, 1, (1, 1)] = 1
6979  with self.assertRaises(IndexError):
6980  reference[3, 3, 3, 3, 3, 3, 3, 3] = 1
6981  with self.assertRaises(IndexError):
6982  reference[0.0] = 1
6983  with self.assertRaises(TypeError):
6984  reference[0.0:2.0] = 1
6985  with self.assertRaises(IndexError):
6986  reference[0.0, 0.0:2.0] = 1
6987  with self.assertRaises(IndexError):
6988  reference[0.0, :, 0.0:2.0] = 1
6989  with self.assertRaises(IndexError):
6990  reference[0.0, ..., 0.0:2.0] = 1
6991  with self.assertRaises(IndexError):
6992  reference[0.0, :, 0.0] = 1
6993 
6994  def test_index_copy(self):
6995  num_copy, num_dest = 3, 20
6996  dest = torch.randn(num_dest, 4, 5)
6997  src = torch.randn(num_copy, 4, 5)
6998  idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
6999  dest2 = dest.clone()
7000  dest.index_copy_(0, idx, src)
7001  for i in range(idx.size(0)):
7002  dest2[idx[i]] = src[i]
7003  self.assertEqual(dest, dest2, 0)
7004 
7005  dest = torch.randn(num_dest)
7006  src = torch.randn(num_copy)
7007  idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7008  dest2 = dest.clone()
7009  dest.index_copy_(0, idx, src)
7010  for i in range(idx.size(0)):
7011  dest2[idx[i]] = src[i]
7012  self.assertEqual(dest, dest2, 0)
7013 
7014  def test_index_add(self):
7015  num_copy, num_dest = 3, 3
7016  dest = torch.randn(num_dest, 4, 5)
7017  src = torch.randn(num_copy, 4, 5)
7018  idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7019  dest2 = dest.clone()
7020  dest.index_add_(0, idx, src)
7021  for i in range(idx.size(0)):
7022  dest2[idx[i]] += src[i]
7023  self.assertEqual(dest, dest2)
7024 
7025  dest = torch.randn(num_dest)
7026  src = torch.randn(num_copy)
7027  idx = torch.randperm(num_dest).narrow(0, 0, num_copy)
7028  dest2 = dest.clone()
7029  dest.index_add_(0, idx, src)
7030  for i in range(idx.size(0)):
7031  dest2[idx[i]] = dest2[idx[i]] + src[i]
7032  self.assertEqual(dest, dest2)
7033 
7034  def test_index_select(self):
7035  src = torch.randn(3, 4, 5)
7036  # Index can be duplicated.
7037  idx = torch.LongTensor([2, 1, 0, 1, 2])
7038  dest = torch.index_select(src, 0, idx)
7039  self.assertEqual(dest.shape, (5, 4, 5))
7040  for i in range(idx.size(0)):
7041  self.assertEqual(dest[i], src[idx[i]])
7042 
7043  # Check that 'out' is used correctly.
7044  out = torch.randn(5 * 4 * 5)
7045  dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5))
7046  self.assertEqual(dest.shape, (5, 4, 5))
7047  for i in range(idx.size(0)):
7048  self.assertEqual(dest[i], src[idx[i]])
7049  out.fill_(0.123)
7050  self.assertEqual(out, dest.view(-1)) # Must point to the same storage.
7051 
7052  def test_t(self):
7053  # Test 0D tensors
7054  x = torch.randn(())
7055  self.assertEqual(x, x.t())
7056  x = x.to_sparse()
7057  self.assertEqual(x, x.t())
7058 
7059  # Test 1D tensors
7060  x = torch.arange(4)
7061  self.assertEqual(x, x.t())
7062  x = x.to_sparse()
7063  self.assertEqual(x, x.t())
7064 
7065  # Test 2D tensors
7066  x = torch.rand((2, 2))
7067  self.assertEqual(x.t(), x.transpose(0, 1))
7068  x = x.to_sparse()
7069  self.assertEqual(x.t(), x.transpose(0, 1))
7070 
7071  # Test 3D tensor
7072  x = torch.rand((2, 2, 2))
7073  with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
7074  x.t()
7075  x = x.to_sparse()
7076  with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
7077  x.t()
7078 
7079  def test_take(self):
7080  def check(src, idx):
7081  expected = src.contiguous().view(-1).index_select(
7082  0, idx.contiguous().view(-1)).view_as(idx)
7083  actual = src.take(idx)
7084  self.assertEqual(actual.size(), idx.size())
7085  self.assertEqual(expected, actual)
7086 
7087  src = torch.randn(2, 3, 5)
7088  idx = torch.LongTensor([[0, 2], [3, 4]])
7089  check(src, idx)
7090  check(src.transpose(1, 2), idx)
7091 
7092  def test_take_empty(self):
7093  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
7094  for device in devices:
7095  for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
7096  for indices_shape in [(0,), (0, 1, 2, 0)]:
7097  input = torch.empty(input_shape, device=device)
7098  indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
7099  self.assertEqual(indices, torch.take(input, indices))
7100 
7101  def test_put_(self):
7102  def check(dst, idx, value):
7103  expected = dst.clone().view(-1).index_copy_(
7104  0, idx.contiguous().view(-1), value.contiguous().view(-1))
7105  expected = expected.view_as(dst)
7106  dst.put_(idx, value)
7107  self.assertEqual(expected, dst)
7108 
7109  dst = torch.randn(2, 3, 5)
7110  idx = torch.LongTensor([[0, 2], [3, 4]])
7111  values = torch.randn(2, 2)
7112  check(dst, idx, values)
7113  check(dst.transpose(1, 2), idx, values)
7114 
7115  def test_put_accumulate(self):
7116  dst = torch.ones(2, 2)
7117  idx = torch.LongTensor([[0, 1], [0, 1]])
7118  src = torch.Tensor([1, 2, 3, 4])
7119  dst.put_(idx, src, accumulate=True)
7120  self.assertEqual(dst.tolist(), [[5, 7], [1, 1]])
7121 
7122  def test_put_empty(self):
7123  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
7124  for device in devices:
7125  for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
7126  for indices_shape in [(0,), (0, 1, 2, 0)]:
7127  for accumulate in [False, True]:
7128  dst = torch.randn(dst_shape, device=device)
7129  indices = torch.empty(indices_shape, dtype=torch.int64, device=device)
7130  src = torch.randn(indices_shape, device=device)
7131  self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate))
7132 
7133  # Fill idx with valid indices.
7134  @staticmethod
7135  def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
7136  for i in range(1 if dim == 0 else m):
7137  for j in range(1 if dim == 1 else n):
7138  for k in range(1 if dim == 2 else o):
7139  ii = [i, j, k]
7140  ii[dim] = slice(0, idx.size(dim) + 1)
7141  idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
7142 
7143  def test_flatten(self):
7144  src = torch.randn(5, 5, 5, 5)
7145  flat = src.flatten(0, -1)
7146  self.assertEqual(flat.shape, torch.Size([625]))
7147  self.assertEqual(src.view(-1), flat.view(-1))
7148 
7149  flat = src.flatten(0, 2)
7150  self.assertEqual(flat.shape, torch.Size([125, 5]))
7151  self.assertEqual(src.view(-1), flat.view(-1))
7152 
7153  flat = src.flatten(0, 1)
7154  self.assertEqual(flat.shape, torch.Size([25, 5, 5]))
7155  self.assertEqual(src.view(-1), flat.view(-1))
7156 
7157  flat = src.flatten(1, 2)
7158  self.assertEqual(flat.shape, torch.Size([5, 25, 5]))
7159  self.assertEqual(src.view(-1), flat.view(-1))
7160 
7161  flat = src.flatten(2, 3)
7162  self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
7163  self.assertEqual(src.view(-1), flat.view(-1))
7164 
7165  flat = src.flatten(-2, -1)
7166  self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
7167  self.assertEqual(src.view(-1), flat.view(-1))
7168 
7169  flat = src.flatten(2, 2)
7170  self.assertEqual(flat, src)
7171 
7172  # out of bounds index
7173  with self.assertRaisesRegex(IndexError, 'Dimension out of range'):
7174  src.flatten(5, 10)
7175 
7176  # invalid start and end
7177  with self.assertRaisesRegex(RuntimeError, 'start_dim cannot come after end_dim'):
7178  src.flatten(2, 0)
7179 
7180  @staticmethod
7181  def _test_gather(self, cast, test_bounds=True):
7182  m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
7183  elems_per_row = random.randint(1, 10)
7184  dim = random.randrange(3)
7185 
7186  src = torch.randn(m, n, o)
7187  idx_size = [m, n, o]
7188  idx_size[dim] = elems_per_row
7189  idx = torch.LongTensor().resize_(*idx_size)
7190  _TestTorchMixin._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o)
7191 
7192  src = cast(src)
7193  idx = cast(idx)
7194 
7195  actual = torch.gather(src, dim, idx)
7196  expected = cast(torch.Tensor().resize_(*idx_size))
7197  for i in range(idx_size[0]):
7198  for j in range(idx_size[1]):
7199  for k in range(idx_size[2]):
7200  ii = [i, j, k]
7201  ii[dim] = idx[i, j, k]
7202  expected[i, j, k] = src[tuple(ii)]
7203  self.assertEqual(actual, expected, 0)
7204 
7205  if test_bounds:
7206  idx[0][0][0] = 23
7207  self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
7208 
7209  src = cast(torch.randn(3, 4, 5))
7210  expected, idx = src.max(2, True)
7211  expected = cast(expected)
7212  idx = cast(idx)
7213  actual = torch.gather(src, 2, idx)
7214  self.assertEqual(actual, expected, 0)
7215 
7216  def test_gather(self):
7217  self._test_gather(self, lambda t: t)
7218 
7219  @staticmethod
7220  def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True):
7221  m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
7222  elems_per_row = random.randint(1, 10)
7223