Caffe2 - C++ API
A deep learning, cross platform ML framework
misc.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/nn/init.h>
4 #include <torch/nn/modules/linear.h>
5 #include <torch/types.h>
6 #include <torch/utils.h>
7 
8 #include <test/cpp/api/support.h>
9 
10 TEST(NoGradTest, SetsGradModeCorrectly) {
11  torch::manual_seed(0);
12  torch::NoGradGuard guard;
13  torch::nn::Linear model(5, 2);
14  auto x = torch::randn({10, 5}, torch::requires_grad());
15  auto y = model->forward(x);
16  torch::Tensor s = y.sum();
17 
18  s.backward();
19  ASSERT_FALSE(model->weight.grad().defined());
20 }
21 
23  AutogradTest() {
24  x = torch::randn({3, 3}, torch::requires_grad());
25  y = torch::randn({3, 3});
26  z = x * y;
27  }
28  torch::Tensor x, y, z;
29 };
30 
31 TEST_F(AutogradTest, CanTakeDerivatives) {
32  z.backward();
33  ASSERT_TRUE(x.grad().allclose(y));
34 }
35 
36 TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
37  z.sum().backward();
38  ASSERT_TRUE(x.grad().allclose(y));
39 }
40 
41 TEST_F(AutogradTest, CanPassCustomGradientInputs) {
42  z.sum().backward(torch::ones({}) * 2);
43  ASSERT_TRUE(x.grad().allclose(y * 2));
44 }
void backward(c10::optional< Tensor > gradient=c10::nullopt, bool keep_graph=false, bool create_graph=false)
Computes the gradient of current tensor w.r.t. graph leaves.
Definition: TensorMethods.h:49