1 #include <gtest/gtest.h> 3 #include <torch/nn/init.h> 4 #include <torch/nn/modules/linear.h> 5 #include <torch/types.h> 6 #include <torch/utils.h> 8 #include <test/cpp/api/support.h> 10 TEST(NoGradTest, SetsGradModeCorrectly) {
11 torch::manual_seed(0);
13 torch::nn::Linear model(5, 2);
14 auto x = torch::randn({10, 5}, torch::requires_grad());
15 auto y = model->forward(x);
19 ASSERT_FALSE(model->weight.grad().defined());
24 x = torch::randn({3, 3}, torch::requires_grad());
25 y = torch::randn({3, 3});
33 ASSERT_TRUE(x.grad().allclose(y));
36 TEST_F(
AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
38 ASSERT_TRUE(x.grad().allclose(y));
42 z.sum().
backward(torch::ones({}) * 2);
43 ASSERT_TRUE(x.grad().allclose(y * 2));
void backward(c10::optional< Tensor > gradient=c10::nullopt, bool keep_graph=false, bool create_graph=false)
Computes the gradient of current tensor w.r.t. graph leaves.