1 #include <gtest/gtest.h> 3 #include <torch/nn/module.h> 4 #include <torch/nn/modules/functional.h> 5 #include <torch/nn/modules/linear.h> 6 #include <torch/nn/modules/sequential.h> 7 #include <torch/optim.h> 8 #include <torch/types.h> 9 #include <torch/utils.h> 11 #include <test/cpp/api/optim_baseline.h> 12 #include <test/cpp/api/support.h> 25 template <
typename OptimizerClass,
typename Options>
26 bool test_optimizer_xor(Options options) {
27 torch::manual_seed(0);
31 Functional(torch::sigmoid),
33 Functional(torch::sigmoid));
35 const int64_t kBatchSize = 4;
36 const int64_t kMaximumNumberOfEpochs = 3000;
38 OptimizerClass optimizer(model->parameters(), options);
40 float running_loss = 1;
42 while (running_loss > 0.1) {
43 auto inputs = torch::empty({kBatchSize, 2});
44 auto labels = torch::empty({kBatchSize});
45 for (
size_t i = 0; i < kBatchSize; i++) {
46 inputs[i] = torch::randint(2, {2}, torch::kInt64);
47 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
49 inputs.set_requires_grad(
true);
50 optimizer.zero_grad();
51 auto x = model->forward(inputs);
57 running_loss = running_loss * 0.99 + loss.item<
float>() * 0.01;
58 if (epoch > kMaximumNumberOfEpochs) {
59 std::cout <<
"Loss is too high after epoch " << epoch <<
": " 60 << running_loss << std::endl;
68 template <
typename Parameters>
69 void assign_parameter(
70 const Parameters& parameters,
73 auto parameter = parameters[name];
74 parameter.set_requires_grad(
false);
75 parameter.flatten().copy_(new_tensor);
76 parameter.set_requires_grad(
true);
79 template <
typename OptimizerClass,
typename Options>
80 void check_exact_values(
83 const size_t kIterations = 1001;
84 const size_t kSampleEvery = 100;
86 torch::manual_seed(0);
90 Functional(torch::sigmoid),
92 Functional(torch::sigmoid));
94 model->to(torch::kFloat64);
97 auto parameters = model->named_parameters();
101 torch::tensor({-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}));
103 parameters,
"0.bias", torch::tensor({-0.1085, -0.2979, 0.6892}));
105 parameters,
"2.weight", torch::tensor({-0.0508, -0.3941, -0.2843}));
106 assign_parameter(parameters,
"2.bias", torch::tensor({-0.0711}));
108 auto optimizer = OptimizerClass(parameters.values(), options);
110 torch::tensor({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}).reshape({3, 2});
112 for (
size_t i = 0; i < kIterations; ++i) {
113 optimizer.zero_grad();
114 auto output = model->forward(input);
115 auto loss = output.sum();
120 if (i % kSampleEvery == 0) {
123 for (
size_t p = 0; p < parameters.size(); ++p) {
124 ASSERT_TRUE(parameters[p]->defined());
125 auto computed = parameters[p]->flatten();
127 if (!computed.allclose(expected, 1e-3, 5e-4)) {
128 std::cout <<
"Iteration " << i <<
": " << computed
129 <<
" != " << expected <<
" (parameter " << p <<
")" 138 TEST(OptimTest, BasicInterface) {
140 using Optimizer::Optimizer;
141 void step()
override {}
143 std::vector<torch::Tensor> parameters = {
144 torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
146 MyOptimizer optimizer(parameters);
147 ASSERT_EQ(optimizer.size(), parameters.size());
150 MyOptimizer optimizer;
151 ASSERT_EQ(optimizer.size(), 0);
152 optimizer.add_parameters(parameters);
153 ASSERT_EQ(optimizer.size(), parameters.size());
154 for (
size_t p = 0; p < parameters.size(); ++p) {
155 ASSERT_TRUE(optimizer.parameters()[p].allclose(parameters[p]));
160 MyOptimizer optimizer(linear->parameters());
161 ASSERT_EQ(optimizer.size(), linear->parameters().size());
165 TEST(OptimTest, XORConvergence_SGD) {
166 ASSERT_TRUE(test_optimizer_xor<SGD>(
167 SGDOptions(0.1).momentum(0.9).nesterov(
true).weight_decay(1e-6)));
170 TEST(OptimTest, XORConvergence_Adagrad) {
171 ASSERT_TRUE(test_optimizer_xor<Adagrad>(
175 TEST(OptimTest, XORConvergence_RMSprop) {
176 ASSERT_TRUE(test_optimizer_xor<RMSprop>(
RMSpropOptions(0.1).centered(
true)));
179 TEST(OptimTest, XORConvergence_RMSpropWithMomentum) {
180 ASSERT_TRUE(test_optimizer_xor<RMSprop>(
184 TEST(OptimTest, XORConvergence_Adam) {
185 ASSERT_TRUE(test_optimizer_xor<Adam>(
AdamOptions(0.1).weight_decay(1e-6)));
188 TEST(OptimTest, XORConvergence_AdamWithAmsgrad) {
189 ASSERT_TRUE(test_optimizer_xor<Adam>(
190 AdamOptions(0.1).weight_decay(1e-6).amsgrad(
true)));
193 TEST(OptimTest, ProducesPyTorchValues_Adam) {
194 check_exact_values<Adam>(
AdamOptions(1.0), expected_parameters::Adam());
197 TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecay) {
198 check_exact_values<Adam>(
200 expected_parameters::Adam_with_weight_decay());
203 TEST(OptimTest, ProducesPyTorchValues_AdamWithWeightDecayAndAMSGrad) {
204 check_exact_values<Adam>(
206 expected_parameters::Adam_with_weight_decay_and_amsgrad());
209 TEST(OptimTest, ProducesPyTorchValues_Adagrad) {
210 check_exact_values<Adagrad>(
214 TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecay) {
215 check_exact_values<Adagrad>(
217 expected_parameters::Adagrad_with_weight_decay());
220 TEST(OptimTest, ProducesPyTorchValues_AdagradWithWeightDecayAndLRDecay) {
221 check_exact_values<Adagrad>(
223 expected_parameters::Adagrad_with_weight_decay_and_lr_decay());
226 TEST(OptimTest, ProducesPyTorchValues_RMSprop) {
227 check_exact_values<RMSprop>(
231 TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecay) {
232 check_exact_values<RMSprop>(
234 expected_parameters::RMSprop_with_weight_decay());
237 TEST(OptimTest, ProducesPyTorchValues_RMSpropWithWeightDecayAndCentered) {
238 check_exact_values<RMSprop>(
240 expected_parameters::RMSprop_with_weight_decay_and_centered());
245 ProducesPyTorchValues_RMSpropWithWeightDecayAndCenteredAndMomentum) {
246 check_exact_values<RMSprop>(
247 RMSpropOptions(0.1).weight_decay(1e-6).centered(
true).momentum(0.9),
248 expected_parameters::
249 RMSprop_with_weight_decay_and_centered_and_momentum());
252 TEST(OptimTest, ProducesPyTorchValues_SGD) {
253 check_exact_values<SGD>(
SGDOptions(0.1), expected_parameters::SGD());
256 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecay) {
257 check_exact_values<SGD>(
259 expected_parameters::SGD_with_weight_decay());
262 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndMomentum) {
263 check_exact_values<SGD>(
264 SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
265 expected_parameters::SGD_with_weight_decay_and_momentum());
268 TEST(OptimTest, ProducesPyTorchValues_SGDWithWeightDecayAndNesterovMomentum) {
269 check_exact_values<SGD>(
270 SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(
true),
271 expected_parameters::SGD_with_weight_decay_and_nesterov_momentum());
274 TEST(OptimTest, ZeroGrad) {
275 torch::manual_seed(0);
278 SGD optimizer(model->parameters(), 0.1);
280 for (
const auto& parameter : model->parameters()) {
281 ASSERT_FALSE(parameter.grad().defined());
284 auto output = model->forward(torch::ones({5, 2}));
285 auto loss = output.sum();
288 for (
const auto& parameter : model->parameters()) {
289 ASSERT_TRUE(parameter.grad().defined());
290 ASSERT_GT(parameter.grad().sum().item<
float>(), 0);
293 optimizer.zero_grad();
295 for (
const auto& parameter : model->parameters()) {
296 ASSERT_TRUE(parameter.grad().defined());
297 ASSERT_EQ(parameter.grad().sum().item<
float>(), 0);
301 TEST(OptimTest, ExternalVectorOfParameters) {
302 torch::manual_seed(0);
304 std::vector<torch::Tensor> parameters = {
305 torch::randn({2, 2}), torch::randn({3, 3}), torch::randn({4, 4})};
306 std::vector<torch::Tensor> original_parameters = {
307 parameters[0].clone(), parameters[1].clone(), parameters[2].clone()};
310 for (
auto& parameter : parameters) {
311 parameter.grad() = torch::ones_like(parameter);
314 SGD optimizer(parameters, 1.0);
318 ASSERT_TRUE(parameters[0].allclose(original_parameters[0] - 1.0));
319 ASSERT_TRUE(parameters[1].allclose(original_parameters[1] - 1.0));
320 ASSERT_TRUE(parameters[2].allclose(original_parameters[2] - 1.0));
323 TEST(OptimTest, AddParameter_LBFGS) {
324 torch::manual_seed(0);
326 std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
327 std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};
330 for (
auto& parameter : parameters) {
331 parameter.grad() = torch::ones_like(parameter);
334 LBFGS optimizer(std::vector<torch::Tensor>{}, 1.0);
335 optimizer.add_parameters(parameters);
337 optimizer.step([]() {
return torch::tensor(1); });
void backward(c10::optional< Tensor > gradient=c10::nullopt, bool keep_graph=false, bool create_graph=false)
Computes the gradient of current tensor w.r.t. graph leaves.
Optimizer that defines a required step() method that takes no arguments and produces no values...