Caffe2 - C++ API
A deep learning, cross platform ML framework
module.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/nn/module.h>
4 #include <torch/nn/modules/linear.h>
5 #include <torch/nn/modules/rnn.h>
6 #include <torch/nn/modules/sequential.h>
7 #include <torch/types.h>
8 #include <torch/utils.h>
9 
10 #include <test/cpp/api/support.h>
11 
12 using namespace torch::nn;
13 using namespace torch::test;
14 
16 
17 namespace test {
20  AGIUnit2() : torch::nn::Module("Foo") {}
21 };
22 } // namespace test
23 
25 
26 TEST_F(ModuleTest, CanEnableAndDisableTrainingMode) {
27  Linear module(3, 4);
28  ASSERT_TRUE(module->is_training());
29 
30  module->eval();
31  ASSERT_FALSE(module->is_training());
32 
33  module->train();
34  ASSERT_TRUE(module->is_training());
35 }
36 
37 TEST_F(ModuleTest, ZeroGrad) {
38  Linear module(3, 4);
39  auto weight = torch::ones({8, 3}, torch::requires_grad());
40  auto loss = module(weight).sum();
41  loss.backward();
42  for (auto& parameter : module->parameters()) {
43  auto grad = parameter.grad();
44  ASSERT_TRUE(grad.defined());
45  ASSERT_NE(grad.sum().item<float>(), 0);
46  }
47  module->zero_grad();
48  for (auto& parameter : module->parameters()) {
49  auto grad = parameter.grad();
50  ASSERT_TRUE(grad.defined());
51  ASSERT_EQ(grad.sum().item<float>(), 0);
52  }
53 }
54 
55 TEST_F(ModuleTest, ZeroGradWithUndefined) {
56  struct TestModule : torch::nn::Module {
57  TestModule() {
58  x = register_parameter("x", torch::ones(5, torch::requires_grad()));
59  y = register_parameter("y", torch::ones(5, torch::requires_grad()));
60  }
61  torch::Tensor x, y;
62  };
63 
64  TestModule module;
65  auto z = module.x * 2;
66  z.sum().backward();
67 
68  ASSERT_TRUE(module.x.grad().defined());
69  ASSERT_FALSE(module.y.grad().defined());
70 
71  module.zero_grad();
72 
73  ASSERT_TRUE(module.x.grad().defined());
74  ASSERT_FALSE(module.y.grad().defined());
75 
76  ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
77 }
78 
79 TEST_F(ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
80  struct TestModel : public torch::nn::Module {
82  };
83  ASSERT_THROWS_WITH(
84  TestModel{}.register_module("name.with.dot", torch::nn::Linear(3, 4)),
85  "Submodule name must not contain a dot (got 'name.with.dot')");
86  ASSERT_THROWS_WITH(
87  TestModel{}.register_module("", torch::nn::Linear(3, 4)),
88  "Submodule name must not be empty");
89 }
90 
91 TEST_F(ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
92  struct TestModel : public torch::nn::Module {
94  };
95  TestModel model;
96  model.register_module("linear", torch::nn::Linear(3, 4));
97  ASSERT_THROWS_WITH(
98  model.register_module("linear", torch::nn::Linear(3, 4)),
99  "Submodule 'linear' already defined");
100 }
101 
102 TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
103  struct TestModel : public torch::nn::Module {
105  };
106  ASSERT_THROWS_WITH(
107  TestModel{}.register_parameter("name.with.dot", torch::ones(5)),
108  "Parameter name must not contain a dot (got 'name.with.dot')");
109  ASSERT_THROWS_WITH(
110  TestModel{}.register_parameter("", torch::ones(5)),
111  "Parameter name must not be empty");
112 }
113 
114 TEST_F(ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
115  struct TestModel : public torch::nn::Module {
117  };
118  TestModel model;
119  model.register_parameter("p", torch::ones(5));
120  ASSERT_THROWS_WITH(
121  model.register_parameter("p", torch::ones(5)),
122  "Parameter 'p' already defined");
123 }
124 
125 TEST_F(ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
126  struct TestModel : public torch::nn::Module {
128  };
129  ASSERT_THROWS_WITH(
130  TestModel{}.register_buffer("name.with.dot", torch::ones(5)),
131  "Buffer name must not contain a dot (got 'name.with.dot')");
132  ASSERT_THROWS_WITH(
133  TestModel{}.register_buffer("", torch::ones(5)),
134  "Buffer name must not be empty");
135 }
136 
137 TEST_F(ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
138  struct TestModel : public torch::nn::Module {
140  };
141  TestModel model;
142  model.register_buffer("p", torch::ones(5));
143  ASSERT_THROWS_WITH(
144  model.register_buffer("p", torch::ones(5)), "Buffer 'p' already defined");
145 }
146 
147 TEST_F(ModuleTest, CanGetName) {
148  // CHECK instead of REQUIRE because demangling may fail.
149  AGIUnit agi;
150  // Call it twice just to make sure there are no bugs in the lazy
151  // initialization semantics.
152  EXPECT_EQ(agi.name(), "AGIUnit");
153  EXPECT_EQ(agi.name(), "AGIUnit");
154  EXPECT_EQ(test::AGIUnit().name(), "test::AGIUnit");
155  EXPECT_EQ(test::AGIUnit2().name(), "Foo");
156 }
157 
158 TEST_F(ModuleTest, AsCastsModulesCorrectly) {
159  Linear module(3, 4);
160  ASSERT_EQ(module->as<Linear>(), module.get());
161  ASSERT_EQ(module->as<LinearImpl>(), module.get());
162  ASSERT_EQ(module->as<Module>(), module.get());
163  ASSERT_EQ(module->as<AGIUnit>(), nullptr);
164 
165  std::shared_ptr<Module> raw = module.ptr();
166  ASSERT_EQ(raw->as<Linear>(), module.get());
167  ASSERT_EQ(raw->as<LinearImpl>(), module.get());
168  ASSERT_EQ(raw->as<Module>(), module.get());
169  ASSERT_EQ(raw->as<AGIUnit>(), nullptr);
170 
171  Module& raw_ref = *raw.get();
172  ASSERT_EQ(raw_ref.as<Linear>(), module.get());
173  ASSERT_EQ(raw_ref.as<LinearImpl>(), module.get());
174  ASSERT_EQ(raw_ref.as<Module>(), module.get());
175  ASSERT_EQ(raw_ref.as<AGIUnit>(), nullptr);
176  if (auto* linear = raw_ref.as<Linear>()) {
177  ASSERT_EQ(linear->weight.ndimension(), 2);
178  }
179 
180  AGIUnit unit;
181  ASSERT_EQ(unit.as<Linear>(), nullptr);
182  ASSERT_EQ(unit.as<LinearImpl>(), nullptr);
183  ASSERT_EQ(unit.as<AGIUnit>(), &unit);
184 }
185 
186 TEST_F(ModuleTest, Conversion_MultiCUDA) {
187  Linear module(128, 64);
188  for (auto& parameter : module->parameters()) {
189  ASSERT_EQ(parameter.device(), torch::Device(torch::kCPU));
190  ASSERT_EQ(parameter.dtype(), torch::kFloat32);
191  }
192  {
193  module->to({torch::kCUDA, 0});
194  for (auto& parameter : module->parameters()) {
195  ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
196  ASSERT_EQ(parameter.device().index(), 0);
197  }
198  module->to({torch::kCUDA, 1});
199  for (auto& parameter : module->parameters()) {
200  ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
201  ASSERT_EQ(parameter.device().index(), 1);
202  }
203  }
204  {
205  module->to(torch::Device(torch::kCPU));
206  for (auto& parameter : module->parameters()) {
207  ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
208  }
209  }
210  {
211  module->to(torch::kInt32);
212  for (auto& parameter : module->parameters()) {
213  ASSERT_EQ(parameter.dtype(), torch::kInt32);
214  }
215  }
216  {
217  module->to(torch::kFloat64);
218  for (auto& parameter : module->parameters()) {
219  ASSERT_EQ(parameter.dtype(), torch::kFloat64);
220  }
221  }
222  {
223  module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
224  for (auto& parameter : module->parameters()) {
225  ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
226  ASSERT_EQ(parameter.device().index(), 1);
227  }
228  for (auto& parameter : module->parameters()) {
229  ASSERT_EQ(parameter.dtype(), torch::kUInt8);
230  }
231  }
232 }
233 
234 TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
235  struct UnCloneable : Module {};
236  UnCloneable module;
237  ASSERT_THROWS_WITH(module.clone(), "clone() has not been implemented");
238 }
239 
240 TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
241  struct Cloneable : Module {
242  std::shared_ptr<Module> clone(
243  const torch::optional<torch::Device>& device =
244  torch::nullopt) const override {
245  return nullptr;
246  }
247  };
248  Cloneable module;
249  ASSERT_NO_THROW({ module.clone(); });
250 }
251 
252 TEST_F(ModuleTest, CloneCreatesDistinctParameters) {
253  struct TestModule : public Cloneable<TestModule> {
254  TestModule() {
255  reset();
256  }
257  void reset() override {
258  l1 = register_module("l1", Linear(10, 3));
259  l2 = register_module("l2", Linear(3, 5));
260  l3 = register_module("l3", Linear(5, 100));
261  buffer = register_buffer("buf", torch::ones({2, 2}));
262  }
263 
264  Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
265  torch::Tensor buffer;
266  };
267 
268  auto module = std::make_shared<TestModule>();
269 
270  torch::NoGradGuard no_grad;
271 
272  auto module2 = module->clone();
273  auto params1 = module->named_parameters();
274  auto params2 = module2->named_parameters();
275  ASSERT_EQ(params1.size(), 6);
276  ASSERT_EQ(params2.size(), 6);
277  for (auto& param : params1) {
278  ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
279  ASSERT_TRUE(param->allclose(params2[param.key()]));
280  param->add_(2);
281  }
282  for (auto& param : params1) {
283  ASSERT_FALSE(param->allclose(params2[param.key()]));
284  }
285 
286  auto buffers1 = module->named_buffers();
287  auto buffers2 = module2->named_buffers();
288  ASSERT_EQ(buffers1.size(), 1);
289  ASSERT_EQ(buffers2.size(), 1);
290  for (auto& buffer : buffers1) {
291  ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
292  ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
293  buffer->add_(2);
294  }
295  for (auto& buffer : buffers1) {
296  ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
297  }
298 }
299 
300 TEST_F(ModuleTest, ClonePreservesExternalReferences) {
301  struct TestModule : public Cloneable<TestModule> {
302  TestModule() {
303  reset();
304  }
305  void reset() override {
306  weight = register_parameter("weight", torch::ones({4, 4}));
307  }
308  torch::Tensor weight;
309  };
310  auto module = std::make_shared<TestModule>();
311  {
312  torch::NoGradGuard no_grad;
313  module->weight += 1;
314  }
315  ASSERT_TRUE(
316  pointer_equal(module->weight, module->named_parameters()["weight"]));
317  ASSERT_TRUE(module->weight.allclose(module->named_parameters()["weight"]));
318 
319  auto module2 = std::dynamic_pointer_cast<TestModule>(
320  std::shared_ptr<Module>(module->clone()));
321  ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
322  ASSERT_TRUE(
323  pointer_equal(module2->weight, module2->named_parameters()["weight"]));
324  ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()["weight"]));
325  ASSERT_TRUE(module2->weight.allclose(module->weight));
326  ASSERT_FALSE(
327  pointer_equal(module2->weight, module->named_parameters()["weight"]));
328 }
329 
330 TEST_F(ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
331  struct TestModule : public Cloneable<TestModule> {
332  TestModule() {
333  reset();
334  }
335  void reset() override {
336  weight = register_parameter("weight", torch::ones({4, 4}));
337  }
338 
339  torch::Tensor weight;
340  int value = 0;
341  };
342  struct NestedModule : public Cloneable<NestedModule> {
343  NestedModule() {
344  reset();
345  }
346  void reset() override {
347  module = register_module("module", std::make_shared<TestModule>());
348  }
349  std::shared_ptr<TestModule> module;
350  };
351 
352  auto a = std::make_shared<NestedModule>();
353  {
354  torch::NoGradGuard no_grad;
355  a->module->weight += 1;
356  a->module->value = 123;
357  }
358 
359  auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
360 
361  ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
362  ASSERT_TRUE(pointer_equal(
363  b->module->weight, b->module->named_parameters()["weight"]));
364  ASSERT_TRUE(
365  b->module->named_parameters()["weight"].allclose(a->module->weight));
366  ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
367  ASSERT_EQ(b->module->value, a->module->value);
368 }
369 
370 TEST_F(ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
371  struct TestModule : public Cloneable<TestModule> {
372  TestModule() {
373  reset();
374  }
375  void reset() override {
376  l1 = register_module("l1", Linear(10, 3));
377  l2 = register_module("l2", Linear(3, 5));
378  l3 = register_module("l3", Linear(5, 100));
379  buffer = register_buffer("buf", torch::ones({2, 2}));
380  }
381 
382  Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
383  torch::Tensor buffer;
384  };
385 
386  TestModule m;
387  torch::Device device(torch::kCUDA, 0);
388 
389  m.to(device);
390 
391  auto clone = m.clone();
392  for (const auto& parameter : clone->parameters()) {
393  ASSERT_EQ(parameter.device().type(), device.type());
394  ASSERT_EQ(parameter.device().index(), device.index());
395  }
396  for (const auto& buffer : clone->buffers()) {
397  ASSERT_EQ(buffer.device().type(), device.type());
398  ASSERT_EQ(buffer.device().index(), device.index());
399  }
400 }
401 
402 TEST_F(
403  ModuleTest,
404  CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
405  struct TestModule : public Cloneable<TestModule> {
406  TestModule() {
407  reset();
408  }
409  void reset() override {
410  l1 = register_module("l1", Linear(10, 3));
411  l2 = register_module("l2", Linear(3, 5));
412  l3 = register_module("l3", Linear(5, 100));
413  buffer = register_buffer("buf", torch::ones({2, 2}));
414  }
415 
416  Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
417  torch::Tensor buffer;
418  };
419 
420  TestModule m;
421  torch::Device device(torch::kCUDA, 1);
422  // everything is on CPU here
423  auto clone = m.clone(device);
424  for (const auto& parameter : clone->parameters()) {
425  ASSERT_EQ(parameter.device().type(), device.type());
426  ASSERT_EQ(parameter.device().index(), device.index());
427  }
428  for (const auto& buffer : clone->buffers()) {
429  ASSERT_EQ(buffer.device().type(), device.type());
430  ASSERT_EQ(buffer.device().index(), device.index());
431  }
432 }
433 
436  a = register_parameter("a", torch::zeros({2, 2}));
437  b = register_parameter("b", torch::ones({2, 2}));
438  c = register_parameter("c", torch::ones({2, 2}) * 2);
439  }
440 
441  torch::Tensor a, b, c;
442 };
443 
444 TEST_F(ModuleTest, HasCorrectNumberOfParameters) {
445  ParameterTestModule module;
446  ASSERT_EQ(module.parameters().size(), 3);
447  ASSERT_EQ(module.named_parameters().size(), 3);
448 }
449 
450 TEST_F(ModuleTest, ContainsParametersWithTheCorrectName) {
451  ParameterTestModule module;
452  auto parameters = module.named_parameters();
453  ASSERT_TRUE(parameters.contains("a"));
454  ASSERT_TRUE(parameters.contains("b"));
455  ASSERT_TRUE(parameters.contains("c"));
456 }
457 
459  BufferTestModule() {
460  a = register_buffer("a", torch::zeros({2, 2}));
461  b = register_buffer("b", torch::ones({2, 2}));
462  c = register_buffer("c", torch::ones({2, 2}) * 2);
463  }
464 
465  torch::Tensor a, b, c;
466 };
467 
468 TEST_F(ModuleTest, HasCorrectNumberOfBuffers) {
469  BufferTestModule module;
470  ASSERT_EQ(module.buffers().size(), 3);
471  ASSERT_EQ(module.named_buffers().size(), 3);
472 }
473 
474 TEST_F(ModuleTest, ContainsBuffersWithTheCorrectName) {
475  BufferTestModule module;
476  auto buffers = module.named_buffers();
477  ASSERT_TRUE(buffers.contains("a"));
478  ASSERT_TRUE(buffers.contains("b"));
479  ASSERT_TRUE(buffers.contains("c"));
480 }
481 
483  AImpl() : x_(123) {}
484  AImpl(int x) : x_(x) {}
485  int x_;
486 };
487 TORCH_MODULE(A);
488 
489 TEST_F(
490  ModuleTest,
491  DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
492  A a;
493  ASSERT_TRUE(a);
494  ASSERT_FALSE(a.is_empty());
495  ASSERT_EQ(a->x_, 123);
496 }
497 
498 TEST_F(
499  ModuleTest,
500  ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
501  A a(5);
502  ASSERT_TRUE(a);
503  ASSERT_FALSE(a.is_empty());
504  ASSERT_EQ(a->x_, 5);
505 }
506 
507 TEST_F(ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
508  A a = nullptr;
509  ASSERT_FALSE(a);
510  ASSERT_TRUE(a.is_empty());
511  ASSERT_THROWS_WITH(a->x_, "Accessing empty ModuleHolder");
512 }
513 
514 struct TestModule : public torch::nn::Module {
515  TestModule(int64_t size) {
516  p1 = register_parameter("p1", torch::randn({size}));
517  p2 = register_parameter("p2", torch::randn({size}));
518  b1 = register_buffer("b1", torch::randn({size}));
519  b2 = register_buffer("b2", torch::randn({size}));
520  }
521 
522  torch::Tensor forward(torch::Tensor input) {
523  return input;
524  }
525 
526  torch::Tensor p1, p2, b1, b2;
527 };
528 
529 TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
530  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
531  std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
532  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
533  model.ptr(), model[0], model[1], model[2]};
534  ASSERT_EQ(modules.size(), expected.size());
535  for (size_t i = 0; i < expected.size(); ++i) {
536  // Assert pointer equality.
537  ASSERT_EQ(modules[i].get(), expected[i].get());
538  }
539 }
540 
541 TEST_F(ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
542  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
543  std::vector<std::shared_ptr<torch::nn::Module>> modules =
544  model->modules(/*include_self=*/false);
545  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
546  model[0], model[1], model[2]};
547  ASSERT_EQ(modules.size(), expected.size());
548  for (size_t i = 0; i < expected.size(); ++i) {
549  // Assert pointer equality.
550  ASSERT_EQ(modules[i].get(), expected[i].get());
551  }
552 }
553 
554 TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
555  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
557  model->named_modules();
558  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
559  model.ptr(), model[0], model[1], model[2]};
560  ASSERT_EQ(modules.size(), expected.size());
561  for (size_t i = 0; i < expected.size(); ++i) {
562  // Assert pointer equality.
563  ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
564  ASSERT_EQ(modules[i].value().get(), expected[i].get());
565  }
566 }
567 
568 TEST_F(ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
569  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
571  model->named_modules(
572  /*name_prefix=*/std::string(), /*include_self=*/false);
573  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
574  model[0], model[1], model[2]};
575  ASSERT_EQ(modules.size(), expected.size());
576  for (size_t i = 0; i < expected.size(); ++i) {
577  // Assert pointer equality.
578  ASSERT_EQ(modules[i].key(), std::to_string(i));
579  ASSERT_EQ(modules[i].value().get(), expected[i].get());
580  }
581 }
582 
583 TEST_F(ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
584  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
585  std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
586  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
587  model[0], model[1], model[2]};
588  ASSERT_EQ(modules.size(), expected.size());
589  for (size_t i = 0; i < expected.size(); ++i) {
590  // Assert pointer equality.
591  ASSERT_EQ(modules[i].get(), expected[i].get());
592  }
593 
594  // For this flat model, this should be true.
595  ASSERT_EQ(modules, model->modules(/*include_self=*/false));
596 }
597 
598 TEST_F(ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
599  torch::nn::Sequential model(TestModule(1), TestModule(2), TestModule(3));
601  model->named_children();
602  std::vector<std::shared_ptr<torch::nn::Module>> expected = {
603  model[0], model[1], model[2]};
604  ASSERT_EQ(modules.size(), expected.size());
605  for (size_t i = 0; i < expected.size(); ++i) {
606  // Assert pointer equality.
607  ASSERT_EQ(modules[i].key(), std::to_string(i));
608  ASSERT_EQ(modules[i].value().get(), expected[i].get());
609  }
610 }
611 
612 TEST_F(ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
613  TestModule module(1);
614  std::vector<torch::Tensor> parameters = module.parameters();
615  ASSERT_EQ(parameters.size(), 2);
616  ASSERT_EQ(parameters[0].data<float>(), module.p1.data<float>());
617  ASSERT_EQ(parameters[1].data<float>(), module.p2.data<float>());
618 }
619 
620 TEST_F(ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
621  TestModule module(1);
623  module.named_parameters();
624  ASSERT_EQ(parameters.size(), 2);
625  ASSERT_EQ(parameters[0].key(), "p1");
626  ASSERT_EQ(parameters[0]->data<float>(), module.p1.data<float>());
627  ASSERT_EQ(parameters[1].key(), "p2");
628  ASSERT_EQ(parameters[1]->data<float>(), module.p2.data<float>());
629 }
630 
631 TEST_F(ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
632  TestModule module(1);
633  std::vector<torch::Tensor> buffers = module.buffers();
634  ASSERT_EQ(buffers.size(), 2);
635  ASSERT_EQ(buffers[0].data<float>(), module.b1.data<float>());
636  ASSERT_EQ(buffers[1].data<float>(), module.b2.data<float>());
637 }
638 
639 TEST_F(ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
640  TestModule module(1);
642  module.named_buffers();
643  ASSERT_EQ(buffers.size(), 2);
644  ASSERT_EQ(buffers[0].key(), "b1");
645  ASSERT_EQ(buffers[0]->data<float>(), module.b1.data<float>());
646  ASSERT_EQ(buffers[1].key(), "b2");
647  ASSERT_EQ(buffers[1]->data<float>(), module.b2.data<float>());
648 }
649 
651  TestContainer(int64_t number, std::vector<TestContainer> modules = {})
652  : tensor(torch::tensor(number)) {
653  for (size_t i = 0; i < modules.size(); ++i) {
654  register_module(
655  std::to_string(i),
656  std::make_shared<TestContainer>(std::move(modules[i])));
657  }
658  }
659  torch::Tensor tensor;
660 };
661 
662 int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
663  return std::dynamic_pointer_cast<TestContainer>(module)
664  ->tensor.item<int64_t>();
665 }
666 
667 std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
668  return std::make_shared<TestContainer>(TestContainer(
669  0,
671  TestContainer(4),
673  5,
674  {TestContainer(6),
675  TestContainer(7, {TestContainer(8), TestContainer(9)})})}));
676 }
677 
678 std::vector<std::pair<std::string, int64_t>>
679 make_key_value_pairs_for_deeply_nested_container() {
680  return {{"test_prefix", 0},
681  {"test_prefix.0", 1},
682  {"test_prefix.0.0", 2},
683  {"test_prefix.0.1", 3},
684  {"test_prefix.1", 4},
685  {"test_prefix.2", 5},
686  {"test_prefix.2.0", 6},
687  {"test_prefix.2.1", 7},
688  {"test_prefix.2.1.0", 8},
689  {"test_prefix.2.1.1", 9}};
690 }
691 
692 TEST_F(ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
693  auto model = make_deeply_nested_test_container();
694  std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
695 
696  ASSERT_EQ(modules.size(), 10);
697  for (size_t i = 0; i < modules.size(); ++i) {
698  ASSERT_EQ(get_test_container_item(modules[i]), i);
699  }
700 }
701 
702 TEST_F(ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
703  auto model = make_deeply_nested_test_container();
705  model->named_modules(/*name_prefix=*/"test_prefix");
706  auto expected = make_key_value_pairs_for_deeply_nested_container();
707 
708  ASSERT_EQ(modules.size(), expected.size());
709 
710  for (size_t i = 0; i < expected.size(); ++i) {
711  ASSERT_EQ(modules[i].key(), expected[i].first);
712  ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
713  }
714 }
715 
716 TEST_F(ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
717  auto model = make_deeply_nested_test_container();
718  std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
719 
720  ASSERT_EQ(modules.size(), 3);
721  ASSERT_EQ(get_test_container_item(modules[0]), 1);
722  ASSERT_EQ(get_test_container_item(modules[1]), 4);
723  ASSERT_EQ(get_test_container_item(modules[2]), 5);
724 }
725 
726 TEST_F(ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
727  auto model = make_deeply_nested_test_container();
729  model->named_children();
730 
731  ASSERT_EQ(modules.size(), 3);
732 
733  ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
734  ASSERT_EQ(modules[0].key(), "0");
735 
736  ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
737  ASSERT_EQ(modules[1].key(), "1");
738 
739  ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
740  ASSERT_EQ(modules[2].key(), "2");
741 }
742 
743 TEST_F(ModuleTest, ModuleApplyIteratesCorreclty) {
744  auto model = make_deeply_nested_test_container();
745  int64_t index = 0;
746  model->apply([&index](torch::nn::Module& module) {
747  ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
748  });
749  ASSERT_EQ(index, 10);
750 }
751 
752 TEST_F(ModuleTest, ConstModuleApplyIteratesCorreclty) {
753  std::shared_ptr<const TestContainer> model =
754  make_deeply_nested_test_container();
755  int64_t index = 0;
756  model->apply([&index](const torch::nn::Module& module) {
757  ASSERT_EQ(module.as<TestContainer>()->tensor.item<int64_t>(), index++);
758  });
759  ASSERT_EQ(index, 10);
760 }
761 
762 TEST_F(ModuleTest, NamedModuleApplyIteratesCorreclty) {
763  auto model = make_deeply_nested_test_container();
764  auto expected = make_key_value_pairs_for_deeply_nested_container();
765  int64_t index = 0;
766  model->apply(
767  [&index, expected](const std::string& name, torch::nn::Module& module) {
768  ASSERT_EQ(name, expected[index].first);
769  ASSERT_EQ(
770  module.as<TestContainer>()->tensor.item<int64_t>(),
771  expected[index++].second);
772  },
773  /*name_prefix=*/"test_prefix");
774  ASSERT_EQ(index, 10);
775 }
776 
777 TEST_F(ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
778  std::shared_ptr<const TestContainer> model =
779  make_deeply_nested_test_container();
780  auto expected = make_key_value_pairs_for_deeply_nested_container();
781  int64_t index = 0;
782  model->apply(
783  [&index, &expected](
784  const std::string& name, const torch::nn::Module& module) {
785  ASSERT_EQ(name, expected[index].first);
786  ASSERT_EQ(
787  module.as<const TestContainer>()->tensor.item<int64_t>(),
788  expected[index++].second);
789  },
790  /*name_prefix=*/"test_prefix");
791  ASSERT_EQ(index, 10);
792 }
793 
794 TEST_F(ModuleTest, ModulePointerApplyIteratesCorreclty) {
795  auto model = make_deeply_nested_test_container();
796  int64_t index = 0;
797  model->apply([&index](const std::shared_ptr<torch::nn::Module>& module) {
798  ASSERT_EQ(get_test_container_item(module), index++);
799  });
800  ASSERT_EQ(index, 10);
801 }
802 
803 TEST_F(ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
804  auto model = make_deeply_nested_test_container();
805  auto expected = make_key_value_pairs_for_deeply_nested_container();
806  int64_t index = 0;
807  model->apply(
808  [&index, &expected](
809  const std::string& name,
810  const std::shared_ptr<torch::nn::Module>& module) {
811  ASSERT_EQ(name, expected[index].first);
812  ASSERT_EQ(get_test_container_item(module), expected[index++].second);
813  },
814  /*name_prefix=*/"test_prefix");
815  ASSERT_EQ(index, 10);
816 }
817 
818 TEST_F(ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
819  {
820  TestModule module(1);
821  ASSERT_THROWS_WITH(
822  module.modules(),
823  "It looks like you attempted to retrieve "
824  "your top-level module as a shared_ptr")
825  }
826  {
827  TestModule module(1);
828  ASSERT_NO_THROW(module.modules(/*include_self=*/false));
829  }
830  {
831  auto module = std::make_shared<TestModule>(1);
832  ASSERT_NO_THROW(module->modules());
833  }
834 }
835 
837 
838 TEST_F(ModuleTest, PrettyPrint) {
839  struct TestModule : torch::nn::Module {
840  TestModule(int x, float y) : x_(x), y_(y) {}
841 
842  void pretty_print(std::ostream& stream) const override {
843  stream << "TestModule(x=" << x_ << ", y=" << y_ << ")";
844  }
845 
846  int x_;
847  float y_;
848  };
849 
850  using namespace torch::nn;
851 
852  ASSERT_EQ(c10::str(EmptyModule{}), "EmptyModule");
853  ASSERT_EQ(c10::str(TestModule(1, 3.14)), "TestModule(x=1, y=3.14)");
854 }
855 
857  int64_t forward(torch::Tensor x) {
858  return x.numel();
859  }
860 };
861 TORCH_MODULE(ModuleWithNonTensorForward);
862 
863 TEST_F(ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
864  ModuleWithNonTensorForward m;
865  ASSERT_EQ(m(torch::ones(123)), 123);
866 }
Applies a linear transformation with optional bias.
Definition: linear.h:25
size_t size() const noexcept
Returns the number of items currently stored in the OrderedDict.
Definition: ordered_dict.h:419
std::vector< Tensor > buffers(bool recurse=true) const
Returns the buffers of this Module and if recurse is true, also recursively of every submodule...
Definition: module.cpp:166
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
Definition: module.cpp:301
OrderedDict< std::string, Tensor > named_parameters(bool recurse=true) const
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is tr...
Definition: module.cpp:153
Definition: module.cpp:17
const std::string & name() const noexcept
Returns the name of the Module.
Definition: module.cpp:53
virtual void zero_grad()
Recursively zeros out the grad value of each registered parameter.
Definition: module.cpp:260
std::vector< std::shared_ptr< Module > > modules(bool include_self=true) const
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true...
Definition: module.cpp:187
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
std::vector< Tensor > parameters(bool recurse=true) const
Returns the parameters of this Module and if recurse is true, also recursively of every submodule...
Definition: module.cpp:143
std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const override
Performs a recursive "deep copy" of the Module, such that all parameters and submodules in the cloned...
Definition: cloneable.h:34
Tensor & register_buffer(std::string name, Tensor tensor)
Registers a buffer with this Module.
Definition: module.cpp:315
virtual std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodule...
Definition: module.cpp:78
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
Definition: module.cpp:244
does bound shape inference given a C2 net.
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
The base class for all modules in PyTorch.
Definition: module.h:62
OrderedDict< std::string, Tensor > named_buffers(bool recurse=true) const
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true ...
Definition: module.cpp:174
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70
ModuleType::ContainedType * as() noexcept
Attempts to cast this Module to the given ModuleType.
Definition: module.h:532
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module.
Definition: module.h:556
An ordered dictionary implementation, akin to Python&#39;s OrderedDict.
Definition: ordered_dict.h:16
DeviceType type() const noexcept
Returns the type of device this is.
Definition: Device.h:65