Caffe2 - Python API
A deep learning, cross platform ML framework
test_indexing.py
1 from common_utils import TestCase, run_tests
2 import torch
3 import warnings
4 from torch import tensor
5 import unittest
6 
7 
9  def test_single_int(self):
10  v = torch.randn(5, 7, 3)
11  self.assertEqual(v[4].shape, (7, 3))
12 
13  def test_multiple_int(self):
14  v = torch.randn(5, 7, 3)
15  self.assertEqual(v[4].shape, (7, 3))
16  self.assertEqual(v[4, :, 1].shape, (7,))
17 
18  def test_none(self):
19  v = torch.randn(5, 7, 3)
20  self.assertEqual(v[None].shape, (1, 5, 7, 3))
21  self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
22  self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
23  self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
24 
25  def test_step(self):
26  v = torch.arange(10)
27  self.assertEqual(v[::1], v)
28  self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
29  self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
30  self.assertEqual(v[::11].tolist(), [0])
31  self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
32 
33  def test_step_assignment(self):
34  v = torch.zeros(4, 4)
35  v[0, 1::2] = torch.tensor([3., 4.])
36  self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
37  self.assertEqual(v[1:].sum(), 0)
38 
39  def test_byte_mask(self):
40  v = torch.randn(5, 7, 3)
41  mask = torch.ByteTensor([1, 0, 1, 1, 0])
42  self.assertEqual(v[mask].shape, (3, 7, 3))
43  self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
44 
45  v = torch.tensor([1.])
46  self.assertEqual(v[v == 0], torch.tensor([]))
47 
48  def test_byte_mask_accumulate(self):
49  mask = torch.zeros(size=(10, ), dtype=torch.uint8)
50  y = torch.ones(size=(10, 10))
51  y.index_put_((mask, ), y[mask], accumulate=True)
52  self.assertEqual(y, torch.ones(size=(10, 10)))
53 
54  def test_multiple_byte_mask(self):
55  v = torch.randn(5, 7, 3)
56  # note: these broadcast together and are transposed to the first dim
57  mask1 = torch.ByteTensor([1, 0, 1, 1, 0])
58  mask2 = torch.ByteTensor([1, 1, 1])
59  self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
60 
61  def test_byte_mask2d(self):
62  v = torch.randn(5, 7, 3)
63  c = torch.randn(5, 7)
64  num_ones = (c > 0).sum()
65  r = v[c > 0]
66  self.assertEqual(r.shape, (num_ones, 3))
67 
68  def test_int_indices(self):
69  v = torch.randn(5, 7, 3)
70  self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
71  self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
72  self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
73 
74  def test_int_indices2d(self):
75  # From the NumPy indexing example
76  x = torch.arange(0, 12).view(4, 3)
77  rows = torch.tensor([[0, 0], [3, 3]])
78  columns = torch.tensor([[0, 2], [0, 2]])
79  self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
80 
81  def test_int_indices_broadcast(self):
82  # From the NumPy indexing example
83  x = torch.arange(0, 12).view(4, 3)
84  rows = torch.tensor([0, 3])
85  columns = torch.tensor([0, 2])
86  result = x[rows[:, None], columns]
87  self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
88 
89  def test_empty_index(self):
90  x = torch.arange(0, 12).view(4, 3)
91  idx = torch.tensor([], dtype=torch.long)
92  self.assertEqual(x[idx].numel(), 0)
93 
94  # empty assignment should have no effect but not throw an exception
95  y = x.clone()
96  y[idx] = -1
97  self.assertEqual(x, y)
98 
99  mask = torch.zeros(4, 3).byte()
100  y[mask] = -1
101  self.assertEqual(x, y)
102 
103  def test_empty_ndim_index(self):
104  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
105  for device in devices:
106  x = torch.randn(5, device=device)
107  self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
108 
109  x = torch.randn(2, 3, 4, 5, device=device)
110  self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
111  x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
112 
113  x = torch.empty(10, 0)
114  self.assertEqual(x[[1, 2]].shape, (2, 0))
115  self.assertEqual(x[[], []].shape, (0,))
116  with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
117  x[:, [0, 1]]
118 
119  def test_empty_ndim_index_bool(self):
120  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
121  for device in devices:
122  x = torch.randn(5, device=device)
123  self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
124 
125  def test_empty_slice(self):
126  devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
127  for device in devices:
128  x = torch.randn(2, 3, 4, 5, device=device)
129  y = x[:, :, :, 1]
130  z = y[:, 1:1, :]
131  self.assertEqual((2, 0, 4), z.shape)
132  # this isn't technically necessary, but matches NumPy stride calculations.
133  self.assertEqual((60, 20, 5), z.stride())
134  self.assertTrue(z.is_contiguous())
135 
136  def test_index_getitem_copy_bools_slices(self):
137  true = torch.tensor(1, dtype=torch.uint8)
138  false = torch.tensor(0, dtype=torch.uint8)
139 
140  tensors = [torch.randn(2, 3), torch.tensor(3)]
141 
142  for a in tensors:
143  self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
144  self.assertEqual(torch.empty(0, *a.shape), a[False])
145  self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
146  self.assertEqual(torch.empty(0, *a.shape), a[false])
147  self.assertEqual(a.data_ptr(), a[None].data_ptr())
148  self.assertEqual(a.data_ptr(), a[...].data_ptr())
149 
150  def test_index_setitem_bools_slices(self):
151  true = torch.tensor(1, dtype=torch.uint8)
152  false = torch.tensor(0, dtype=torch.uint8)
153 
154  tensors = [torch.randn(2, 3), torch.tensor(3)]
155 
156  for a in tensors:
157  # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
158  # (some of these ops already prefix a 1 to the size)
159  neg_ones = torch.ones_like(a) * -1
160  neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
161  a[True] = neg_ones_expanded
162  self.assertEqual(a, neg_ones)
163  a[False] = 5
164  self.assertEqual(a, neg_ones)
165  a[true] = neg_ones_expanded * 2
166  self.assertEqual(a, neg_ones * 2)
167  a[false] = 5
168  self.assertEqual(a, neg_ones * 2)
169  a[None] = neg_ones_expanded * 3
170  self.assertEqual(a, neg_ones * 3)
171  a[...] = neg_ones_expanded * 4
172  self.assertEqual(a, neg_ones * 4)
173  if a.dim() == 0:
174  with self.assertRaises(IndexError):
175  a[:] = neg_ones_expanded * 5
176 
177  def test_setitem_expansion_error(self):
178  true = torch.tensor(True)
179  a = torch.randn(2, 3)
180  # check prefix with non-1s doesn't work
181  a_expanded = a.expand(torch.Size([5, 1]) + a.size())
182  # NumPy: ValueError
183  with self.assertRaises(RuntimeError):
184  a[True] = a_expanded
185  with self.assertRaises(RuntimeError):
186  a[true] = a_expanded
187 
188  def test_getitem_scalars(self):
189  zero = torch.tensor(0, dtype=torch.int64)
190  one = torch.tensor(1, dtype=torch.int64)
191 
192  # non-scalar indexed with scalars
193  a = torch.randn(2, 3)
194  self.assertEqual(a[0], a[zero])
195  self.assertEqual(a[0][1], a[zero][one])
196  self.assertEqual(a[0, 1], a[zero, one])
197  self.assertEqual(a[0, one], a[zero, 1])
198 
199  # indexing by a scalar should slice (not copy)
200  self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
201  self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
202  self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
203 
204  # scalar indexed with scalar
205  r = torch.randn(())
206  with self.assertRaises(IndexError):
207  r[:]
208  with self.assertRaises(IndexError):
209  r[zero]
210  self.assertEqual(r, r[...])
211 
212  def test_setitem_scalars(self):
213  zero = torch.tensor(0, dtype=torch.int64)
214 
215  # non-scalar indexed with scalars
216  a = torch.randn(2, 3)
217  a_set_with_number = a.clone()
218  a_set_with_scalar = a.clone()
219  b = torch.randn(3)
220 
221  a_set_with_number[0] = b
222  a_set_with_scalar[zero] = b
223  self.assertEqual(a_set_with_number, a_set_with_scalar)
224  a[1, zero] = 7.7
225  self.assertEqual(7.7, a[1, 0])
226 
227  # scalar indexed with scalars
228  r = torch.randn(())
229  with self.assertRaises(IndexError):
230  r[:] = 8.8
231  with self.assertRaises(IndexError):
232  r[zero] = 8.8
233  r[...] = 9.9
234  self.assertEqual(9.9, r)
235 
236  def test_basic_advanced_combined(self):
237  # From the NumPy indexing example
238  x = torch.arange(0, 12).view(4, 3)
239  self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
240  self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
241 
242  # Check that it is a copy
243  unmodified = x.clone()
244  x[1:2, [1, 2]].zero_()
245  self.assertEqual(x, unmodified)
246 
247  # But assignment should modify the original
248  unmodified = x.clone()
249  x[1:2, [1, 2]] = 0
250  self.assertNotEqual(x, unmodified)
251 
252  def test_int_assignment(self):
253  x = torch.arange(0, 4).view(2, 2)
254  x[1] = 5
255  self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
256 
257  x = torch.arange(0, 4).view(2, 2)
258  x[1] = torch.arange(5, 7)
259  self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
260 
261  def test_byte_tensor_assignment(self):
262  x = torch.arange(0., 16).view(4, 4)
263  b = torch.ByteTensor([True, False, True, False])
264  value = torch.tensor([3., 4., 5., 6.])
265  x[b] = value
266  self.assertEqual(x[0], value)
267  self.assertEqual(x[1], torch.arange(4, 8))
268  self.assertEqual(x[2], value)
269  self.assertEqual(x[3], torch.arange(12, 16))
270 
271  def test_variable_slicing(self):
272  x = torch.arange(0, 16).view(4, 4)
273  indices = torch.IntTensor([0, 1])
274  i, j = indices
275  self.assertEqual(x[i:j], x[0:1])
276 
277  def test_ellipsis_tensor(self):
278  x = torch.arange(0, 9).view(3, 3)
279  idx = torch.tensor([0, 2])
280  self.assertEqual(x[..., idx].tolist(), [[0, 2],
281  [3, 5],
282  [6, 8]])
283  self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
284  [6, 7, 8]])
285 
286  def test_invalid_index(self):
287  x = torch.arange(0, 16).view(4, 4)
288  self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
289 
290  def test_out_of_bound_index(self):
291  x = torch.arange(0, 100).view(2, 5, 10)
292  self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
293  self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
294  self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
295  lambda: x[0, 1, 15])
296  self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
297  lambda: x[:, :, 12])
298 
299  def test_zero_dim_index(self):
300  x = torch.tensor(10)
301  self.assertEqual(x, x.item())
302 
303  def runner():
304  print(x[0])
305  return x[0]
306 
307  self.assertRaisesRegex(IndexError, 'invalid index', runner)
308 
309 
310 # The tests below are from NumPy test_indexing.py with some modifications to
311 # make them compatible with PyTorch. It's licensed under the BDS license below:
312 #
313 # Copyright (c) 2005-2017, NumPy Developers.
314 # All rights reserved.
315 #
316 # Redistribution and use in source and binary forms, with or without
317 # modification, are permitted provided that the following conditions are
318 # met:
319 #
320 # * Redistributions of source code must retain the above copyright
321 # notice, this list of conditions and the following disclaimer.
322 #
323 # * Redistributions in binary form must reproduce the above
324 # copyright notice, this list of conditions and the following
325 # disclaimer in the documentation and/or other materials provided
326 # with the distribution.
327 #
328 # * Neither the name of the NumPy Developers nor the names of any
329 # contributors may be used to endorse or promote products derived
330 # from this software without specific prior written permission.
331 #
332 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
333 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
334 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
335 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
336 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
337 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
338 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
339 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
340 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
341 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
342 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
343 
344 
346  def test_index_no_floats(self):
347  a = torch.tensor([[[5.]]])
348 
349  self.assertRaises(IndexError, lambda: a[0.0])
350  self.assertRaises(IndexError, lambda: a[0, 0.0])
351  self.assertRaises(IndexError, lambda: a[0.0, 0])
352  self.assertRaises(IndexError, lambda: a[0.0, :])
353  self.assertRaises(IndexError, lambda: a[:, 0.0])
354  self.assertRaises(IndexError, lambda: a[:, 0.0, :])
355  self.assertRaises(IndexError, lambda: a[0.0, :, :])
356  self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
357  self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
358  self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
359  self.assertRaises(IndexError, lambda: a[-1.4])
360  self.assertRaises(IndexError, lambda: a[0, -1.4])
361  self.assertRaises(IndexError, lambda: a[-1.4, 0])
362  self.assertRaises(IndexError, lambda: a[-1.4, :])
363  self.assertRaises(IndexError, lambda: a[:, -1.4])
364  self.assertRaises(IndexError, lambda: a[:, -1.4, :])
365  self.assertRaises(IndexError, lambda: a[-1.4, :, :])
366  self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
367  self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
368  self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
369  # self.assertRaises(IndexError, lambda: a[0.0:, 0.0])
370  # self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:])
371 
372  def test_none_index(self):
373  # `None` index adds newaxis
374  a = tensor([1, 2, 3])
375  self.assertEqual(a[None].dim(), a.dim() + 1)
376 
377  def test_empty_tuple_index(self):
378  # Empty tuple index creates a view
379  a = tensor([1, 2, 3])
380  self.assertEqual(a[()], a)
381  self.assertEqual(a[()].data_ptr(), a.data_ptr())
382 
383  def test_empty_fancy_index(self):
384  # Empty list index creates an empty array
385  a = tensor([1, 2, 3])
386  self.assertEqual(a[[]], torch.tensor([]))
387 
388  b = tensor([]).long()
389  self.assertEqual(a[[]], torch.tensor([], dtype=torch.long))
390 
391  b = tensor([]).float()
392  self.assertRaises(IndexError, lambda: a[b])
393 
394  def test_ellipsis_index(self):
395  a = tensor([[1, 2, 3],
396  [4, 5, 6],
397  [7, 8, 9]])
398  self.assertIsNot(a[...], a)
399  self.assertEqual(a[...], a)
400  # `a[...]` was `a` in numpy <1.9.
401  self.assertEqual(a[...].data_ptr(), a.data_ptr())
402 
403  # Slicing with ellipsis can skip an
404  # arbitrary number of dimensions
405  self.assertEqual(a[0, ...], a[0])
406  self.assertEqual(a[0, ...], a[0, :])
407  self.assertEqual(a[..., 0], a[:, 0])
408 
409  # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
410  # we don't have separate 0-dim arrays and scalars.
411  self.assertEqual(a[0, ..., 1], torch.tensor(2))
412 
413  # Assignment with `(Ellipsis,)` on 0-d arrays
414  b = torch.tensor(1)
415  b[(Ellipsis,)] = 2
416  self.assertEqual(b, 2)
417 
418  def test_single_int_index(self):
419  # Single integer index selects one row
420  a = tensor([[1, 2, 3],
421  [4, 5, 6],
422  [7, 8, 9]])
423 
424  self.assertEqual(a[0], [1, 2, 3])
425  self.assertEqual(a[-1], [7, 8, 9])
426 
427  # Index out of bounds produces IndexError
428  self.assertRaises(IndexError, a.__getitem__, 1 << 30)
429  # Index overflow produces Exception NB: different exception type
430  self.assertRaises(Exception, a.__getitem__, 1 << 64)
431 
432  def test_single_bool_index(self):
433  # Single boolean index
434  a = tensor([[1, 2, 3],
435  [4, 5, 6],
436  [7, 8, 9]])
437 
438  self.assertEqual(a[True], a[None])
439  self.assertEqual(a[False], a[None][0:0])
440 
441  def test_boolean_shape_mismatch(self):
442  arr = torch.ones((5, 4, 3))
443 
444  index = tensor([True])
445  self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
446 
447  index = tensor([False] * 6)
448  self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
449 
450  index = torch.ByteTensor(4, 4).zero_()
451  self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
452 
453  self.assertRaisesRegex(IndexError, 'mask', lambda: arr[(slice(None), index)])
454 
455  def test_boolean_indexing_onedim(self):
456  # Indexing a 2-dimensional array with
457  # boolean array of length one
458  a = tensor([[0., 0., 0.]])
459  b = tensor([True])
460  self.assertEqual(a[b], a)
461  # boolean assignment
462  a[b] = 1.
463  self.assertEqual(a, tensor([[1., 1., 1.]]))
464 
465  def test_boolean_assignment_value_mismatch(self):
466  # A boolean assignment should fail when the shape of the values
467  # cannot be broadcast to the subscription. (see also gh-3458)
468  a = torch.arange(0, 4)
469 
470  def f(a, v):
471  a[a > -1] = tensor(v)
472 
473  self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [])
474  self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [1, 2, 3])
475  self.assertRaisesRegex(Exception, 'shape mismatch', f, a[:1], [1, 2, 3])
476 
477  def test_boolean_indexing_twodim(self):
478  # Indexing a 2-dimensional array with
479  # 2-dimensional boolean array
480  a = tensor([[1, 2, 3],
481  [4, 5, 6],
482  [7, 8, 9]])
483  b = tensor([[True, False, True],
484  [False, True, False],
485  [True, False, True]])
486  self.assertEqual(a[b], tensor([1, 3, 5, 7, 9]))
487  self.assertEqual(a[b[1]], tensor([[4, 5, 6]]))
488  self.assertEqual(a[b[0]], a[b[2]])
489 
490  # boolean assignment
491  a[b] = 0
492  self.assertEqual(a, tensor([[0, 2, 0],
493  [4, 0, 6],
494  [0, 8, 0]]))
495 
496  def test_boolean_indexing_weirdness(self):
497  # Weird boolean indexing things
498  a = torch.ones((2, 3, 4))
499  self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
500  self.assertEqual(torch.ones(1, 2), a[True, [0, 1], True, True, [1], [[2]]])
501  self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])
502 
503  def test_boolean_indexing_weirdness_tensors(self):
504  # Weird boolean indexing things
505  false = torch.tensor(False)
506  true = torch.tensor(True)
507  a = torch.ones((2, 3, 4))
508  self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
509  self.assertEqual(torch.ones(1, 2), a[true, [0, 1], true, true, [1], [[2]]])
510  self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])
511 
512  def test_boolean_indexing_alldims(self):
513  true = torch.tensor(True)
514  a = torch.ones((2, 3))
515  self.assertEqual((1, 2, 3), a[True, True].shape)
516  self.assertEqual((1, 2, 3), a[true, true].shape)
517 
518  def test_boolean_list_indexing(self):
519  # Indexing a 2-dimensional array with
520  # boolean lists
521  a = tensor([[1, 2, 3],
522  [4, 5, 6],
523  [7, 8, 9]])
524  b = [True, False, False]
525  c = [True, True, False]
526  self.assertEqual(a[b], tensor([[1, 2, 3]]))
527  self.assertEqual(a[b, b], tensor([1]))
528  self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]]))
529  self.assertEqual(a[c, c], tensor([1, 5]))
530 
531  def test_everything_returns_views(self):
532  # Before `...` would return a itself.
533  a = tensor([5])
534 
535  self.assertIsNot(a, a[()])
536  self.assertIsNot(a, a[...])
537  self.assertIsNot(a, a[:])
538 
539  def test_broaderrors_indexing(self):
540  a = torch.zeros(5, 5)
541  self.assertRaisesRegex(IndexError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
542  self.assertRaisesRegex(IndexError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
543 
544  def test_trivial_fancy_out_of_bounds(self):
545  a = torch.zeros(5)
546  ind = torch.ones(20, dtype=torch.int64)
547  if a.is_cuda:
548  raise unittest.SkipTest('CUDA asserts instead of raising an exception')
549  ind[-1] = 10
550  self.assertRaises(IndexError, a.__getitem__, ind)
551  self.assertRaises(IndexError, a.__setitem__, ind, 0)
552  ind = torch.ones(20, dtype=torch.int64)
553  ind[0] = 11
554  self.assertRaises(IndexError, a.__getitem__, ind)
555  self.assertRaises(IndexError, a.__setitem__, ind, 0)
556 
557  def test_index_is_larger(self):
558  # Simple case of fancy index broadcasting of the index.
559  a = torch.zeros((5, 5))
560  a[[[0], [1], [2]], [0, 1, 2]] = tensor([2., 3., 4.])
561 
562  self.assertTrue((a[:3, :3] == tensor([2., 3., 4.])).all())
563 
564  def test_broadcast_subspace(self):
565  a = torch.zeros((100, 100))
566  v = torch.arange(0., 100)[:, None]
567  b = torch.arange(99, -1, -1).long()
568  a[b] = v
569  expected = b.double().unsqueeze(1).expand(100, 100)
570  self.assertEqual(a, expected)
571 
572 
573 if __name__ == '__main__':
574  run_tests()
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def is_available()
Definition: __init__.py:45
def assertNotEqual(self, x, y, prec=None, message='')