8 from copy
import deepcopy
9 from collections
import OrderedDict
10 from itertools
import product
11 from operator
import mul, itemgetter
12 from functools
import reduce, wraps
18 from common_utils
import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
19 suppress_warnings, skipIfRocm,
20 prod_single_zero, random_square_matrix_of_rank,
21 random_symmetric_matrix, random_symmetric_psd_matrix,
22 random_symmetric_pd_matrix, make_nonzero_det,
23 random_fullrank_matrix_distinct_singular_value, load_tests)
24 from common_cuda
import TEST_CUDA
28 from common_methods_invocations
import (method_tests,
29 create_input, unpack_variables,
30 EXCLUDE_FUNCTIONAL, EXCLUDE_GRADCHECK,
31 EXCLUDE_GRADGRADCHECK,
32 EXCLUDE_GRADGRADCHECK_BY_TEST_NAME,
33 exclude_tensor_method,
39 load_tests = load_tests
41 if sys.version_info[0] == 2:
42 import cPickle
as pickle
49 @contextlib.contextmanager
50 def backward_engine(engine):
51 _prev_engine = Variable._execution_engine
52 Variable._execution_engine = engine()
56 Variable._execution_engine = _prev_engine
62 result = type(fn).__name__ +
'(' 63 next_functions = fn.next_functions
64 for next_fn, _
in next_functions:
65 result += graph_desc(next_fn)
74 def _function_test(self, cls):
75 x = torch.randn(5, 5, requires_grad=
True)
76 y = torch.randn(5, 5, requires_grad=
True)
77 result = cls.apply(x, 2, y)
78 go = torch.ones((), requires_grad=
True)
79 result.sum().backward(go, create_graph=
True)
81 self.assertEqual(x.grad.data, y.data + torch.ones(5, 5))
82 self.assertEqual(y.grad.data, x.data + torch.ones(5, 5) * 2)
83 self.assertIsNotNone(x.grad.grad_fn)
84 self.assertIsNotNone(y.grad.grad_fn)
88 def test_function(self):
89 class MyFunction(Function):
92 def forward(ctx, tensor1, pyscalar, tensor2):
93 ctx.pyscalar = pyscalar
94 ctx.save_for_backward(tensor1, tensor2)
95 return tensor1 + pyscalar * tensor2 + tensor1 * tensor2
98 def backward(ctx, grad_output):
99 var1, var2 = ctx.saved_tensors
101 self.assertIsInstance(var1, torch.Tensor)
102 self.assertIsInstance(var2, torch.Tensor)
103 self.assertIsInstance(grad_output, torch.Tensor)
104 return (grad_output + grad_output * var2,
None,
105 grad_output * ctx.pyscalar + grad_output * var1)
109 x_grad_desc = graph_desc(x.grad.grad_fn)
110 y_grad_desc = graph_desc(y.grad.grad_fn)
111 self.assertExpected(x_grad_desc,
"x_grad_desc")
112 self.assertExpected(y_grad_desc,
"y_grad_desc")
114 def test_once_differentiable(self):
115 class MyFunction(Function):
118 def forward(ctx, tensor1, pyscalar, tensor2):
119 ctx.pyscalar = pyscalar
120 ctx.save_for_backward(tensor1, tensor2)
121 return tensor1 + pyscalar * tensor2 + tensor1 * tensor2
125 def backward(ctx, grad_output):
126 self.assertFalse(torch.is_grad_enabled())
127 t1, t2 = ctx.saved_tensors
128 return (grad_output + grad_output * t2,
None,
129 grad_output * ctx.pyscalar + grad_output * t1)
132 self.assertEqual(graph_desc(x.grad.grad_fn),
133 'CloneBackward(Error(AccumulateGrad(), None, AccumulateGrad()))')
134 self.assertEqual(graph_desc(y.grad.grad_fn),
135 'CloneBackward(Error(AccumulateGrad(), None, AccumulateGrad()))')
137 def test_function_returns_input(self):
138 class MyFunction(Function):
144 def backward(ctx, grad):
147 for shape
in [(1,), ()]:
148 v = torch.ones(shape, requires_grad=
True)
149 MyFunction.apply(v).backward()
150 self.assertEqual(v.grad, torch.full(shape, 2))
153 MyFunction.apply(v.clone()).backward()
154 self.assertEqual(v.grad, torch.full(shape, 2))
156 def test_legacy_function_none_grad(self):
157 class MyFunction(Function):
158 def forward(self, x):
159 return torch.zeros(2, 2, 2)
161 def backward(self, grad_output):
165 v = torch.ones(shape, requires_grad=
True)
166 y = v[0, 0].expand(3, 5).t().sum()
167 MyFunction()(y).sum().backward()
168 self.assertEqual(v.grad.data, torch.zeros(shape))
170 def test_invalid_gradients(self):
171 class MyFunction(Function):
177 def backward(ctx, grad_output):
178 return torch.randn(10, dtype=torch.float)
180 with self.assertRaisesRegex(RuntimeError,
'expected shape'):
181 input = torch.randn(5, 5, dtype=torch.float, requires_grad=
True)
182 MyFunction.apply(input).sum().backward()
183 with self.assertRaisesRegex(RuntimeError,
'expected type'):
184 input = torch.randn(10, dtype=torch.double, requires_grad=
True)
185 MyFunction.apply(input).sum().backward()
187 def test_accumulate_grad(self):
188 grad_output = torch.ones(5, 5)
190 def compute_grad(create_graph):
191 x = torch.randn(5, 5, requires_grad=
True)
193 y.backward(grad_output, retain_graph=
True)
195 x_grad_clone = x.grad.clone()
196 y.backward(grad_output, create_graph=create_graph)
197 return x_grad, x_grad_clone
200 x_grad, x_grad_clone = compute_grad(create_graph=
False)
201 self.assertEqual(x_grad, x_grad_clone * 2)
204 x_grad, x_grad_clone = compute_grad(create_graph=
True)
205 self.assertEqual(x_grad, x_grad_clone)
207 def test_slogdet_sign(self):
208 a = torch.randn(3, 3, requires_grad=
True)
209 s, logdet = a.slogdet()
212 self.assertFalse(s.requires_grad)
215 def sign_mul_logdet(mat):
216 s, logdet = mat.slogdet()
219 u, s, v = a.detach().svd()
220 s.abs_().clamp_(0.0001)
223 mat = torch.chain_matmul(u, s.diag(), v.t()).requires_grad_()
224 gradcheck(sign_mul_logdet, mat)
225 gradgradcheck(sign_mul_logdet, mat)
227 def test_sum_to_with_empty_dim_grad(self):
228 a = torch.rand(4, 0, requires_grad=
True)
229 b = torch.rand(4, 1, requires_grad=
True)
231 assert c.shape == (4, 0)
234 self.assertEqual(b.grad, torch.zeros(4, 1))
235 self.assertEqual(a.grad, torch.zeros(4, 0))
237 def test_hessian_vector(self):
238 x = torch.randn(2, 2, requires_grad=
True)
239 y = torch.randn(2, 2, requires_grad=
True)
241 z = x ** 2 + y * x + y ** 2
242 z.backward(torch.ones(2, 2), create_graph=
True)
244 x_grad = 2 * x.data + y.data
245 y_grad = x.data + 2 * y.data
246 self.assertEqual(x.grad.data, x_grad)
247 self.assertEqual(y.grad.data, y_grad)
249 grad_sum = 2 * x.grad + y.grad
250 grad_sum.backward(torch.ones(2, 2))
251 x_hv = torch.ones(2, 2) * 5
252 y_hv = torch.ones(2, 2) * 4
253 self.assertEqual(x.grad.data, x_grad + x_hv)
254 self.assertEqual(y.grad.data, y_grad + y_hv)
257 x = torch.randn(2, 2, requires_grad=
True)
258 y = torch.randn(2, 2, requires_grad=
True)
259 z = x ** 2 + y * x + y ** 2
260 z.backward(torch.ones(2, 2), create_graph=
True)
262 x_grad = 2 * x.data + y.data
263 y_grad = x.data + 2 * y.data
264 self.assertEqual(x.grad.data, x_grad)
265 self.assertEqual(y.grad.data, y_grad)
267 grad_sum = 2 * x.grad + y.grad
269 outputs=[grad_sum], grad_outputs=[torch.ones(2, 2)],
270 inputs=[x], create_graph=
True)
271 expected_x_hv = torch.ones(2, 2) * 5
272 expected_y_hv = torch.ones(2, 2) * 4
274 self.assertEqual(x_hv[0].data, expected_x_hv)
275 self.assertEqual(x.grad.data, x_grad)
276 self.assertEqual(y.grad.data, y_grad)
278 def test_grad_nonleaf(self):
279 x_init = torch.randn(2, 2, requires_grad=
True)
281 y = torch.randn(2, 2, requires_grad=
True)
282 grad_output = torch.ones(2, 2)
285 return x ** 2 + y * x + y ** 2
289 fn(x), x, grad_outputs=grad_output, create_graph=
True)
291 grad_x_expected = 2 * x.data + y.data
292 self.assertIsNone(y.grad)
293 self.assertIsNone(x.grad)
294 self.assertEqual(grad_x.data, grad_x_expected)
296 x = x + 0.05 * grad_x
298 val_init = fn(x_init).data.sum()
299 val_final = fn(x).data.sum()
300 self.assertGreater(val_final, val_init)
302 x.backward(grad_output)
303 self.assertIsNotNone(y.grad)
304 self.assertIsNotNone(x_init.grad)
306 def test_grad_nonleaf_many_outputs(self):
310 x = torch.randn(4, 2, requires_grad=
True)
314 hook_called[0] =
True 315 hook_called = [
False]
316 x.register_hook(hook)
318 go = torch.randn(2, 2)
320 (a + 2 * b), [a, b], grad_outputs=go, create_graph=
True)
322 self.assertEqual(grad_a.data, go)
323 self.assertEqual(grad_b.data, go * 2)
324 self.assertFalse(hook_called[0])
325 self.assertIsNone(x.grad)
327 def test_grad_nonleaf_register_hook(self):
331 x = torch.randn(5, requires_grad=
True)
335 hook_results = [
None]
338 hook_results[0] = grad
339 x0.register_hook(hook)
344 self.assertEqual(x.grad, expected_grad)
345 self.assertIsNone(x_list[0].grad)
347 for i
in range(1, 5, 1):
349 self.assertEqual(hook_results[0],
None)
350 expected_grad[i] = 1.0
351 self.assertEqual(x.grad, expected_grad)
352 self.assertIsNone(x_list[i].grad)
354 def test_sharded_grad(self):
355 leaves = [torch.zeros(5, 5, requires_grad=
True)
for _
in range(10)]
356 intermediates = [l * i + l * l
for i, l
in enumerate(leaves)]
357 loss = sum(v * i
for i, v
in enumerate(intermediates)).sum()
360 def group(l, group_size):
361 return (l[i:i + group_size]
for i
in range(0, len(l), group_size))
365 d_intermediates = [d_i
for intermediates_batch
in group(intermediates, shard_size)
370 for i, l
in enumerate(leaves):
371 self.assertEqual(l.grad.data, i * i * (1 + l.data))
373 def test_backward_badcalls(self):
375 with self.assertRaisesRegex(RuntimeError,
'does not require grad'):
378 def test_grad_badcalls(self):
381 with self.assertRaisesRegex(RuntimeError,
'does not require grad'):
383 with self.assertRaisesRegex(RuntimeError,
'does not require grad'):
386 x = torch.ones(1, requires_grad=
True)
390 def test_grad_fn_badcalls(self):
391 error_regex =
'expected .* arguments, got .* instead' 392 x = torch.ones(1, requires_grad=
True)
394 with self.assertRaisesRegex(TypeError, error_regex):
395 y.grad_fn(x.detach(), x.detach())
396 with self.assertRaisesRegex(TypeError, error_regex):
399 y.grad_fn(x.detach())
401 def test_grad_unreachable(self):
402 x = torch.ones(1, requires_grad=
True)
403 y = torch.ones(1, requires_grad=
True)
409 self.assertEqual(grad_x, x * 2)
410 self.assertIsNone(grad_y)
414 z = torch.ones(1, requires_grad=
True)
416 self.assertEqual(grad_x, x * 2)
417 self.assertIsNone(grad_z)
419 def test_hooks(self):
420 x = torch.ones(5, 5, requires_grad=
True)
421 y = Variable(torch.ones(5, 5) * 4, requires_grad=
True)
425 def bw_hook(inc, grad):
426 self.assertIsInstance(grad, torch.Tensor)
429 z = x ** 2 + x * 2 + x * y + y
430 x.register_hook(
lambda *args: bw_hook(0, *args))
431 test = z.register_hook(
lambda *args: bw_hook(1, *args))
432 z.backward(torch.ones(5, 5), retain_graph=
True)
433 self.assertEqual(counter[0], 1)
435 test2 = z.register_hook(
lambda *args: bw_hook(2, *args))
436 z.backward(torch.ones(5, 5), retain_graph=
True)
437 self.assertEqual(counter[0], 4)
440 z.backward(torch.ones(5, 5), retain_graph=
True)
441 self.assertEqual(counter[0], 5)
443 def bw_hook_modify(grad):
447 z.register_hook(bw_hook_modify)
449 z.backward(torch.ones(5, 5), retain_graph=
True)
450 self.assertEqual(y.grad.data, (x.data + 1) * 2)
452 y.register_hook(bw_hook_modify)
454 z.backward(torch.ones(5, 5))
455 self.assertEqual(y.grad.data, (x.data + 1) * 4)
457 def test_hooks_cpp(self):
459 bn = torch.nn.BatchNorm1d(5, affine=
False)
468 x = torch.ones(5, 5, requires_grad=
True)
470 z.register_hook(bw_hook)
473 self.assertEqual(counter[0], 1,
'bw_hook not called')
474 self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)
476 def test_hook_none(self):
479 class NoneGradientFunction(Function):
481 def forward(self, x, y):
482 assert self.needs_input_grad[0]
483 assert not self.needs_input_grad[1]
486 def backward(self, grad_x, grad_y):
489 fn = NoneGradientFunction()
492 def hook(grad_input, grad_output):
493 self.assertIsInstance(grad_input, tuple)
494 self.assertIsInstance(grad_output, tuple)
495 self.assertIsNotNone(grad_input[0])
496 self.assertIsNotNone(grad_input[1])
497 self.assertIsNotNone(grad_output[0])
498 self.assertIsNotNone(grad_output[1])
500 fn.register_hook(hook)
502 x = torch.randn(5, 5, requires_grad=
True)
503 y = torch.randn(5, 5)
504 sum(fn(x, y)).sum().backward()
505 self.assertTrue(was_called[0])
507 def test_retain_grad(self):
508 input = torch.rand(1, 3, requires_grad=
True)
510 out = (h1 * h1).sum()
517 out.backward(retain_graph=
True)
518 self.assertEqual(h1.data * 2, h1.grad.data)
519 out.backward(retain_graph=
True)
520 self.assertEqual(h1.data * 4, h1.grad.data)
522 input.grad.data.zero_()
527 self.assertEqual(input.data * 18, input.grad.data)
529 def test_retain_grad_cycle(self):
535 x = torch.ones(5, 5, requires_grad=
True)
543 refs[0] = weakref.ref(y, inc)
548 self.assertIsNone(refs[0]())
549 self.assertEqual(counter[0], 1)
552 def test_backward(self):
553 v_t = torch.randn(5, 5)
554 x_t = torch.randn(5, 5)
555 y_t = torch.rand(5, 5) + 0.1
556 z_t = torch.randn(5, 5)
557 grad_output = torch.randn(5, 5)
558 v = Variable(v_t, requires_grad=
True)
559 x = Variable(x_t, requires_grad=
True)
560 y = Variable(y_t, requires_grad=
True)
561 z = Variable(z_t, requires_grad=
True)
563 v.backward(grad_output)
564 self.assertEqual(v.grad.data, grad_output)
566 a = x + (y * z) + 4 * z ** 2 * x / y
567 a.backward(grad_output)
568 x_grad = 4 * z_t.pow(2) / y_t + 1
569 y_grad = z_t - 4 * x_t * z_t.pow(2) / y_t.pow(2)
570 z_grad = 8 * x_t * z_t / y_t + y_t
571 self.assertEqual(x.grad.data, x_grad * grad_output)
572 self.assertEqual(y.grad.data, y_grad * grad_output)
573 self.assertEqual(z.grad.data, z_grad * grad_output)
575 def test_sparse_backward(self):
576 class FixedGradientFunction(Function):
577 def __init__(self, grad):
580 def forward(self, x):
583 def backward(self, grad_x):
586 size = torch.Size([6, 3, 2])
587 i1 = torch.LongTensor([
591 v1 = torch.DoubleTensor([[1, 2], [4, 5], [7, 8]])
592 sparse_grad1 = torch.sparse.DoubleTensor(i1, v1, size)
593 i2 = torch.LongTensor([
597 v2 = torch.DoubleTensor([[1, 2], [4, 3], [4, 5], [7, 8]])
598 sparse_grad2 = torch.sparse.DoubleTensor(i2, v2, size)
599 dense_grad = torch.rand(size).double()
600 sparse_fn1 = FixedGradientFunction(sparse_grad1)
601 sparse_fn2 = FixedGradientFunction(sparse_grad2)
602 dense_fn = FixedGradientFunction(dense_grad)
605 x = torch.randn(size, requires_grad=
True)
606 (sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward()
607 self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
609 x = torch.randn(size, requires_grad=
True)
610 (dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward()
611 self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
613 x = torch.randn(size, requires_grad=
True)
614 (sparse_fn1(x) + sparse_fn2(x)).sum().backward()
615 self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
617 def test_sparse_mm_backward(self):
619 sparse = torch.sparse_coo_tensor(size, requires_grad=
True)
620 dense = torch.randn(size, requires_grad=
True)
623 with self.assertRaisesRegex(RuntimeError,
624 "calculating the gradient of a sparse Tensor argument to mm is not supported."):
627 z = dense.addmm(sparse, dense)
628 with self.assertRaisesRegex(RuntimeError,
629 "calculating the gradient of a sparse Tensor argument to mm is not supported."):
633 def test_sparse_ctor_getter_backward(self):
635 def test(size, sparse_dim, nnz, device):
636 v_size = [nnz] + list(size[sparse_dim:])
637 i = torch.rand(sparse_dim, nnz)
638 i.mul_(
torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
641 inp = torch.randn(v_size, requires_grad=
True)
642 other = self.genSparseTensor(size, sparse_dim, nnz, is_uncoalesced=
True)[0]
643 other = other.to(device)
646 x = torch.sparse_coo_tensor(i, v, size, device=device)
647 y = (x + other).coalesce()
650 z = torch.sparse_coo_tensor(y.indices(), new_v, y.size())
651 return z.coalesce().values()
653 gradcheck(fn, (inp,))
658 with self.assertRaisesRegex(RuntimeError,
"does not have a grad_fn"):
659 other.detach().requires_grad_()._values().backward(torch.ones_like(other._values()))
664 devices.append(
'cuda')
666 for empty_i, empty_v, empty_nnz
in product([
True,
False], repeat=3):
667 sparse_size = []
if empty_i
else [2, 1]
668 dense_size = [1, 0, 2]
if empty_v
else [1, 2]
669 nnz = 0
if empty_nnz
else 5
670 for device
in devices:
671 test(sparse_size + dense_size, len(sparse_size), nnz, device)
673 def test_multi_backward(self):
674 x = torch.randn(5, 5, requires_grad=
True)
675 y = torch.randn(5, 5, requires_grad=
True)
677 q = torch.randn(5, 5, requires_grad=
True)
679 a = torch.randn(5, 5, requires_grad=
True)
680 b = torch.randn(5, 5, requires_grad=
True)
685 grad_z = torch.randn(5, 5)
686 grad_c = torch.randn(5, 5)
689 self.assertEqual(x.grad.data, grad_z)
690 self.assertEqual(y.grad.data, grad_z)
691 self.assertEqual(a.grad.data, grad_c * b.data)
692 self.assertEqual(b.grad.data, grad_c * a.data)
693 self.assertEqual(q.grad.data, (grad_c + grad_z) * 2)
695 def test_multi_backward_no_grad(self):
696 x = torch.randn(5, 5, requires_grad=
True)
697 y = torch.randn(5, 5, requires_grad=
False)
705 def call_backwards():
707 self.assertRaises(RuntimeError, call_backwards)
709 def test_dependent_backward(self):
710 x = torch.randn(10, requires_grad=
True)
714 go_y = torch.randn(10)
715 go_z = torch.randn(10)
719 self.assertEqual(x.grad.data, 2 * xd * go_y + 6 * xd.pow(5) * go_z)
721 def test_save_output_nr(self):
722 x = torch.randn(10, requires_grad=
True)
724 class MultiOutputFn(Function):
730 def backward(ctx, *grad):
731 return torch.cat(grad)
733 a, b = MultiOutputFn.apply(x)
734 self.assertEqual(b.output_nr, 1)
736 class TestFn(Function):
739 ctx.save_for_backward(b)
743 def backward(ctx, grad_b):
744 b, = ctx.saved_tensors
745 self.assertEqual(b.output_nr, 1)
747 TestFn.apply(b).sum().backward()
749 def test_free_deep_graph(self):
752 x = torch.randn(1, requires_grad=
True)
756 for _
in range(depth):
766 def test_free_deep_graph_complicated(self):
769 randchoice = torch.randint(2, [depth, 2])
770 x = torch.randn(1, requires_grad=
True)
774 prev_values = [
None,
None]
777 for _
in range(depth):
778 prev_tensors = [tensor
for tensor
in prev_values[:-1]
779 if tensor
is not None]
780 prev_values.append(y)
787 nprev = len(prev_tensors)
789 y += randchoice[depth].mul(torch.cat(prev_tensors)).sum()
796 def test_free_deep_graph_pyfunction(self):
797 class MyOp(Function):
799 def forward(ctx, tensor1, tensor2):
800 return tensor1 + tensor2
803 def backward(ctx, grad_output):
804 return grad_output, grad_output
808 x = torch.randn(1, requires_grad=
True)
812 for _
in range(depth):
820 @unittest.skipIf(
not TEST_CUDA,
"need CUDA memory stats")
821 def test_free_unneeded_tensor(self):
822 x = torch.randn(2, 3, 10, 10, device=
'cuda', requires_grad=
True)
823 m = torch.randn(1, 3, 1, 1, device=
'cuda')
827 z = ((x + 2) * m).sum()
833 self.assertEqual(base_mem, end_mem)
835 def test_no_unnecessary_save(self):
840 mu = torch.ones(1, requires_grad=
True)
849 def test_no_grad(self):
850 x = torch.ones(5, 5, requires_grad=
True)
851 y = Variable(torch.ones(5, 5) * 4)
852 with torch.no_grad():
861 self.assertFalse(w.requires_grad)
862 self.assertRaises(RuntimeError,
lambda: w.backward(torch.ones(5, 5)))
863 self.assertIsNone(w.grad_fn)
864 self.assertFalse(z.requires_grad)
865 self.assertRaises(RuntimeError,
lambda: z.backward(torch.ones(5, 5)))
866 self.assertIsNone(z.grad_fn)
869 with torch.no_grad():
870 self.assertFalse(torch.is_grad_enabled())
872 self.assertFalse(torch.is_grad_enabled())
875 """Python Functions should respect grad mode.""" 876 x = torch.ones(5, 5, requires_grad=
True)
878 class MyOp(Function):
880 def forward(self, x):
884 def backward(self, dy):
887 with torch.no_grad():
889 self.assertFalse(y.requires_grad)
892 x = torch.arange(1., 17).view(4, 4)
893 y = Variable(x, requires_grad=
True)
895 def compare(x, y, idx, indexed_tensor, indexed_var):
896 indexed_var_t = indexed_var.data
897 if not isinstance(indexed_tensor, torch.Tensor):
898 indexed_var_t = indexed_var_t[0]
899 self.assertEqual(indexed_tensor, indexed_var_t)
901 indexed_var.sum().backward()
902 expected_grad = torch.Tensor(x.size()).fill_(0)
903 expected_grad[idx] = 1
904 self.assertEqual(y.grad.data, expected_grad)
906 def check_index(x, y, idx):
907 if y.grad
is not None:
909 indexed_tensor = x[idx]
911 compare(x, y, idx, indexed_tensor, indexed_var)
914 check_index(x, y, (1, 1))
915 check_index(x, y, slice(1,
None))
916 check_index(x, y, slice(
None, 2))
917 check_index(x, y, (slice(
None, 2), 2))
918 check_index(x, y, (slice(1, 2), 2))
919 check_index(x, y, (1, slice(2,
None)))
920 check_index(x, y, (slice(
None,
None), slice(2,
None)))
921 check_index(x, y, torch.LongTensor([0, 2]))
922 check_index(x, y, torch.rand(4, 4).bernoulli().byte())
923 check_index(x, y, (Ellipsis, slice(2,
None)))
924 check_index(x, y, ([0], [0]))
925 check_index(x, y, ([1, 2, 3], [0]))
926 check_index(x, y, ([1, 2], [2, 1]))
927 check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]]))
928 check_index(x, y, ([slice(
None), [2, 3]]))
929 check_index(x, y, ([[2, 3], slice(
None)]))
932 check_index(x, y, ([0]))
933 check_index(x, y, ([0], ))
935 x = torch.arange(1., 49).view(4, 3, 4)
936 y = Variable(x, requires_grad=
True)
938 check_index(x, y, (slice(
None), [0], [0]))
939 check_index(x, y, ([0], [0], slice(
None)))
940 check_index(x, y, (slice(
None), [0, 1, 2], [0]))
941 check_index(x, y, ([0, 1, 2], [0], slice(
None)))
942 check_index(x, y, (slice(
None), [1, 2], [2, 1]))
943 check_index(x, y, ([1, 2], [2, 1], slice(
None)))
944 check_index(x, y, (slice(
None), [[1, 2], [2, 0]], [[0, 1], [2, 3]]))
945 check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 2]], slice(
None)))
946 check_index(x, y, (slice(
None), slice(
None), [2, 1]))
947 check_index(x, y, (slice(
None), [2, 1], slice(
None)))
948 check_index(x, y, ([2, 1], slice(
None), slice(
None)))
951 check_index(x, y, ([0], ))
952 check_index(x, y, ([0], slice(
None)))
953 check_index(x, y, ([0], Ellipsis))
954 check_index(x, y, ([1, 2], [0, 1]))
955 check_index(x, y, ([1, 2], [0, 1], Ellipsis))
956 check_index(x, y, (Ellipsis, [1, 2], [0, 1]))
959 z = torch.LongTensor([0, 1])
960 zv = Variable(z, requires_grad=
False)
962 seqv = [zv, Ellipsis]
964 if y.grad
is not None:
966 indexed_tensor = x[seq]
967 indexed_var = y[seqv]
968 compare(x, y, seq, indexed_tensor, indexed_var)
970 def test_indexing_duplicates(self):
971 x = torch.arange(1., 17).view(4, 4)
972 y = Variable(x, requires_grad=
True)
974 idx = torch.LongTensor([1, 1, 3, 2, 1, 2])
975 y[idx].sum().backward()
976 expected_grad = torch.zeros(4, 4)
978 expected_grad[i] += 1
979 self.assertEqual(y.grad.data, expected_grad)
982 x = torch.arange(1., 17).view(4, 4)
983 y = Variable(x, requires_grad=
True)
985 idx = [[1, 1, 3, 2, 1, 2], [0]]
986 y[idx].sum().backward()
987 expected_grad = torch.zeros(4, 4)
990 expected_grad[i][j] += 1
992 self.assertEqual(y.grad.data, expected_grad)
994 x = torch.arange(1., 17).view(4, 4)
995 y = Variable(x, requires_grad=
True)
996 idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]]
997 y[idx].sum().backward()
998 expected_grad = torch.Tensor([[0, 2, 0, 0],
1002 self.assertEqual(y.grad.data, expected_grad)
1004 x = torch.arange(1., 65).view(4, 4, 4)
1005 y = Variable(x, requires_grad=
True)
1007 idx = [[1, 1, 1], slice(
None), slice(
None)]
1008 y[idx].sum().backward()
1009 expected_grad = torch.Tensor(4, 4, 4).zero_()
1010 expected_grad[1].fill_(3)
1011 self.assertEqual(y.grad.data, expected_grad)
1013 def test_volatile_deprecated(self):
1014 v = torch.autograd.torch.randn(3, 3)
1015 with warnings.catch_warnings(record=
True)
as w:
1016 self.assertFalse(v.volatile)
1017 self.assertIn(
'volatile', str(w[0].message))
1019 def test_saved_variables_deprecated(self):
1020 class MyFunction(Function):
1022 def forward(ctx, tensor1, tensor2):
1023 ctx.save_for_backward(tensor1, tensor2)
1024 return tensor1 + tensor2
1027 def backward(ctx, grad_output):
1028 var1, var2 = ctx.saved_variables
1029 return (grad_output, grad_output)
1031 with warnings.catch_warnings(record=
True)
as warns:
1032 warnings.simplefilter(
"always")
1033 x = torch.randn((3, 3), requires_grad=
True)
1034 y = torch.randn((3, 3), requires_grad=
True)
1035 model = MyFunction()
1036 model.apply(x, y).sum().backward()
1038 has_deprecated = map(
lambda warn:
1039 'deprecated' in str(warn)
and 1040 'saved_variables' in str(warn),
1042 has_deprecated = reduce(
lambda x, y: x
or y, has_deprecated)
1043 self.assertTrue(has_deprecated)
1045 def test_requires_grad(self):
1046 x = torch.randn(5, 5)
1047 y = torch.randn(5, 5)
1048 z = torch.randn(5, 5, requires_grad=
True)
1050 self.assertFalse(a.requires_grad)
1052 self.assertTrue(b.requires_grad)
1057 a._backward_hooks = OrderedDict()
1058 x._backward_hooks = OrderedDict()
1059 y._backward_hooks = OrderedDict()
1060 a._backward_hooks[
'test'] = error
1061 x._backward_hooks[
'test'] = error
1062 y._backward_hooks[
'test'] = error
1063 b.backward(torch.ones(5, 5))
1065 def test_requires_grad_(self):
1066 x = torch.randn(5, 5)
1067 y = torch.randn(5, 5, requires_grad=
True)
1068 self.assertIs(x, x.requires_grad_())
1069 self.assertTrue(x.requires_grad)
1070 self.assertIs(y, y.requires_grad_())
1071 self.assertTrue(y.requires_grad)
1072 self.assertIs(x, x.requires_grad_(
True))
1073 self.assertTrue(x.requires_grad)
1074 self.assertIs(y, y.requires_grad_(
True))
1075 self.assertTrue(y.requires_grad)
1077 self.assertRaises(RuntimeError,
lambda: z.requires_grad_(
False))
1078 self.assertIs(z, z.requires_grad_())
1079 self.assertTrue(z.requires_grad)
1080 self.assertIs(z, z.requires_grad_(
True))
1081 self.assertTrue(z.requires_grad)
1083 self.assertIs(x, x.requires_grad_(
False))
1084 self.assertFalse(x.requires_grad)
1085 self.assertIs(y, y.requires_grad_(
False))
1086 self.assertFalse(y.requires_grad)
1088 def test_requires_grad_inplace(self):
1089 a = torch.randn(5, 5)
1090 b = torch.randn(5, 5, requires_grad=
True)
1092 self.assertTrue(a.requires_grad)
1095 a = torch.randn(5, 5) + 0
1096 b = torch.randn(5, 5, requires_grad=
True)
1098 self.assertTrue(a.requires_grad)
1100 def test_no_requires_grad_inplace(self):
1102 a = torch.randn(2, 3)
1104 a.requires_grad =
True 1106 self.assertEqual(a.grad.data, torch.ones(2, 3))
1109 a = torch.randn(2, 3)
1112 a.requires_grad =
True 1114 self.assertEqual(a.grad.data, torch.ones(2, 3))
1117 a = torch.randn(2, 3)
1119 a.requires_grad =
True 1120 with self.assertRaises(RuntimeError):
1122 with self.assertRaises(RuntimeError):
1125 def test_requires_grad_factory(self):
1126 x = torch.randn(2, 3)
1127 fns = [torch.ones_like, torch.testing.randn_like]
1128 dtypes = [torch.float32, torch.float64]
1130 for requires_grad
in [
True,
False]:
1131 for dtype
in dtypes:
1132 for use_cuda
in [
True,
False]:
1134 output = fn(x, dtype=dtype, requires_grad=requires_grad)
1135 self.assertEqual(requires_grad, output.requires_grad)
1136 self.assertIs(dtype, output.dtype)
1138 output = fn(x, dtype=dtype, device=1, requires_grad=requires_grad)
1139 self.assertEqual(requires_grad, output.requires_grad)
1140 self.assertIs(dtype, output.dtype)
1141 self.assertEqual(1, output.get_device())
1143 def test_attribute_deletion(self):
1144 x = torch.randn((5, 5), requires_grad=
True)
1146 self.assertIsNone(x.grad)
1147 with self.assertRaises(RuntimeError):
1149 with self.assertRaises(TypeError):
1151 with self.assertRaises(RuntimeError):
1153 with self.assertRaises(RuntimeError):
1155 with self.assertRaises(RuntimeError):
1156 del x._backward_hooks
1158 def test_grad_assignment(self):
1159 x = torch.randn(5, 5)
1161 with self.assertRaises(RuntimeError):
1162 x.grad = torch.randn(2, 2)
1163 with self.assertRaises(RuntimeError):
1164 x.grad = Variable(torch.randn(5, 5).long())
1165 with self.assertRaises(RuntimeError):
1169 raise unittest.SkipTest(
"CUDA not available")
1170 with self.assertRaises(RuntimeError):
1171 x.grad = Variable(torch.randn(5, 5).cuda())
1173 x.grad = torch.zeros_like(x)
1176 raise unittest.SkipTest(
"At least 2 CUDA devices needed")
1177 x = Variable(torch.randn(5, 5).cuda(0))
1178 with self.assertRaises(RuntimeError):
1179 x.grad = Variable(torch.randn(5, 5).cuda(1))
1181 def test_duplicate_backward_root(self):
1182 a = torch.randn(5, 5, requires_grad=
True)
1183 b = torch.randn(5, 5, requires_grad=
True)
1186 grad_output = torch.randn_like(x)
1189 self.assertEqual(a.grad.data, b.data * grad_output * 2)
1190 self.assertEqual(b.grad.data, a.data * grad_output * 2)
1192 def test_backward_no_grad(self):
1193 a = torch.randn(5, 5, requires_grad=
True)
1195 with self.assertRaises(RuntimeError):
1198 def test_next_functions(self):
1199 x = torch.randn(5, 5, requires_grad=
True)
1200 y = torch.randn(5, 5, requires_grad=
True)
1203 self.assertIsNotNone(a.grad_fn)
1204 next_functions = a.grad_fn.next_functions
1205 self.assertEqual(len(next_functions), 2)
1206 self.assertIsInstance(next_functions[0][0], torch._C._functions.AccumulateGrad)
1207 self.assertEqual(next_functions[0][1], 0)
1208 self.assertIsInstance(next_functions[1][0], torch._C._functions.AccumulateGrad)
1209 self.assertEqual(next_functions[1][1], 0)
1212 next_functions = b.grad_fn.next_functions
1213 self.assertEqual(len(next_functions), 2)
1214 self.assertIs(next_functions[0][0], a.grad_fn)
1215 self.assertIs(next_functions[1][0],
None)
1217 def test_inplace(self):
1218 x = torch.ones(5, 5, requires_grad=
True)
1219 y = Variable(torch.ones(5, 5) * 4, requires_grad=
True)
1226 q.backward(torch.ones(5, 5), retain_graph=
True)
1228 self.assertRaises(RuntimeError,
lambda: w.backward(torch.ones(5, 5)))
1235 w.backward(torch.ones(5, 5), retain_graph=
True)
1237 r.backward(torch.ones(5, 5), retain_graph=
True)
1239 self.assertRaises(RuntimeError,
lambda: q.backward(torch.ones(5, 5)))
1246 prev_version = z._version
1248 self.assertNotEqual(z._version, prev_version)
1249 r.backward(torch.ones(5, 5), retain_graph=
True)
1250 self.assertEqual(x.grad.data, torch.ones(5, 5) / 2)
1251 w.backward(torch.ones(5, 5), retain_graph=
True)
1252 self.assertEqual(x.grad.data, torch.Tensor(5, 5).fill_((1 + math.e) / 2))
1253 self.assertRaises(RuntimeError,
lambda: q.backward(torch.ones(5, 5)))
1255 leaf = torch.ones(5, 5, requires_grad=
True)
1258 self.assertEqual(x.data, torch.ones(5, 5) * 11)
1261 y.backward(torch.ones(5, 5))
1262 self.assertEqual(leaf.grad.data, torch.ones(5, 5))
1265 self.assertRaises(RuntimeError,
lambda: z.backward(torch.ones(5, 5)))
1267 def test_mark_non_differentiable(self):
1268 class MyFunction(Function):
1270 def forward(ctx, input):
1272 ctx.mark_non_differentiable(output)
1276 def backward(ctx, grad_output):
1277 return (grad_output * 0).type(torch.DoubleTensor)
1279 x = torch.randn(5, 5, requires_grad=
True)
1280 mask = MyFunction.apply(x)
1281 self.assertFalse(mask.requires_grad)
1282 y = x.masked_fill(mask, 0)
1285 def test_mark_non_differentiable_mixed(self):
1286 class MyFunction(Function):
1288 def forward(ctx, input):
1291 ctx.mark_non_differentiable(a)
1295 def backward(ctx, grad_a, grad_b):
1296 self.assertTrue((grad_a == 0).all())
1297 self.assertTrue((grad_b == 1).all())
1300 x = torch.randn(5, 5, requires_grad=
True)
1301 a, b = MyFunction.apply(x)
1302 self.assertFalse(a.requires_grad)
1303 self.assertTrue(b.requires_grad)
1305 self.assertEqual(x.grad.data, torch.ones(5, 5))
1307 def test_mark_non_differentiable_none(self):
1311 class MyFunction(Function):
1313 def forward(ctx, input):
1314 output = input.clone()
1315 ctx.mark_non_differentiable(output)
1319 def backward(ctx, grad_output):
1322 x = torch.randn(5, 5, requires_grad=
True)
1323 r = MyFunction.apply(x * x)
1324 (r * x).sum().backward()
1326 def test_return_duplicate(self):
1327 class DoubleDuplicate(Function):
1329 def forward(ctx, x):
1331 return output, output
1334 def backward(ctx, grad1, grad2):
1335 return grad1 * 2 + grad2 * 2
1338 a, b = DoubleDuplicate.apply(x)
1342 x = torch.randn(5, 5, requires_grad=
True)
1344 gradgradcheck(fn, [x])
1346 def test_return_duplicate_inplace(self):
1347 class DoubleInplace(Function):
1349 def forward(ctx, x):
1355 def backward(ctx, grad1, grad2):
1356 return grad1 * 2 + grad2 * 2
1359 a, b = DoubleInplace.apply(x.clone())
1363 x = torch.randn(5, 5, requires_grad=
True)
1364 gradcheck(inplace_fn, [x])
1365 gradgradcheck(inplace_fn, [x])
1368 self.assertRaises(RuntimeError,
lambda: InplaceFunction.apply(x))
1370 self.assertRaises(RuntimeError,
lambda: InplaceFunction.apply(x.clone()[0]))
1373 def test_resize(self):
1374 x = torch.ones(2, 3)
1375 self.assertTrue(x.resize(3, 2).size() == (3, 2))
1377 def _test_setitem(self, size, index):
1378 x = torch.ones(*size, requires_grad=
True)
1380 y_version = y._version
1382 self.assertNotEqual(y._version, y_version)
1383 y.backward(torch.ones(*size))
1384 expected_grad = torch.ones(*size)
1385 expected_grad[index] = 0
1386 self.assertEqual(x.grad, expected_grad)
1388 def _test_setitem_tensor(self, size, index):
1389 x = torch.ones(*size, requires_grad=
True)
1391 y_version = y._version
1392 value = x.new(x[index].size()).fill_(7)
1393 value.requires_grad =
True 1395 self.assertNotEqual(y._version, y_version)
1396 y.backward(torch.ones(*size))
1397 expected_grad_input = torch.ones(*size)
1398 expected_grad_input[index] = 0
1399 self.assertEqual(x.grad, expected_grad_input)
1400 self.assertEqual(value.grad, torch.ones_like(value))
1403 x = torch.randn(4, requires_grad=
True)
1404 y = torch.zeros(2, 3, 4)
1406 y.backward(torch.randn(2, 3, 4))
1407 self.assertEqual(x.size(), x.grad.size())
1409 def test_setitem(self):
1415 self.
_test_setitem((5, 5, 5), [slice(
None), slice(
None), [1, 3]])
1416 self.
_test_setitem((5, 5, 5), [slice(
None), [1, 3], slice(
None)])
1417 self.
_test_setitem((5, 5, 5), [[1, 3], slice(
None), slice(
None)])
1418 self.
_test_setitem((5, 5, 5), [slice(
None), [2, 4], [1, 3]])
1419 self.
_test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(
None)])
1431 3]), requires_grad=
False), [2, 4], slice(
None)])
1433 def test_setitem_mask(self):
1434 mask = torch.ByteTensor(5, 5).bernoulli_()
1441 def test_select_sum(self):
1443 x = torch.randn(10, requires_grad=
True)
1446 return x.select(0, 1).sum()
1448 gradcheck(func, [x])
1449 gradgradcheck(func, [x])
1451 def test_stack(self):
1452 x = torch.randn(10, 10, requires_grad=
True)
1453 y = torch.randn(10, 10, requires_grad=
True)
1454 z = torch.randn(10, 10, requires_grad=
True)
1455 stacked = torch.stack([x, y, z], 0)
1456 grad = torch.randn(3, 10, 10)
1457 stacked.backward(grad)
1458 self.assertEqual(x.grad.data, grad[0])
1459 self.assertEqual(y.grad.data, grad[1])
1460 self.assertEqual(z.grad.data, grad[2])
1462 def test_unbind(self):
1463 stacked = torch.randn(3, 10, 10, requires_grad=
True)
1464 x, y, z = stacked.unbind()
1465 grad = torch.randn(3, 10, 10)
1467 self.assertEqual(stacked.grad.data, grad)
1470 stacked = torch.randn(3, 10, 10, requires_grad=
True)
1471 outs = stacked.unbind()
1472 gi = grad.unbind()[i]
1474 g_expected = torch.stack([gi
if j == i
else torch.zeros_like(gi)
1475 for j
in range(3)], dim=0)
1476 self.assertEqual(g, g_expected)
1479 root = torch.randn(4, 5, requires_grad=
True)
1480 values = torch.randn(6, requires_grad=
True)
1481 idx = Variable(torch.LongTensor([1, 2, 3, -1, -2, -3]))
1483 def func(root, values):
1488 gradcheck(func, [root, values])
1489 gradgradcheck(func, [root, values])
1491 def test_put_accumulate(self):
1492 root = torch.randn(4, 5, requires_grad=
True)
1493 values = torch.randn(6, requires_grad=
True)
1494 idx = Variable(torch.LongTensor([1, 2, 3, 1, 2, 3]))
1496 def func(root, values):
1498 x.put_(idx, values, accumulate=
True)
1501 gradcheck(func, [root, values])
1502 gradgradcheck(func, [root, values])
1504 def test_fill(self):
1505 root = torch.randn(4, 5, requires_grad=
True)
1512 gradcheck(func, [root])
1513 gradgradcheck(func, [root])
1515 def test_unused_output(self):
1516 x = torch.randn(10, 10, requires_grad=
True)
1517 outputs = x.chunk(5)
1521 expected_grad = torch.zeros(10, 10)
1522 expected_grad[4:6] = 4
1523 self.assertEqual(x.grad.data, expected_grad)
1526 grad_output = torch.randn(2, 10)
1527 outputs = x.chunk(5)
1528 outputs[0].backward(grad_output)
1529 expected_grad = torch.zeros(10, 10)
1530 expected_grad[:2] = grad_output
1531 self.assertEqual(x.grad.data, expected_grad)
1534 def test_ctc_loss(self):
1538 gradcheck_input_size = 10
1541 tests = [(
'cpu', 150,
False),
1544 tests += [(
'cuda', 50,
False),
1545 (
'cuda', 150,
False),
1547 (
'cuda', 150,
True)]
1549 for device, input_length, vary_lengths
in tests:
1550 targets = torch.randint(1, num_labels, (batch_size, target_length),
1551 device=device, dtype=torch.long)
1552 x = torch.randn(gradcheck_input_size, device=device, requires_grad=
True)
1553 tile_factors = torch.randn(input_length * batch_size * num_labels // gradcheck_input_size + 1,
1555 input_lengths = [(torch.randint(input_length // 2, input_length + 1, ()).item()
1556 if vary_lengths
or i == 0
else input_length)
for i
in range(batch_size)]
1557 target_lengths = [(torch.randint(target_length // 2, target_length + 1, ()).item()
1558 if vary_lengths
or i == 0
else target_length)
for i
in range(batch_size)]
1560 def ctc_after_softmax(x):
1561 x_full = ((x[:,
None] * tile_factors[
None, :]).view(-1)[:input_length * batch_size * num_labels]
1562 .view(input_length, batch_size, num_labels))
1563 log_probs = torch.log_softmax(x_full, 2)
1566 gradcheck(ctc_after_softmax, [x])
1568 def _test_sparse_gather(self, size_x, size_ind, dim):
1569 x = torch.randn(size_x, requires_grad=
True)
1570 if len(size_ind) > 0
and len(size_x) > 0:
1571 ind = torch.randint(x.size(dim), size_ind)
1573 ind = torch.zeros(size_ind, dtype=torch.int64)
1574 out = torch.gather(x, dim, ind, sparse_grad=
False)
1575 grad = torch.rand_like(out)
1577 grad_dense = x.grad.clone()
1579 out = torch.gather(x, dim, ind, sparse_grad=
True)
1581 self.assertEqual(grad_dense, x.grad.to_dense())
1583 def test_sparse_gather_dim0(self):
1586 def test_sparse_gather_dim1(self):
1589 def test_sparse_gather_dim_neg(self):
1592 def test_sparse_gather_ind_scalar(self):
1595 def test_sparse_gather_x_scalar(self):
1598 def test_sparse_gather_both_scalar(self):
1603 Previously, if a Function destructor triggered a garbage collection, 1604 the Variable's tp_dealloc handler would get called twice leading to a 1607 class CollectOnDelete(Function):
1613 Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
1617 def test_unused_output_gpu(self):
1619 x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=
True)
1623 self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)
1626 def test_backward_device(self):
1630 class Identity(torch.autograd.Function):
1632 def forward(ctx, x):
1636 def backward(ctx, grad_output):
1638 return grad_output.clone()
1640 v = Variable(torch.randn(1).cuda(1), requires_grad=
True)
1641 Identity.apply(v).backward()
1642 self.assertEqual(device[0], 1)
1646 def test_inputbuffer_add_multigpu(self):
1647 input = torch.randn(1).cuda(0).requires_grad_()
1648 output = input.cuda(1) + input.cuda(1)
1651 def test_detach(self):
1652 x = torch.randn(10, 10, requires_grad=
True)
1656 self.assertFalse(y.requires_grad)
1657 self.assertFalse(z.requires_grad)
1659 x = torch.randn(10, 10, requires_grad=
True)
1662 self.assertFalse(y.requires_grad)
1663 self.assertIsNone(y.grad_fn)
1668 self.assertEqual(x.grad.data, torch.ones(10, 10))
1671 x = torch.randn(10, 10, requires_grad=
True)
1672 y = torch.randn(10, 10, requires_grad=
True)
1674 (y + a).sum().backward(retain_graph=
True)
1676 self.assertFalse(a.requires_grad)
1677 (y + a).sum().backward()
1678 self.assertEqual(x.grad.data, torch.ones(10, 10) * 2)
1679 self.assertEqual(y.grad.data, torch.ones(10, 10) * 2)
1682 view = x.narrow(0, 1, 4)
1683 self.assertRaisesRegex(RuntimeError,
'view',
lambda: view.detach_())
1685 def test_detach_base(self):
1686 "detaching base does not detach view" 1687 x = torch.randn(10, 10, requires_grad=
True)
1688 view = x.narrow(0, 1, 4)
1690 self.assertFalse(x.requires_grad)
1691 self.assertTrue(view.requires_grad)
1692 self.assertIsNotNone(view.grad_fn)
1693 self.assertIs(view._base, x)
1695 def _test_type_conversion_backward(self, t, ):
1696 fvar = Variable(t(torch.randn(5, 5).float()), requires_grad=
True)
1697 fvar.double().sum().backward()
1698 self.assertEqual(fvar.grad, torch.ones_like(fvar))
1699 self.assertEqual(type(fvar.grad.data), type(fvar.data))
1700 dvar = Variable(t(torch.randn(5, 5).double()), requires_grad=
True)
1701 dvar.float().sum().backward()
1702 self.assertEqual(dvar.grad, torch.ones_like(dvar))
1703 self.assertEqual(type(dvar.grad.data), type(dvar.data))
1705 def test_type_conversions(self):
1706 x = torch.randn(5, 5)
1707 self.assertIsInstance(x.float(), torch.FloatTensor)
1708 self.assertIsInstance(x.int(), torch.IntTensor)
1710 self.assertIsInstance(x.float().cuda(), torch.cuda.FloatTensor)
1711 self.assertIsInstance(x.int().cuda(), torch.cuda.IntTensor)
1712 self.assertIsInstance(x.int().cuda().cpu(), torch.IntTensor)
1714 x2 = x.float().cuda(1)
1715 self.assertIsInstance(x2, torch.cuda.FloatTensor)
1716 self.assertIs(x2.get_device(), 1)
1717 x2 = x.float().cuda()
1718 self.assertIsInstance(x2.data, torch.cuda.FloatTensor)
1719 self.assertIs(x2.get_device(), 0)
1721 self.assertIsInstance(x2, torch.cuda.FloatTensor)
1722 self.assertIs(x2.get_device(), 1)
1723 y = Variable(torch.randn(5).cuda(1), requires_grad=
True)
1724 y.cpu().sum().backward()
1725 self.assertIs(y.grad.get_device(), 1)
1726 self.assertIs(y.long().data.get_device(), 1)
1728 for t
in [torch.DoubleTensor, torch.FloatTensor, torch.IntTensor, torch.ByteTensor]:
1729 for y_var
in (
True,
False):
1730 y = torch.randint(5, (5, 5), dtype=t.dtype)
1731 y = Variable(y)
if y_var
else y
1732 self.assertIsInstance(x.type(t), t)
1733 self.assertIsInstance(x.type_as(y), t)
1736 self.assertIsInstance(x.type(t_dtype), t)
1737 self.assertIs(t_dtype, x.type(t_dtype).dtype)
1738 self.assertEqual(y.data_ptr(), y.type(t).data_ptr())
1740 for x_cuda
in (
True,
False):
1741 for y_cuda
in (
True,
False):
1742 x_c = x.cuda()
if x_cuda
else x
1743 y_c = y.cuda()
if y_cuda
else y
1744 _, y_type = y_c.type().rsplit(
'.', 1)
1745 y_typestr = (
'torch.cuda.' if y_cuda
else 'torch.') + y_type
1746 self.assertEqual(y_c.type(), x_c.type(y_typestr).type())
1747 self.assertIs(y_c.dtype, x_c.type(y_c.dtype).dtype)
1748 self.assertEqual(y_c.data_ptr(), y_c.cuda().data_ptr()
if y_cuda
else y_c.data_ptr())
1758 def _test_pyscalar_conversions(self, t, integral_conv):
1760 l = t(torch.zeros(1, 1, 1, dtype=torch.long))
1763 self.assertEqual(integral_conv(l), pyscalar)
1766 f = Variable(t(torch.randn(1, 1)))
1769 self.assertEqual(float(f), pyscalar)
1771 self.assertTrue(math.isnan(float(f)))
1773 self.assertEqual(float(f), inf, allow_inf=
True)
1775 self.assertEqual(float(f), -inf, allow_inf=
True)
1779 pyscalar = 1234567890123456789
1780 self.assertNotEqual(pyscalar, integral_conv(float(pyscalar)))
1782 self.assertEqual(float(l), float(pyscalar))
1786 self.assertRaises(ValueError,
lambda: integral_conv(f[0]))
1788 self.assertRaises(OverflowError,
lambda: integral_conv(f[0]))
1790 self.assertRaises(OverflowError,
lambda: integral_conv(f[0]))
1791 f[0] = sys.float_info.max
1792 self.assertEqual(integral_conv(f), sys.float_info.max)
1795 def test_nonzero(tensor, value, expected):
1797 self.assertEqual(expected, bool(tensor))
1798 self.assertEqual(expected,
True if tensor
else False)
1800 test_nonzero(l, 0,
False)
1801 test_nonzero(l, -2,
True)
1802 test_nonzero(f, 0.0,
False)
1803 test_nonzero(f, sys.float_info.min,
True)
1804 test_nonzero(f, nan, bool(nan))
1805 test_nonzero(f, inf, bool(inf))
1806 test_nonzero(f, -inf, bool(-inf))
1808 def test_pyscalar_conversions(self):
1810 if sys.version_info[0] == 2:
1814 if sys.version_info[0] == 2:
1818 def test_pin_memory(self):
1819 x = torch.randn(2, 2, requires_grad=
True)
1820 self.assertEqual(x, x.pin_memory())
1821 self.assertIsNot(x, x.pin_memory())
1822 self.assertTrue(x.pin_memory().requires_grad)
1823 gradcheck(
lambda x: x.pin_memory(), [x])
1824 gradgradcheck(
lambda x: x.pin_memory(), [x])
1826 def test_isolated_node(self):
1827 x = torch.randn(5, 5, requires_grad=
True)
1828 y = torch.randn(5, 5, requires_grad=
True)
1831 b = torch.max(a, 1,
True)[1].repeat(1, 5).double()
1835 def test_shape(self):
1836 x = torch.randn(3, 4)
1837 self.assertEqual(2, len(x.shape))
1838 self.assertEqual(x.shape[0], 3)
1839 self.assertEqual(x.shape[1], 4)
1841 def test_numpy_requires_grad(self):
1842 x = torch.randn(2, 2, requires_grad=
True)
1843 self.assertRaisesRegex(RuntimeError,
'requires grad',
lambda: x.numpy())
1845 def test_return_leaf(self):
1846 class Identity(Function):
1848 def forward(self, a, b):
1851 def backward(self, grad_a, grad_b):
1852 return grad_a + grad_b, grad_b
1854 hook_called = [
False]
1855 x = torch.randn(5, 5, requires_grad=
True)
1856 y = torch.randn(5, 5, requires_grad=
True)
1858 q, p = Identity()(x, y)
1862 hook_called[0] =
True 1863 self.assertEqual(grad.data, torch.ones(5, 5))
1865 q.register_hook(hook)
1866 (q + p + x).sum().backward()
1867 self.assertEqual(x.grad.data, torch.ones(5, 5) * 3)
1868 self.assertEqual(y.grad.data, torch.ones(5, 5))
1869 self.assertTrue(hook_called[0])
1871 def test_return_leaf_inplace(self):
1874 def forward(self, a, b):
1876 return a.add_(b), b + 2
1878 def backward(self, grad_a, grad_b):
1879 return grad_a, grad_a + grad_b
1881 x = torch.randn(5, 5)
1882 y = torch.randn(5, 5, requires_grad=
True)
1887 self.assertIs(q.grad_fn, fn)
1888 self.assertTrue(q.requires_grad)
1890 self.assertEqual(y.grad.data, torch.ones(5, 5))
1892 def test_leaf_assignment(self):
1893 x = torch.randn(5, 5)
1894 y = torch.randn(5, requires_grad=
True)
1895 z = torch.randn(5, requires_grad=
True)
1899 self.assertTrue(x.requires_grad)
1900 self.assertIsNot(x.grad_fn,
None)
1902 self.assertEqual(y.grad.data, torch.ones(5))
1903 self.assertEqual(z.grad.data, torch.ones(5) * 2)
1905 def test_no_grad_assignment(self):
1906 x = torch.randn(5, 5, requires_grad=
True)
1908 with torch.no_grad():
1911 self.assertTrue(x.requires_grad)
1912 self.assertIsNone(x.grad_fn)
1914 def test_no_grad_modifies_version(self):
1915 x = torch.randn(5, requires_grad=
True)
1916 y = torch.randn(5, requires_grad=
True)
1918 with torch.no_grad():
1920 self.assertRaisesRegex(RuntimeError,
'modified by an inplace operation',
1921 lambda: z.backward())
1923 def test_no_grad_input(self):
1924 class MyFunction(Function):
1926 def forward(self, x):
1930 def backward(self, grad_output):
1933 x = torch.randn(5, requires_grad=
True)
1934 with torch.no_grad():
1935 y = MyFunction.apply(x)
1937 self.assertTrue(x.requires_grad)
1938 self.assertIsNone(y.grad_fn)
1940 def test_backward_copy(self):
1958 x = torch.ones(5, 5, requires_grad=
True)
1959 y = torch.ones(5, 5, requires_grad=
True)
1977 grad_output = torch.ones(5, 5)
1978 out.backward(grad_output)
1979 self.assertEqual(x.grad, torch.ones(5, 5) * 34)
1980 self.assertEqual(y.grad, torch.ones(5, 5) * 17)
1982 def test_save_none_for_backward(self):
1985 class MyFn(Function):
1987 def forward(self, input):
1988 self.save_for_backward(
None, input,
None)
1989 return input * input
1991 def backward(self, grad_output):
1992 n1, input, n2 = self.saved_tensors
1993 test_case.assertIsNone(n1)
1994 test_case.assertIsNone(n2)
1995 return 2 * input * grad_output
1997 x = torch.randn(5, 5, requires_grad=
True)
2000 self.assertEqual(x.grad, 2 * x)
2002 def test_too_many_grads(self):
2003 class MyFn(Function):
2005 def forward(self, input):
2008 def backward(self, grad_output):
2009 return grad_output,
None,
None 2011 x = torch.randn(5, 5, requires_grad=
True)
2014 self.assertEqual(x.grad, torch.ones_like(x))
2016 def test_pickle(self):
2017 x = torch.randn(10, 10, requires_grad=
True)
2018 y = torch.randn(10, 10, requires_grad=
False)
2020 def assert_strict_equal(var1, var2):
2021 self.assertEqual(var1.data, var2.data)
2022 self.assertEqual(var1.requires_grad, var2.requires_grad)
2024 serialized = [pickle.dumps([x, y], protocol=p)
for p
in range(3)]
2025 for dump
in serialized:
2026 xc, yc = pickle.loads(dump)
2027 assert_strict_equal(xc, x)
2028 assert_strict_equal(yc, y)
2030 def test_dep_nograd(self):
2033 def forward(self, input):
2034 out = torch.randn(input.size())
2035 self.mark_non_differentiable(out)
2038 def backward(self, grad_output, ignored):
2043 def forward(self, input, ignored):
2046 def backward(self, grad_output):
2047 return grad_output,
None 2049 x = torch.randn(5, requires_grad=
True)
2052 self.assertTrue(a.requires_grad)
2053 self.assertFalse(b.requires_grad)
2055 c.backward(torch.ones(c.size()))
2056 self.assertEqual(x.grad.data, torch.ones(x.size()))
2058 def test_set_grad_enabled(self):
2060 with torch.set_grad_enabled(
False):
2062 self.assertFalse(y.requires_grad)
2063 with torch.set_grad_enabled(
True):
2065 self.assertTrue(y.requires_grad)
2066 with torch.set_grad_enabled(
False):
2067 torch.set_grad_enabled(
True)
2069 self.assertTrue(y.requires_grad)
2071 def test_reentrant(self):
2072 y_data = torch.randn(2, 2)
2074 class Reenter(Function):
2076 def forward(ctx, x):
2077 with torch.enable_grad():
2078 ctx.x = Variable(x.data, requires_grad=
True)
2079 ctx.y = Variable(y_data, requires_grad=
True)
2080 ctx.output_var = ctx.x * ctx.y
2081 return ctx.output_var.detach()
2084 def backward(ctx, grad_output):
2085 with torch.enable_grad():
2086 ctx.output_var.sum().backward()
2087 return ctx.x.grad * grad_output
2089 x = torch.randn(2, 2, requires_grad=
True)
2090 out = Reenter.apply(x)
2091 out.sum().backward()
2092 self.assertEqual(x.grad.data, y_data)
2094 def test_broadcast_tensors(self):
2095 f_args_variable = (torch.randn(3, requires_grad=
True),
2096 torch.randn(1, 2, 1, requires_grad=
True),
2097 torch.randn(1, 1, requires_grad=
True),
2098 torch.randn(5, 1, 1, requires_grad=
True))
2099 f_args_tensor = deepcopy(unpack_variables(f_args_variable))
2100 run_functional_checks(self,
"test_broadcast_tensors",
"broadcast",
2101 lambda a, b, c, d: torch.broadcast_tensors(a, b, c, d),
2102 True, f_args_variable, f_args_tensor)
2105 f_args_variable = (torch.randn(1, S, S, requires_grad=
True),
2106 torch.randn(2, S, S, requires_grad=
True),
2107 torch.randn(3, S, S, requires_grad=
True),
2109 f_args_tensor = deepcopy(unpack_variables(f_args_variable))
2110 run_functional_checks(self,
"test_cat",
"cat",
2111 lambda a, b, c, dim: torch.cat((a, b, c), dim),
2112 True, f_args_variable, f_args_tensor)
2114 def test_cat_negdim_1(self):
2115 f_args_variable = (torch.randn(S, S, 1, requires_grad=
True),
2116 torch.randn(S, S, 2, requires_grad=
True),
2117 torch.randn(S, S, 3, requires_grad=
True),
2119 f_args_tensor = deepcopy(unpack_variables(f_args_variable))
2120 run_functional_checks(self,
"test_cat_negdim_1",
"cat",
2121 lambda a, b, c, dim: torch.cat((a, b, c), dim),
2122 True, f_args_variable, f_args_tensor)
2124 def test_cat_negdim_2(self):
2125 f_args_variable = (torch.randn(S, 1, S, requires_grad=
True),
2126 torch.randn(S, 2, S, requires_grad=
True),
2127 torch.randn(S, 3, S, requires_grad=
True),
2129 f_args_tensor = deepcopy(unpack_variables(f_args_variable))
2130 run_functional_checks(self,
"test_cat_negdim_2",
"cat",
2131 lambda a, b, c, dim: torch.cat((a, b, c), dim),
2132 True, f_args_variable, f_args_tensor)
2134 def test_cat_empty_legacy(self):
2135 f_args_variable = (torch.randn(0, requires_grad=
True),
2136 torch.randn(S, S, requires_grad=
True))
2139 f_args_tensor = deepcopy(unpack_variables(f_args_variable))
2140 run_functional_checks(self,
"test_cat_empty_legacy",
"cat",
2141 lambda a, b: torch.cat((a, b)),
2142 False, f_args_variable, f_args_tensor)
2143 self.assertTrue(gradcheck(
lambda a, b: torch.cat((a, b)), f_args_variable, eps=1e-6, atol=PRECISION))
2145 def test_cat_empty(self):
2146 f_args_variable = (torch.randn(0, S, requires_grad=
True),
2147 torch.randn(S, S, requires_grad=
True))
2148 f_args_tensor = deepcopy(unpack_variables(f_args_variable))
2149 run_functional_checks(self,
"test_cat_empty",
"cat",
2150 lambda a, b: torch.cat((a, b)),
2151 True, f_args_variable, f_args_tensor)
2153 def test_cdist(self):
2154 for p
in [0, 1, 2, 3, 1.5, 2.5, float(
'inf')]:
2155 f_args_variable = (torch.randn(S, S, requires_grad=
True),
2156 torch.randn(S, S, requires_grad=
True))
2157 f =
lambda a, b: torch.cdist(a, b, p)
2158 f_args_tensor = deepcopy(unpack_variables(f_args_variable))
2159 run_functional_checks(self,
"test_cdist",
"cdist", f,
2160 True, f_args_variable, f_args_tensor)
2163 def test_cholesky(self):
2165 x = torch.matmul(root, root.transpose(-1, -2)) + 1e-05
2166 return torch.cholesky(x, upper)
2168 def run_test(upper, dims):
2169 root = torch.rand(*dims)
2170 indices = torch.ones(dims[-1], dims[-1], dtype=torch.uint8).tril()
2171 indices = indices.expand_as(root)
2173 root.requires_grad_()
2175 gradcheck(func, [root])
2176 gradgradcheck(func, [root])
2178 for upper, dims
in product([
True,
False], [(3, 3), (4, 3, 2, 2)]):
2179 run_test(upper, dims)
2180 run_test(upper, dims)
2183 def test_trtrs(self):
2184 def _test_with_size(A_dims, B_dims):
2185 A = torch.rand(*A_dims).requires_grad_()
2186 b = torch.rand(*B_dims).requires_grad_()
2188 for upper, transpose, unitriangular
in product((
True,
False), repeat=3):
2190 return torch.trtrs(b, A, upper, transpose, unitriangular)
2192 gradcheck(func, [A, b])
2193 gradgradcheck(func, [A, b])
2195 _test_with_size((3, 3), (3, 4))
2196 _test_with_size((3, 3), (3, 2))
2197 _test_with_size((2, 3, 3), (2, 3, 4))
2198 _test_with_size((2, 3, 3), (2, 3, 2))
2200 @unittest.skipIf(
not TEST_MKL,
"PyTorch is built without MKL support")
2201 def test_fft_ifft_rfft_irfft(self):
2202 def _test_complex(sizes, signal_ndim):
2203 x = torch.randn(sizes, requires_grad=
True, dtype=torch.double)
2205 for normalized
in (
True,
False):
2207 return x.fft(signal_ndim, normalized=normalized)
2210 gradgradcheck(fft, [x], gen_non_contig_grad_outputs=
True)
2213 return fx.ifft(signal_ndim, normalized=normalized)
2216 fx = fft(x).detach()
2217 fx.requires_grad =
True 2218 gradcheck(ifft, [fx])
2219 gradgradcheck(ifft, [fx], gen_non_contig_grad_outputs=
True)
2221 def _test_real(sizes, signal_ndim):
2222 x = torch.randn(sizes, requires_grad=
True, dtype=torch.double)
2223 if x.dim() == signal_ndim:
2227 signal_sizes = x.size()[start_dim:start_dim + signal_ndim]
2229 for normalized, onesided
in product((
True,
False), repeat=2):
2231 return x.rfft(signal_ndim, normalized=normalized, onesided=onesided)
2233 gradcheck(rfft, [x])
2234 gradgradcheck(rfft, [x], gen_non_contig_grad_outputs=
True)
2254 if signal_ndim == 1
and onesided:
2256 return fx.irfft(signal_ndim, normalized=normalized,
2257 onesided=onesided, signal_sizes=signal_sizes)
2260 fx = rfft(x).detach()
2261 fx.requires_grad =
True 2262 gradcheck(irfft, [fx])
2263 gradgradcheck(irfft, [fx], gen_non_contig_grad_outputs=
True)
2268 z = torch.randn(sizes, dtype=torch.double)
2269 fz = z.rfft(signal_ndim, normalized=normalized, onesided=onesided)
2272 fx = x.rfft(signal_ndim, normalized=normalized, onesided=onesided)
2274 return y.irfft(signal_ndim, normalized=normalized,
2275 onesided=onesided, signal_sizes=signal_sizes)
2277 gradcheck(rfft_irfft, [x])
2278 gradgradcheck(rfft_irfft, [x], gen_non_contig_grad_outputs=
True)
2280 _test_real((2, 10), 1)
2281 _test_real((2, 3, 4), 2)
2282 _test_real((2, 3, 4, 3), 3)
2284 _test_complex((2, 2, 10, 2), 1)
2285 _test_complex((1, 2, 3, 4, 2), 2)
2286 _test_complex((2, 1, 3, 4, 3, 2), 3)
2288 def test_variable_traverse(self):
2289 def get_out_and_unrefed_cycle():
2290 inp = torch.randn(10, requires_grad=
True)
2291 tmp = inp.view(10, 1)
2298 my_list.append(my_list)
2302 out = get_out_and_unrefed_cycle()
2305 out.backward(torch.randn(out.size()))
2307 def test_norm_subgradient(self):
2308 def run_test(input_size, norm_deg):
2309 input = torch.zeros(*input_size, requires_grad=
True)
2310 input.norm(norm_deg).backward()
2311 self.assertEqual(input.grad.data.abs().sum(), 0)
2314 run_test((10, 10), 2)
2317 run_test((10,), 1.5)
2319 def test_pow_zero_tensor_gradient(self):
2320 def run_test(input_size, exponent):
2321 input = torch.zeros(*input_size, requires_grad=
True)
2322 input.pow(exponent).sum().backward()
2323 self.assertEqual(input.grad.data.abs().sum(), 0)
2325 run_test((10,), torch.zeros(10))
2326 run_test((10, 10), torch.zeros(10, 10))
2329 def test_pow_scalar_base(self):
2330 a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_()
2331 gradcheck(
lambda a: torch.pow(2, a), (a,))
2334 def test_pdist_large(self):
2336 return torch.pdist(x, p=2)
2339 for device
in devices:
2344 x = torch.randn(shape, device=device).requires_grad_()
2345 output = torch.pdist(x, p=2)
2347 output.sum().backward()
2350 def test_pinverse(self):
2360 U = torch.randn(n, m).qr()[0].t()
2361 V = torch.randn(n, m).qr()[0].t()
2364 S = torch.cat([x, torch.zeros(n - m)], 0)
2365 M = U.mm(torch.diag(S)).mm(V.t())
2368 gradcheck(func, [torch.rand(m).add_(1).requires_grad_()])
2369 gradcheck(func, [torch.rand(m).add_(10).requires_grad_()])
2370 gradgradcheck(func, [torch.rand(m).add_(1).requires_grad_()])
2371 gradgradcheck(func, [torch.rand(m).add_(10).requires_grad_()])
2373 def test_chain_matmul(self):
2374 def gen_matrices(p):
2376 for (pi, pi_1)
in zip(p[:-1], p[1:]):
2377 matrices.append(torch.randn(pi, pi_1).requires_grad_())
2380 gradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5]))
2381 gradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6]))
2382 gradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10]))
2383 gradgradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5]))
2384 gradgradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6]))
2385 gradgradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10]))
2387 def test_profiler(self):
2388 x = torch.randn(10, 10)
2394 names = [
'mul',
'add']
2395 self.assertEqual(len(p.function_events), len(names))
2396 for info, expected_name
in zip(p.function_events, names):
2397 self.assertGreater(info.cpu_interval.start, last_end)
2398 self.assertEqual(info.name, expected_name)
2399 last_end = info.cpu_interval.end
2402 x = torch.randn(10, 10)
2404 self.assertIn(
'shape', keys)
2407 self.assertTrue(hasattr(x, key))
2409 def test_as_strided(self):
2411 def test(x, prepro_fn, size, strides, offset=None):
2412 x = x.to(torch.double).detach().requires_grad_()
2416 with torch.no_grad():
2417 y = prepro_fn(x)
if prepro_fn
is not None else x
2418 max_offset = sum((si - 1) * st
for si, st
in zip(size, strides))
2419 max_offset += offset
if offset
is not None else y.storage_offset()
2420 assert max_offset < len(y.storage()),
"test case resizes storage" 2423 if prepro_fn
is not None:
2425 return x.as_strided(size, strides, offset)
2427 gradcheck(closure, [x])
2428 gradgradcheck(closure, [x])
2431 test(torch.arange(0, 25),
lambda x: x.view(5, 5), [3, 3], [6, 2], 2)
2434 test(torch.randn(12),
None, [1, 2, 1, 5], [0, 5, 100, 1], 2)
2437 test(torch.randn(5),
None, [3, 3, 3], [0, 1, 0], 2)
2438 test(torch.randn(5),
None, [3, 3, 3], [0, 0, 0], 4)
2439 test(torch.randn(5),
lambda x: x.expand(5, 5), [5, 5], [0, 1], 0)
2442 test(torch.randn(35),
None, [6, 6], [5, 1], 2)
2443 test(torch.randn(15),
None, [3, 2], [3, 6], 2)
2446 test(torch.randn(3, 4),
None, [4, 3], [1, 4])
2449 x = torch.randn(6, 2)
2450 test(x[3:],
None, [3, 2], [2, 1], 0)
2451 self.assertEqual(x[3:].as_strided([3, 2], [2, 1], 0), x[:3])
2454 test(torch.randn(2, 3),
lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0)
2456 def _test_where_functional(self, t):
2457 x = Variable(t(torch.randn(5, 5)), requires_grad=
True)
2458 y = Variable(t(torch.randn(5, 5)), requires_grad=
True)
2459 cond = Variable(t(mask_not_all_zeros((5, 5))), requires_grad=
False)
2461 def where(cond, x, y):
2462 return torch.where(cond, x, y)
2464 gradcheck(where, [cond, x, y], raise_exception=
True)
2465 gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5)))])
2467 x = Variable(t(torch.randn(5, 1, 5)), requires_grad=
True)
2468 y = Variable(t(torch.randn(5, 5, 1)), requires_grad=
True)
2469 gradcheck(where, [cond, x, y], raise_exception=
True)
2470 gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5, 5)))])
2472 def test_where_functional(self):
2476 def test_where_functional_cuda(self):
2479 def _test_lerp_tensor_weights(self, cast):
2480 def construct_inputs(*shapes):
2481 start = cast(torch.randn(shapes[0])).requires_grad_()
2482 end = cast(torch.randn(shapes[1])).requires_grad_()
2483 weight = cast(torch.randn(shapes[2]))
2484 return [start, end, weight]
2486 all_test_shapes = [((3, 3, 3), (3, 3, 3), (3, 3, 3)),
2487 ((3,), (3, 3, 3), (3, 3, 3)),
2488 ((3, 3, 3), (3,), (3, 3, 3)),
2489 ((3, 3, 3), (3, 3, 3), (3,)),
2490 ((), (3, 3, 3), (3, 3, 3)),
2491 ((3, 3, 3), (), (3, 3, 3)),
2492 ((3, 3, 3), (3, 3, 3), ()),
2493 ((3, 3), (3, 3, 3), (3,))]
2495 for shapes
in all_test_shapes:
2496 cur_inputs = construct_inputs(*shapes)
2497 gradcheck(torch.lerp, cur_inputs)
2498 gradgradcheck(torch.lerp, cur_inputs)
2500 def test_lerp_tensor_weights(self):
2503 def test_reduce_dtype(self):
2504 def test_reduction(op, has_no_dim):
2505 x = torch.randn(3, 3, dtype=torch.float, requires_grad=
True)
2510 self.assertEqual(grad1, grad2)
2511 self.assertEqual(grad2.dtype, torch.float)
2513 gi = torch.randn(op(x, dim=0).shape, dtype=torch.float)
2516 self.assertEqual(grad1, grad2)
2517 self.assertEqual(grad2.dtype, torch.float)
2519 test_reduction(torch.sum,
True)
2520 test_reduction(torch.prod,
True)
2521 test_reduction(torch.cumsum,
False)
2522 test_reduction(torch.cumprod,
False)
2524 def test_inplace_view_backprop_base(self):
2526 root = torch.randn(2, 2, requires_grad=
True)
2528 v1 = x.narrow(0, 0, 1)
2531 self.assertEqual(root.grad.data.tolist(), [[2, 2], [1, 1]])
2533 def test_inplace_view_backprop_view_of_view(self):
2535 root = torch.randn(2, 2, requires_grad=
True)
2537 v1 = x.narrow(0, 0, 1)
2538 v2 = x.narrow(0, 0, 1)
2541 self.assertEqual(root.grad.data.tolist(), [[2, 2], [0, 0]])
2543 def test_inplace_view_of_view(self):
2545 root = torch.randn(2, 2, requires_grad=
True)
2547 v1 = x.narrow(0, 0, 1)
2548 v2 = v1.narrow(1, 1, 1)
2551 self.assertEqual(root.grad.data.tolist(), [[1, 2], [1, 1]])
2553 def test_inplace_view_gradcheck(self):
2555 a = torch.randn(4, 4, requires_grad=
True)
2556 b = torch.randn(2, 2, requires_grad=
True)
2560 x.narrow(1, 2, 2).narrow(0, 1, 2).mul_(b)
2561 x.narrow(1, 0, 2).narrow(0, 1, 2).mul_(b)
2564 gradcheck(func, [a, b], raise_exception=
True)
2565 go = torch.randn(a.size(), requires_grad=
True)
2566 gradgradcheck(func, (a, b), (go,))
2568 def test_inplace_view_makes_base_require_grad(self):
2570 a = torch.randn(4, 4, requires_grad=
False)
2571 b = torch.randn(4, 2, requires_grad=
True)
2575 self.assertFalse(x.requires_grad)
2576 x.narrow(1, 2, 2).mul_(b)
2577 self.assertTrue(x.requires_grad)
2580 gradcheck(func, [a, b], raise_exception=
True)
2581 go = torch.randn(a.size(), requires_grad=
True)
2582 gradgradcheck(func, (a, b), (go,))
2584 def test_inplace_view_backprop_view(self):
2586 a = Variable(torch.Tensor([2, 5]), requires_grad=
False)
2587 b = Variable(torch.Tensor([3]), requires_grad=
True)
2588 res = a.narrow(0, 1, 1).mul_(b)
2589 res.sum().backward()
2590 self.assertEqual(b.grad.data.tolist(), [5])
2591 self.assertIsNone(a.grad)
2593 def test_inplace_view_modify_base(self):
2597 r = torch.ones(1, requires_grad=
True)
2602 self.assertFalse(v.requires_grad)
2603 self.assertIsNone(v.grad_fn)
2605 self.assertTrue(v.requires_grad)
2609 gradgradcheck(fn, [r])
2611 def test_inplace_view_python(self):
2613 a = torch.randn(4, 4, requires_grad=
True)
2614 b = torch.randn(2, 2, requires_grad=
True)
2616 class PyAdd(torch.autograd.Function):
2618 def forward(ctx, x, y):
2624 def backward(ctx, grad):
2629 PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b)
2630 PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b)
2633 gradcheck(func, [a, b], raise_exception=
True)
2634 go = torch.randn(a.size(), requires_grad=
True)
2635 gradgradcheck(func, (a, b), (go,))
2637 def test_inplace_view_non_contig(self):
2638 data = torch.ones(2, 3, 2).select(2, 1).t()
2639 root = Variable(data, requires_grad=
True)
2641 v1 = x.narrow(0, 0, 1)
2642 v2 = v1.narrow(1, 1, 1)
2645 self.assertEqual(root.grad.data.tolist(), [[1, 2], [1, 1], [1, 1]])
2647 def test_inplace_view_saved_output(self):
2652 class IncrementOnDelete(object):
2657 root = torch.randn(3, 3, requires_grad=
True)
2659 copy.grad_fn.register_hook(IncrementOnDelete())
2664 self.assertEqual(dealloc[0], 1)
2666 def test_mul_out(self):
2667 a = torch.randn(2, 2, requires_grad=
True)
2668 b = torch.randn(2, 2, requires_grad=
True)
2669 x = torch.zeros_like(a)
2672 self.assertRaisesRegex(RuntimeError,
'out=',
lambda: torch.mul(a, b, out=x))
2675 with torch.no_grad():
2676 torch.mul(a, b, out=x)
2677 self.assertEqual(x, a * b)
2679 def test_mul_out_result_requires_grad(self):
2680 a = torch.randn(2, 2)
2681 b = torch.randn(2, 2)
2682 x = torch.zeros(2, 2, requires_grad=
True)
2684 self.assertRaisesRegex(RuntimeError,
'out=',
lambda: torch.mul(a, b, out=x))
2686 def test_diagonal_derivative_requires_grad(self):
2692 a = torch.randn(5, 6, requires_grad=
True)
2693 b = torch.diagonal(a)**2
2696 self.assertTrue(d.requires_grad)
2699 def _test_set_requires_grad_only_for_floats(self, cuda):
2700 dtypes = [torch.int64, torch.int32, torch.int16, torch.int8,
2701 torch.float, torch.double]
2703 dtypes.append(torch.half)
2706 a = torch.ones(1, dtype=dt, device=
'cuda' if cuda
else 'cpu')
2710 a = torch.ones(1, dtype=dt, device=
'cuda' if cuda
else 'cpu')
2711 a.requires_grad =
True 2714 torch.ones(1, dtype=dt, device=
'cuda' if cuda
else 'cpu', requires_grad=
True)
2717 a = torch.ones(1, dtype=dt, device=
'cuda' if cuda
else 'cpu')
2718 a.requires_grad =
False 2719 a.requires_grad_(
False)
2721 for f
in [f1, f2, f3]:
2722 if dt.is_floating_point:
2725 with self.assertRaisesRegex(RuntimeError,
'floating point',
2726 msg=
"dt: {} device: {}".format(a.dtype, a.device)):
2730 def test_set_requires_grad_only_for_floats_cuda(self):
2733 def test_set_requires_grad_only_for_floats(self):
2737 def test_rnn_backward_to_input_but_not_parameters_cuda(self):
2740 dev = torch.device(
'cuda')
2741 l = torch.nn.LSTM(2, 3).to(dev)
2742 for p
in l.parameters():
2743 p.requires_grad =
False 2744 s = torch.randn(1, 1, 2, requires_grad=
True, device=dev)
2746 out.sum().backward()
2747 self.assertFalse(s.grad
is None or s.grad.abs().sum().item() == 0)
2750 def test_lstmcell_backward_only_one_output_grad(self):
2753 dev = torch.device(
'cuda')
2754 l = torch.nn.LSTMCell(2, 3).to(dev).double()
2755 s = torch.randn(1, 2, device=dev, dtype=torch.double, requires_grad=
True)
2758 out.sum().backward()
2759 self.assertFalse(s.grad
is None or s.grad.abs().sum().item() == 0)
2761 def test_anomaly_detect_nan(self):
2764 class MyFunc(Function):
2766 def forward(ctx, inp1, inp2, fail_0th):
2767 ctx.fail_0th = fail_0th
2768 return inp1.sum(0, keepdim=
True)
2771 def backward(ctx, gO):
2772 gI = gO.clone().expand(size)
2776 return gI,
None,
None 2778 return None, gI,
None 2780 inp = torch.rand(size, requires_grad=
True)
2781 out = MyFunc.apply(inp, inp,
True)
2784 inp = torch.rand(size, requires_grad=
True)
2785 out = MyFunc.apply(inp, inp,
True)
2786 with self.assertRaisesRegex(RuntimeError,
"Function 'MyFuncBackward' returned nan values in its 0th output."):
2787 with warnings.catch_warnings(record=
True)
as w:
2788 with detect_anomaly():
2790 self.assertIn(
'No forward pass information', str(w[0].message))
2792 inp = torch.rand(size, requires_grad=
True)
2793 with self.assertRaisesRegex(RuntimeError,
"Function 'MyFuncBackward' returned nan values in its 1th output."):
2794 with warnings.catch_warnings(record=
True)
as w:
2795 with detect_anomaly():
2796 out = MyFunc.apply(inp, inp,
False)
2798 self.assertIn(
'MyFunc.apply', str(w[0].message))
2801 def test_symeig_no_eigenvectors(self):
2802 A =
torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=
True)
2803 w, v = torch.symeig(A, eigenvectors=
False)
2804 with self.assertRaisesRegex(RuntimeError,
'cannot compute backward'):
2808 def test_svd_no_singularvectors(self):
2809 A = torch.randn(2, 2, dtype=torch.float32, requires_grad=
True)
2810 u, s, v = torch.svd(A, compute_uv=
False)
2811 with self.assertRaisesRegex(RuntimeError,
'cannot compute backward'):
2814 def test_no_grad_copy(self):
2816 class MyFunc(Function):
2817 static_grad_ptr =
None 2820 def forward(ctx, inp1, inp2):
2824 def backward(ctx, grad):
2825 MyFunc.static_grad_ptr = grad.data_ptr()
2828 class NonContGradFunc(Function):
2830 def forward(ctx, inp1):
2831 ctx.size = inp1.size()
2835 def backward(ctx, grad):
2836 return torch.ones(1).expand(ctx.size)
2838 a = torch.randn(5, 6, requires_grad=
True)
2839 b = torch.randn(5, 6, requires_grad=
True)
2841 NonContGradFunc.apply(MyFunc.apply(a, b)).backward()
2842 self.assertFalse(a.grad.data_ptr() == MyFunc.static_grad_ptr)
2843 self.assertFalse(b.grad.data_ptr() == MyFunc.static_grad_ptr)
2845 a.grad = b.grad =
None 2846 MyFunc.apply(a, b)[1][0].backward()
2847 p_g = MyFunc.static_grad_ptr
2848 p_a = a.grad.data_ptr()
2849 p_b = b.grad.data_ptr()
2851 self.assertFalse(p_a == p_b)
2853 self.assertTrue(p_a == p_g
or p_b == p_g)
2855 def test_gradcheck_single_input(self):
2859 gradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=
True))
2860 gradgradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=
True))
2862 def test_gradcheck_sparse_input(self):
2866 gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(
True), check_sparse_nnz=
True)
2867 with self.assertRaisesRegex(RuntimeError,
'gradcheck expects all tensor inputs are dense'):
2868 gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(
True), check_sparse_nnz=
False)
2870 @unittest.skipIf(
not TEST_CUDA,
"Requires cuda for multi device")
2871 def test_multi_device_reentrant_autograd(self):
2876 dummy = inp * 2 * 2 * 2 * 2
2879 def parent_on_cpu(inp):
2883 branch1 = inp.cuda()
2884 branch1 = branch1 / branch1
2885 branch1 = branch1 / branch1
2886 branch1 = branch1 / branch1
2891 out = branch2 + branch1
2894 inp = torch.rand(2, requires_grad=
True)
2895 out = parent_on_cpu(inp)
2898 out.sum().backward()
2901 def index_variable(shape, max_indices):
2902 if not isinstance(shape, tuple):
2904 index = torch.rand(*shape).mul_(max_indices).floor_().long()
2908 def index_perm_variable(shape, max_indices):
2909 if not isinstance(shape, tuple):
2912 index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape)
2916 def gather_variable(shape, index_dim, max_indices, duplicate=False):
2917 assert len(shape) == 2
2918 assert index_dim < 2
2919 batch_dim = 1 - index_dim
2920 index = torch.LongTensor(*shape)
2921 for i
in range(shape[index_dim]):
2922 index.select(index_dim, i).copy_(
2923 torch.randperm(max_indices)[:shape[batch_dim]])
2925 index.select(batch_dim, 0).copy_(index.select(batch_dim, 1))
2929 def bernoulli_scalar():
2933 def gradgradcheck_method_precision_override(test_name):
2935 gradgradcheck_precision_override = {
2936 'test_norm': {
'atol': 2e-2,
'rtol': 1e-2},
2937 'test_norm_1_5': {
'atol': 1.5e-2,
'rtol': 1e-2},
2938 'test_norm_3': {
'atol': 5e-2,
'rtol': 1e-2},
2939 'test_dist': {
'atol': 5e-2,
'rtol': 1e-2},
2940 'test_dist_4': {
'atol': 8e-2,
'rtol': 1e-2},
2942 non_broadcasted_test_name = test_name.split(
"_broadcast")[0]
2943 override = gradgradcheck_precision_override.get(non_broadcasted_test_name)
2945 if 'broadcast_lhs' in test_name
or 'broadcast_rhs' in test_name:
2947 override = {
'atol': override[
'atol'] * S,
'rtol': override[
'atol'] * S}
2948 elif 'broadcast_all' in test_name:
2950 override = {
'atol': override[
'atol'] * S * S,
'rtol': override[
'atol'] * S * S}
2954 def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable,
2955 input_variables, run_gradgradcheck=
True):
2956 test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION))
2957 if name
in EXCLUDE_GRADGRADCHECK
or test_name
in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME:
2959 gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name)
2960 if gradgradcheck_precision_override
is not None:
2961 atol = gradgradcheck_precision_override[
'atol']
2962 rtol = gradgradcheck_precision_override[
'rtol']
2963 test_case.assertTrue(gradgradcheck(apply_method, input_variables,
None, atol=atol, rtol=rtol,
2964 gen_non_contig_grad_outputs=
True))
2966 test_case.assertTrue(gradgradcheck(apply_method, input_variables, gen_non_contig_grad_outputs=
True))
2969 def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
2970 f_args_variable, f_args_tensor):
2971 output_variable = apply_fn(*f_args_variable)
2974 run_grad_and_gradgrad_checks(test_case, name, test_name, apply_fn,
2975 output_variable, f_args_variable)
2977 self_variable = f_args_variable[0]
2978 if isinstance(output_variable, torch.Tensor)
and output_variable.requires_grad
and self_variable
is not None:
2979 output_variable.backward(randn_like(output_variable))
2980 test_case.assertEqual(self_variable.type(), self_variable.grad.type())
2981 test_case.assertEqual(self_variable.size(), self_variable.grad.size())
2991 output_process_fn=
lambda x: x,
2993 kwargs = kwargs
if kwargs
else {}
2994 basic_test_name =
'test_' + name
2995 if variant_name !=
'':
2996 basic_test_name +=
'_' + variant_name
2998 for dim_perm
in product([-1, 1], repeat=len(dim_args_idx)):
2999 test_name = basic_test_name
3000 new_args = [arg * dim_perm[dim_args_idx.index(i)]
if i
in dim_args_idx
else arg
for i, arg
in enumerate(args)]
3001 test_name = basic_test_name +
''.join(
'_neg' + str(i)
for i, idx
in enumerate(dim_perm)
if idx < 0)
3002 new_args = tuple(new_args)
3006 def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name,
3007 output_process_fn=output_process_fn):
3009 is_magic_method = name[:2] ==
'__' and name[-2:] ==
'__' 3010 is_inplace = name[-1] ==
"_" and not is_magic_method
3011 self_variable = create_input((self_size,))[0][0]
3014 self_variable.requires_grad =
False 3016 args_variable, kwargs_variable = create_input(args, requires_grad=
not is_inplace, call_kwargs=kwargs)
3017 self_tensor = deepcopy(self_variable.data)
3018 args_tensor = deepcopy(unpack_variables(args_variable))
3019 output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
3020 if not exclude_tensor_method(name, test_name):
3021 output_tensor = getattr(self_tensor, name)(*args_tensor, **kwargs_variable)
3022 if not isinstance(output_tensor, torch.Tensor)
and not istuple(output_tensor):
3023 output_tensor = torch.DoubleTensor((output_tensor,))
3024 self.assertEqual(unpack_variables(output_variable), output_tensor)
3028 output = getattr(inputs[0], name)(*inputs[1:], **kwargs)
3029 return output_process_fn(output)
3031 if not is_inplace
and name
not in EXCLUDE_GRADCHECK:
3032 run_grad_and_gradgrad_checks(self, name, test_name, fn,
3033 output_variable, (self_variable,) + args_variable)
3036 if hasattr(torch, name)
and name
not in EXCLUDE_FUNCTIONAL:
3038 output = getattr(torch, name)(*inputs)
3039 return output_process_fn(output)
3041 f_args_variable = (self_variable,) + args_variable
3042 f_args_tensor = (self_tensor,) + args_tensor
3044 run_functional_checks(self, test_name, name, fn,
3045 False, f_args_variable, f_args_tensor)
3049 self_variable = create_input((self_size,), requires_grad=
True)[0][0]
3050 args_variable, kwargs_variable = create_input(args, requires_grad=
False, call_kwargs=kwargs)
3051 output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
3052 if isinstance(output_variable, torch.autograd.Variable):
3053 output_variable.backward(randn_like(output_variable))
3054 self.assertTrue(type(self_variable.data) == type(self_variable.grad.data))
3055 self.assertTrue(self_variable.size() == self_variable.grad.size())
3058 inplace_name = name +
'_' 3060 skip_inplace = (
'broadcast_lhs' in test_name
or 3061 'broadcast_all' in test_name)
3062 if hasattr(torch.ones(1), inplace_name)
and not skip_inplace:
3063 output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
3064 if not isinstance(output_variable, tuple):
3065 output_variable = (output_variable,)
3066 inplace_self_variable = deepcopy(self_variable)
3067 inplace_self_variable_copy = tuple(i.clone()
if isinstance(i, torch.Tensor)
else i
3068 for i
in (inplace_self_variable,))
3069 inplace_args_variable = deepcopy(args_variable)
3070 inplace_args_variable_copy = tuple(i.clone()
if isinstance(i, torch.Tensor)
else i
3071 for i
in inplace_args_variable)
3073 inplace_output_variable = (
3074 getattr(inplace_self_variable_copy[0], inplace_name)(*inplace_args_variable_copy,
3076 if not isinstance(inplace_output_variable, tuple):
3077 inplace_output_variable = (inplace_output_variable,)
3078 self.assertEqual(inplace_output_variable, output_variable)
3080 for inp_i, i
in zip((inplace_self_variable,) + inplace_args_variable,
3081 (self_variable,) + args_variable):
3082 if not isinstance(inp_i, torch.Tensor):
3083 assert not isinstance(i, torch.Tensor)
3085 if inp_i.grad
is not None:
3086 inp_i.grad.data.zero_()
3087 if i.grad
is not None:
3089 for io, o
in zip(inplace_output_variable, output_variable):
3090 grad = randn_like(io).double()
3093 for inp_i, i
in zip((inplace_self_variable,) + inplace_args_variable,
3094 (self_variable,) + args_variable):
3095 if not isinstance(inp_i, torch.Tensor):
3097 self.assertEqual(inp_i.grad, i.grad)
3100 inplace_name = name +
'_' 3102 broadcast_skip_inplace =
'broadcast_lhs' in test_name
or 'broadcast_all' in test_name
3103 if hasattr(torch.ones(1), inplace_name)
and not broadcast_skip_inplace:
3106 assert not hasattr(TestAutograd, test_name),
'Two tests have the same name: ' + test_name
3108 for skip
in skipTestIf:
3109 do_test = skip(do_test)
3111 setattr(TestAutograd, test_name, do_test)
3113 for test
in method_tests():
3116 if __name__ ==
'__main__':
def relu(input, inplace=False)
def _test_setitem_tensor(self, size, index)
Module caffe2.python.scope.
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
Module caffe2.python.checkpoint.
def test_gc_in_destructor(self)
def _test_pyscalar_conversions(self, t, integral_conv)
def test_no_grad_python_function(self)
def memory_allocated(device=None)
def _test_lerp_tensor_weights(self, cast)
def _function_test(self, cls)
def _test_type_conversion_backward(self, t)
def _test_set_requires_grad_only_for_floats(self, cuda)
def _test_setitem(self, size, index)
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)
def sum(input, dim=None, dtype=None)
def _test_sparse_gather(self, size_x, size_ind, dim)
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)
def _test_where_functional(self, t)