1 from common_utils
import TestCase, run_tests
4 from torch
import tensor
9 def test_single_int(self):
10 v = torch.randn(5, 7, 3)
13 def test_multiple_int(self):
14 v = torch.randn(5, 7, 3)
19 v = torch.randn(5, 7, 3)
22 self.
assertEqual(v[:,
None,
None].shape, (5, 1, 1, 7, 3))
33 def test_step_assignment(self):
39 def test_byte_mask(self):
40 v = torch.randn(5, 7, 3)
41 mask = torch.ByteTensor([1, 0, 1, 1, 0])
43 self.
assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
48 def test_byte_mask_accumulate(self):
49 mask = torch.zeros(size=(10, ), dtype=torch.uint8)
50 y = torch.ones(size=(10, 10))
51 y.index_put_((mask, ), y[mask], accumulate=
True)
54 def test_multiple_byte_mask(self):
55 v = torch.randn(5, 7, 3)
57 mask1 = torch.ByteTensor([1, 0, 1, 1, 0])
58 mask2 = torch.ByteTensor([1, 1, 1])
61 def test_byte_mask2d(self):
62 v = torch.randn(5, 7, 3)
64 num_ones = (c > 0).sum()
68 def test_int_indices(self):
69 v = torch.randn(5, 7, 3)
72 self.
assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
74 def test_int_indices2d(self):
76 x = torch.arange(0, 12).view(4, 3)
79 self.
assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
81 def test_int_indices_broadcast(self):
83 x = torch.arange(0, 12).view(4, 3)
86 result = x[rows[:,
None], columns]
87 self.
assertEqual(result.tolist(), [[0, 2], [9, 11]])
89 def test_empty_index(self):
90 x = torch.arange(0, 12).view(4, 3)
99 mask = torch.zeros(4, 3).byte()
103 def test_empty_ndim_index(self):
105 for device
in devices:
106 x = torch.randn(5, device=device)
107 self.
assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
109 x = torch.randn(2, 3, 4, 5, device=device)
110 self.
assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
111 x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
113 x = torch.empty(10, 0)
119 def test_empty_ndim_index_bool(self):
121 for device
in devices:
122 x = torch.randn(5, device=device)
123 self.assertRaises(IndexError,
lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
125 def test_empty_slice(self):
127 for device
in devices:
128 x = torch.randn(2, 3, 4, 5, device=device)
134 self.assertTrue(z.is_contiguous())
136 def test_index_getitem_copy_bools_slices(self):
144 self.
assertEqual(torch.empty(0, *a.shape), a[
False])
146 self.
assertEqual(torch.empty(0, *a.shape), a[false])
150 def test_index_setitem_bools_slices(self):
159 neg_ones = torch.ones_like(a) * -1
160 neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
161 a[
True] = neg_ones_expanded
165 a[true] = neg_ones_expanded * 2
169 a[
None] = neg_ones_expanded * 3
171 a[...] = neg_ones_expanded * 4
174 with self.assertRaises(IndexError):
175 a[:] = neg_ones_expanded * 5
177 def test_setitem_expansion_error(self):
179 a = torch.randn(2, 3)
181 a_expanded = a.expand(torch.Size([5, 1]) + a.size())
183 with self.assertRaises(RuntimeError):
185 with self.assertRaises(RuntimeError):
188 def test_getitem_scalars(self):
193 a = torch.randn(2, 3)
200 self.
assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
201 self.
assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
202 self.
assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
206 with self.assertRaises(IndexError):
208 with self.assertRaises(IndexError):
212 def test_setitem_scalars(self):
216 a = torch.randn(2, 3)
217 a_set_with_number = a.clone()
218 a_set_with_scalar = a.clone()
221 a_set_with_number[0] = b
222 a_set_with_scalar[zero] = b
223 self.
assertEqual(a_set_with_number, a_set_with_scalar)
229 with self.assertRaises(IndexError):
231 with self.assertRaises(IndexError):
236 def test_basic_advanced_combined(self):
238 x = torch.arange(0, 12).view(4, 3)
243 unmodified = x.clone()
244 x[1:2, [1, 2]].zero_()
248 unmodified = x.clone()
252 def test_int_assignment(self):
253 x = torch.arange(0, 4).view(2, 2)
257 x = torch.arange(0, 4).view(2, 2)
258 x[1] = torch.arange(5, 7)
261 def test_byte_tensor_assignment(self):
262 x = torch.arange(0., 16).view(4, 4)
263 b = torch.ByteTensor([
True,
False,
True,
False])
271 def test_variable_slicing(self):
272 x = torch.arange(0, 16).view(4, 4)
273 indices = torch.IntTensor([0, 1])
277 def test_ellipsis_tensor(self):
278 x = torch.arange(0, 9).view(3, 3)
286 def test_invalid_index(self):
287 x = torch.arange(0, 16).view(4, 4)
290 def test_out_of_bound_index(self):
291 x = torch.arange(0, 100).view(2, 5, 10)
292 self.
assertRaisesRegex(IndexError,
'index 5 is out of bounds for dimension 1 with size 5',
lambda: x[0, 5])
293 self.
assertRaisesRegex(IndexError,
'index 4 is out of bounds for dimension 0 with size 2',
lambda: x[4, 5])
294 self.
assertRaisesRegex(IndexError,
'index 15 is out of bounds for dimension 2 with size 10',
296 self.
assertRaisesRegex(IndexError,
'index 12 is out of bounds for dimension 2 with size 10',
299 def test_zero_dim_index(self):
346 def test_index_no_floats(self):
349 self.assertRaises(IndexError,
lambda: a[0.0])
350 self.assertRaises(IndexError,
lambda: a[0, 0.0])
351 self.assertRaises(IndexError,
lambda: a[0.0, 0])
352 self.assertRaises(IndexError,
lambda: a[0.0, :])
353 self.assertRaises(IndexError,
lambda: a[:, 0.0])
354 self.assertRaises(IndexError,
lambda: a[:, 0.0, :])
355 self.assertRaises(IndexError,
lambda: a[0.0, :, :])
356 self.assertRaises(IndexError,
lambda: a[0, 0, 0.0])
357 self.assertRaises(IndexError,
lambda: a[0.0, 0, 0])
358 self.assertRaises(IndexError,
lambda: a[0, 0.0, 0])
359 self.assertRaises(IndexError,
lambda: a[-1.4])
360 self.assertRaises(IndexError,
lambda: a[0, -1.4])
361 self.assertRaises(IndexError,
lambda: a[-1.4, 0])
362 self.assertRaises(IndexError,
lambda: a[-1.4, :])
363 self.assertRaises(IndexError,
lambda: a[:, -1.4])
364 self.assertRaises(IndexError,
lambda: a[:, -1.4, :])
365 self.assertRaises(IndexError,
lambda: a[-1.4, :, :])
366 self.assertRaises(IndexError,
lambda: a[0, 0, -1.4])
367 self.assertRaises(IndexError,
lambda: a[-1.4, 0, 0])
368 self.assertRaises(IndexError,
lambda: a[0, -1.4, 0])
372 def test_none_index(self):
374 a = tensor([1, 2, 3])
377 def test_empty_tuple_index(self):
379 a = tensor([1, 2, 3])
383 def test_empty_fancy_index(self):
385 a = tensor([1, 2, 3])
388 b = tensor([]).long()
391 b = tensor([]).float()
392 self.assertRaises(IndexError,
lambda: a[b])
394 def test_ellipsis_index(self):
395 a = tensor([[1, 2, 3],
398 self.assertIsNot(a[...], a)
418 def test_single_int_index(self):
420 a = tensor([[1, 2, 3],
428 self.assertRaises(IndexError, a.__getitem__, 1 << 30)
430 self.assertRaises(Exception, a.__getitem__, 1 << 64)
432 def test_single_bool_index(self):
434 a = tensor([[1, 2, 3],
441 def test_boolean_shape_mismatch(self):
442 arr = torch.ones((5, 4, 3))
444 index = tensor([
True])
447 index = tensor([
False] * 6)
450 index = torch.ByteTensor(4, 4).zero_()
455 def test_boolean_indexing_onedim(self):
458 a = tensor([[0., 0., 0.]])
465 def test_boolean_assignment_value_mismatch(self):
468 a = torch.arange(0, 4)
471 a[a > -1] = tensor(v)
477 def test_boolean_indexing_twodim(self):
480 a = tensor([[1, 2, 3],
483 b = tensor([[
True,
False,
True],
484 [
False,
True,
False],
485 [
True,
False,
True]])
496 def test_boolean_indexing_weirdness(self):
498 a = torch.ones((2, 3, 4))
499 self.
assertEqual((0, 2, 3, 4), a[
False,
True, ...].shape)
500 self.
assertEqual(torch.ones(1, 2), a[
True, [0, 1],
True,
True, [1], [[2]]])
501 self.assertRaises(IndexError,
lambda: a[
False, [0, 1], ...])
503 def test_boolean_indexing_weirdness_tensors(self):
507 a = torch.ones((2, 3, 4))
508 self.
assertEqual((0, 2, 3, 4), a[
False,
True, ...].shape)
509 self.
assertEqual(torch.ones(1, 2), a[true, [0, 1], true, true, [1], [[2]]])
510 self.assertRaises(IndexError,
lambda: a[false, [0, 1], ...])
512 def test_boolean_indexing_alldims(self):
514 a = torch.ones((2, 3))
518 def test_boolean_list_indexing(self):
521 a = tensor([[1, 2, 3],
524 b = [
True,
False,
False]
525 c = [
True,
True,
False]
528 self.
assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]]))
531 def test_everything_returns_views(self):
535 self.assertIsNot(a, a[()])
536 self.assertIsNot(a, a[...])
537 self.assertIsNot(a, a[:])
539 def test_broaderrors_indexing(self):
540 a = torch.zeros(5, 5)
541 self.
assertRaisesRegex(IndexError,
'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
542 self.
assertRaisesRegex(IndexError,
'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
544 def test_trivial_fancy_out_of_bounds(self):
546 ind = torch.ones(20, dtype=torch.int64)
548 raise unittest.SkipTest(
'CUDA asserts instead of raising an exception')
550 self.assertRaises(IndexError, a.__getitem__, ind)
551 self.assertRaises(IndexError, a.__setitem__, ind, 0)
552 ind = torch.ones(20, dtype=torch.int64)
554 self.assertRaises(IndexError, a.__getitem__, ind)
555 self.assertRaises(IndexError, a.__setitem__, ind, 0)
557 def test_index_is_larger(self):
559 a = torch.zeros((5, 5))
560 a[[[0], [1], [2]], [0, 1, 2]] = tensor([2., 3., 4.])
562 self.assertTrue((a[:3, :3] == tensor([2., 3., 4.])).all())
564 def test_broadcast_subspace(self):
565 a = torch.zeros((100, 100))
566 v = torch.arange(0., 100)[:,
None]
567 b = torch.arange(99, -1, -1).long()
569 expected = b.double().unsqueeze(1).expand(100, 100)
573 if __name__ ==
'__main__':
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def assertNotEqual(self, x, y, prec=None, message='')