1 #include <gtest/gtest.h> 3 #include <torch/nn/init.h> 4 #include <torch/nn/modules/linear.h> 6 #include <test/cpp/api/init_baseline.h> 7 #include <test/cpp/api/support.h> 12 void check_exact_values(
13 const std::vector<torch::Tensor>& parameters,
17 for (
size_t i = 0; i < parameters.size(); i++) {
18 auto layerParameters = parameters[i];
21 if (layerParameters.size(0) != expectedLayerParameters.size()) {
22 std::cout <<
"layer #" << i
23 <<
" layerParameters size: " << layerParameters.size(0)
25 <<
" expectedLayerParameters size: " 26 << expectedLayerParameters.size() << std::endl;
30 for (
size_t p = 0; p < layerParameters.size(0); p++) {
31 auto tensor = layerParameters[p];
32 auto expectedTensor = expectedLayerParameters[p];
34 if (!tensor.allclose(expectedTensor, 1e-3, 5e-4)) {
35 std::cout <<
"layer " << i <<
": " << tensor <<
" != " << expectedTensor
36 <<
" (parameter " << p <<
")" << std::endl;
43 void check_initializer_against_baseline(
45 std::vector<std::vector<torch::Tensor>> expected) {
46 torch::manual_seed(0);
48 auto layer1 = torch::nn::Linear(7, 15);
49 initializer(layer1->weight);
50 layer1->to(torch::kFloat64);
52 auto layer2 = torch::nn::Linear(15, 15);
53 initializer(layer2->weight);
54 layer2->to(torch::kFloat64);
56 auto layer3 = torch::nn::Linear(15, 2);
57 initializer(layer3->weight);
58 layer3->to(torch::kFloat64);
60 auto parameters = std::vector<torch::Tensor>{
66 check_exact_values(parameters, expected);
69 TEST(InitTest, ProducesPyTorchValues_XavierUniform) {
70 auto expected = expected_parameters::Xavier_Uniform();
72 torch::nn::init::xavier_uniform_(tensor);
74 check_initializer_against_baseline(initializer, expected);
77 TEST(InitTest, ProducesPyTorchValues_XavierNormal) {
78 auto expected = expected_parameters::Xavier_Normal();
80 torch::nn::init::xavier_normal_(tensor);
82 check_initializer_against_baseline(initializer, expected);
85 TEST(InitTest, ProducesPyTorchValues_KaimingNormal) {
86 auto expected = expected_parameters::Kaiming_Normal();
88 torch::nn::init::kaiming_normal_(tensor);
90 check_initializer_against_baseline(initializer, expected);
93 TEST(InitTest, ProducesPyTorchValues_KaimingUniform) {
94 auto expected = expected_parameters::Kaiming_Uniform();
96 torch::nn::init::kaiming_uniform_(tensor);
98 check_initializer_against_baseline(initializer, expected);
101 TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
102 auto tensor = torch::empty({3, 4}, torch::requires_grad());
105 "a leaf Variable that requires grad " 106 "has been used in an in-place operation");
107 ASSERT_EQ(torch::nn::init::ones_(tensor).sum().item<int32_t>(), 12);
110 TEST(InitTest, CalculateGainWithTanh) {
112 torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::Tanh);
113 ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
116 TEST(InitTest, CalculateGainWithRelu) {
118 torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::ReLU);
119 ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
122 TEST(InitTest, CalculateGainWithLeakyRelu) {
124 torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
125 ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));