1 #include <gtest/gtest.h> 3 #include <torch/nn/module.h> 4 #include <torch/nn/modules/linear.h> 5 #include <torch/nn/modules/rnn.h> 6 #include <torch/nn/modules/sequential.h> 7 #include <torch/types.h> 8 #include <torch/utils.h> 10 #include <test/cpp/api/support.h> 26 TEST_F(
ModuleTest, CanEnableAndDisableTrainingMode) {
28 ASSERT_TRUE(module->is_training());
31 ASSERT_FALSE(module->is_training());
34 ASSERT_TRUE(module->is_training());
39 auto weight = torch::ones({8, 3}, torch::requires_grad());
40 auto loss = module(weight).sum();
42 for (
auto& parameter : module->parameters()) {
43 auto grad = parameter.grad();
44 ASSERT_TRUE(grad.defined());
45 ASSERT_NE(grad.sum().item<
float>(), 0);
48 for (
auto& parameter : module->parameters()) {
49 auto grad = parameter.grad();
50 ASSERT_TRUE(grad.defined());
51 ASSERT_EQ(grad.sum().item<
float>(), 0);
58 x = register_parameter(
"x", torch::ones(5, torch::requires_grad()));
59 y = register_parameter(
"y", torch::ones(5, torch::requires_grad()));
65 auto z = module.x * 2;
68 ASSERT_TRUE(module.x.grad().defined());
69 ASSERT_FALSE(module.y.grad().defined());
73 ASSERT_TRUE(module.x.grad().defined());
74 ASSERT_FALSE(module.y.grad().defined());
76 ASSERT_EQ(module.x.grad().sum().item<
float>(), 0);
79 TEST_F(
ModuleTest, RegisterModuleThrowsForEmptyOrDottedName) {
85 "Submodule name must not contain a dot (got 'name.with.dot')");
88 "Submodule name must not be empty");
91 TEST_F(
ModuleTest, RegisterModuleThrowsForDuplicateModuleName) {
99 "Submodule 'linear' already defined");
102 TEST_F(
ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) {
108 "Parameter name must not contain a dot (got 'name.with.dot')");
111 "Parameter name must not be empty");
114 TEST_F(
ModuleTest, RegisterParameterThrowsForDuplicateModuleName) {
122 "Parameter 'p' already defined");
125 TEST_F(
ModuleTest, RegisterBufferThrowsForEmptyOrDottedName) {
131 "Buffer name must not contain a dot (got 'name.with.dot')");
134 "Buffer name must not be empty");
137 TEST_F(
ModuleTest, RegisterBufferThrowsForDuplicateModuleName) {
144 model.
register_buffer(
"p", torch::ones(5)),
"Buffer 'p' already defined");
152 EXPECT_EQ(agi.
name(),
"AGIUnit");
153 EXPECT_EQ(agi.
name(),
"AGIUnit");
160 ASSERT_EQ(module->as<Linear>(), module.get());
161 ASSERT_EQ(module->as<
LinearImpl>(), module.get());
162 ASSERT_EQ(module->as<
Module>(), module.get());
163 ASSERT_EQ(module->as<
AGIUnit>(),
nullptr);
165 std::shared_ptr<Module> raw = module.ptr();
166 ASSERT_EQ(raw->as<Linear>(), module.get());
167 ASSERT_EQ(raw->as<
LinearImpl>(), module.get());
168 ASSERT_EQ(raw->as<
Module>(), module.get());
169 ASSERT_EQ(raw->as<
AGIUnit>(),
nullptr);
171 Module& raw_ref = *raw.get();
172 ASSERT_EQ(raw_ref.
as<Linear>(), module.get());
174 ASSERT_EQ(raw_ref.
as<
Module>(), module.get());
175 ASSERT_EQ(raw_ref.
as<
AGIUnit>(),
nullptr);
176 if (
auto* linear = raw_ref.
as<Linear>()) {
177 ASSERT_EQ(linear->weight.ndimension(), 2);
181 ASSERT_EQ(unit.as<Linear>(),
nullptr);
183 ASSERT_EQ(unit.as<
AGIUnit>(), &unit);
187 Linear module(128, 64);
188 for (
auto& parameter : module->parameters()) {
190 ASSERT_EQ(parameter.dtype(), torch::kFloat32);
193 module->to({torch::kCUDA, 0});
194 for (
auto& parameter : module->parameters()) {
195 ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
196 ASSERT_EQ(parameter.device().index(), 0);
198 module->to({torch::kCUDA, 1});
199 for (
auto& parameter : module->parameters()) {
200 ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
201 ASSERT_EQ(parameter.device().index(), 1);
206 for (
auto& parameter : module->parameters()) {
207 ASSERT_EQ(parameter.device().type(), torch::Device::Type::CPU);
211 module->to(torch::kInt32);
212 for (
auto& parameter : module->parameters()) {
213 ASSERT_EQ(parameter.dtype(), torch::kInt32);
217 module->to(torch::kFloat64);
218 for (
auto& parameter : module->parameters()) {
219 ASSERT_EQ(parameter.dtype(), torch::kFloat64);
224 for (
auto& parameter : module->parameters()) {
225 ASSERT_EQ(parameter.device().type(), torch::Device::Type::CUDA);
226 ASSERT_EQ(parameter.device().index(), 1);
228 for (
auto& parameter : module->parameters()) {
229 ASSERT_EQ(parameter.dtype(), torch::kUInt8);
234 TEST_F(
ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
235 struct UnCloneable :
Module {};
237 ASSERT_THROWS_WITH(module.clone(),
"clone() has not been implemented");
240 TEST_F(
ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
242 std::shared_ptr<Module> clone(
244 torch::nullopt)
const override {
249 ASSERT_NO_THROW({ module.
clone(); });
252 TEST_F(
ModuleTest, CloneCreatesDistinctParameters) {
257 void reset()
override {
258 l1 = register_module(
"l1", Linear(10, 3));
259 l2 = register_module(
"l2", Linear(3, 5));
260 l3 = register_module(
"l3", Linear(5, 100));
261 buffer = register_buffer(
"buf", torch::ones({2, 2}));
264 Linear l1{
nullptr}, l2{
nullptr}, l3{
nullptr};
268 auto module = std::make_shared<TestModule>();
272 auto module2 = module->clone();
273 auto params1 = module->named_parameters();
274 auto params2 = module2->named_parameters();
275 ASSERT_EQ(params1.size(), 6);
276 ASSERT_EQ(params2.size(), 6);
277 for (
auto& param : params1) {
278 ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
279 ASSERT_TRUE(param->allclose(params2[param.key()]));
282 for (
auto& param : params1) {
283 ASSERT_FALSE(param->allclose(params2[param.key()]));
286 auto buffers1 = module->named_buffers();
287 auto buffers2 = module2->named_buffers();
288 ASSERT_EQ(buffers1.size(), 1);
289 ASSERT_EQ(buffers2.size(), 1);
290 for (
auto& buffer : buffers1) {
291 ASSERT_FALSE(pointer_equal(buffer.value(), buffers2[buffer.key()]));
292 ASSERT_TRUE(buffer->allclose(buffers2[buffer.key()]));
295 for (
auto& buffer : buffers1) {
296 ASSERT_FALSE(buffer->allclose(buffers2[buffer.key()]));
300 TEST_F(
ModuleTest, ClonePreservesExternalReferences) {
305 void reset()
override {
306 weight = register_parameter(
"weight", torch::ones({4, 4}));
310 auto module = std::make_shared<TestModule>();
316 pointer_equal(module->weight, module->named_parameters()[
"weight"]));
317 ASSERT_TRUE(module->weight.allclose(module->named_parameters()[
"weight"]));
319 auto module2 = std::dynamic_pointer_cast<
TestModule>(
320 std::shared_ptr<Module>(module->clone()));
321 ASSERT_FALSE(pointer_equal(module2->weight, module->weight));
323 pointer_equal(module2->weight, module2->named_parameters()[
"weight"]));
324 ASSERT_TRUE(module2->weight.allclose(module2->named_parameters()[
"weight"]));
325 ASSERT_TRUE(module2->weight.allclose(module->weight));
327 pointer_equal(module2->weight, module->named_parameters()[
"weight"]));
330 TEST_F(
ModuleTest, CloneCopiesTheValuesOfVariablesOfSubmodules) {
335 void reset()
override {
336 weight = register_parameter(
"weight", torch::ones({4, 4}));
342 struct NestedModule :
public Cloneable<NestedModule> {
346 void reset()
override {
347 module = register_module(
"module", std::make_shared<TestModule>());
349 std::shared_ptr<TestModule> module;
352 auto a = std::make_shared<NestedModule>();
355 a->module->weight += 1;
356 a->module->value = 123;
359 auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
361 ASSERT_FALSE(pointer_equal(b->module->weight, a->module->weight));
362 ASSERT_TRUE(pointer_equal(
363 b->module->weight, b->module->named_parameters()[
"weight"]));
365 b->module->named_parameters()[
"weight"].allclose(a->module->weight));
366 ASSERT_TRUE(b->module->weight.allclose(a->module->weight));
367 ASSERT_EQ(b->module->value, a->module->value);
370 TEST_F(
ModuleTest, CloneToDevicePreservesTheDeviceOfParameters_CUDA) {
375 void reset()
override {
376 l1 = register_module(
"l1", Linear(10, 3));
377 l2 = register_module(
"l2", Linear(3, 5));
378 l3 = register_module(
"l3", Linear(5, 100));
379 buffer = register_buffer(
"buf", torch::ones({2, 2}));
382 Linear l1{
nullptr}, l2{
nullptr}, l3{
nullptr};
391 auto clone = m.
clone();
392 for (
const auto& parameter : clone->parameters()) {
393 ASSERT_EQ(parameter.device().type(), device.
type());
394 ASSERT_EQ(parameter.device().index(), device.
index());
396 for (
const auto& buffer : clone->buffers()) {
397 ASSERT_EQ(buffer.device().type(), device.
type());
398 ASSERT_EQ(buffer.device().index(), device.
index());
404 CloningToAParticularDevicePlacesAllParametersThere_MultiCUDA) {
409 void reset()
override {
410 l1 = register_module(
"l1", Linear(10, 3));
411 l2 = register_module(
"l2", Linear(3, 5));
412 l3 = register_module(
"l3", Linear(5, 100));
413 buffer = register_buffer(
"buf", torch::ones({2, 2}));
416 Linear l1{
nullptr}, l2{
nullptr}, l3{
nullptr};
423 auto clone = m.
clone(device);
424 for (
const auto& parameter : clone->parameters()) {
425 ASSERT_EQ(parameter.device().type(), device.
type());
426 ASSERT_EQ(parameter.device().index(), device.
index());
428 for (
const auto& buffer : clone->buffers()) {
429 ASSERT_EQ(buffer.device().type(), device.
type());
430 ASSERT_EQ(buffer.device().index(), device.
index());
436 a = register_parameter(
"a", torch::zeros({2, 2}));
437 b = register_parameter(
"b", torch::ones({2, 2}));
438 c = register_parameter(
"c", torch::ones({2, 2}) * 2);
444 TEST_F(
ModuleTest, HasCorrectNumberOfParameters) {
450 TEST_F(
ModuleTest, ContainsParametersWithTheCorrectName) {
453 ASSERT_TRUE(parameters.contains(
"a"));
454 ASSERT_TRUE(parameters.contains(
"b"));
455 ASSERT_TRUE(parameters.contains(
"c"));
460 a = register_buffer(
"a", torch::zeros({2, 2}));
461 b = register_buffer(
"b", torch::ones({2, 2}));
462 c = register_buffer(
"c", torch::ones({2, 2}) * 2);
468 TEST_F(
ModuleTest, HasCorrectNumberOfBuffers) {
470 ASSERT_EQ(module.
buffers().size(), 3);
474 TEST_F(
ModuleTest, ContainsBuffersWithTheCorrectName) {
477 ASSERT_TRUE(buffers.contains(
"a"));
478 ASSERT_TRUE(buffers.contains(
"b"));
479 ASSERT_TRUE(buffers.contains(
"c"));
484 AImpl(
int x) : x_(x) {}
491 DefaultConstructorOfModuleHolderCallsDefaultConstructorOfImpl) {
494 ASSERT_FALSE(a.is_empty());
495 ASSERT_EQ(a->x_, 123);
500 ValueConstructorOfModuleHolderCallsCorrectConstructorInImpl) {
503 ASSERT_FALSE(a.is_empty());
507 TEST_F(
ModuleTest, NullptrConstructorLeavesTheModuleHolderInEmptyState) {
510 ASSERT_TRUE(a.is_empty());
511 ASSERT_THROWS_WITH(a->x_,
"Accessing empty ModuleHolder");
516 p1 = register_parameter(
"p1", torch::randn({size}));
517 p2 = register_parameter(
"p2", torch::randn({size}));
518 b1 = register_buffer(
"b1", torch::randn({size}));
519 b2 = register_buffer(
"b2", torch::randn({size}));
529 TEST_F(
ModuleTest, ModulesReturnsExpectedSubmodulesForFlatModel) {
531 std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
532 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
533 model.ptr(), model[0], model[1], model[2]};
534 ASSERT_EQ(modules.size(), expected.size());
535 for (
size_t i = 0; i < expected.size(); ++i) {
537 ASSERT_EQ(modules[i].
get(), expected[i].
get());
541 TEST_F(
ModuleTest, ModulesExcludesSelfWhenIncludeSelfSetToFalse) {
543 std::vector<std::shared_ptr<torch::nn::Module>> modules =
544 model->modules(
false);
545 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
546 model[0], model[1], model[2]};
547 ASSERT_EQ(modules.size(), expected.size());
548 for (
size_t i = 0; i < expected.size(); ++i) {
550 ASSERT_EQ(modules[i].
get(), expected[i].
get());
554 TEST_F(
ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForFlatModel) {
557 model->named_modules();
558 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
559 model.ptr(), model[0], model[1], model[2]};
560 ASSERT_EQ(modules.
size(), expected.size());
561 for (
size_t i = 0; i < expected.size(); ++i) {
563 ASSERT_EQ(modules[i].key(), i ? std::to_string(i - 1) : std::string());
564 ASSERT_EQ(modules[i].value().
get(), expected[i].
get());
568 TEST_F(
ModuleTest, NamedModulesExcludesSelfWhenIncludeSelfSetToFalse) {
571 model->named_modules(
572 std::string(),
false);
573 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
574 model[0], model[1], model[2]};
575 ASSERT_EQ(modules.
size(), expected.size());
576 for (
size_t i = 0; i < expected.size(); ++i) {
578 ASSERT_EQ(modules[i].key(), std::to_string(i));
579 ASSERT_EQ(modules[i].value().
get(), expected[i].
get());
583 TEST_F(
ModuleTest, ChildrenReturnsExpectedSubmodulesForFlatModel) {
585 std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
586 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
587 model[0], model[1], model[2]};
588 ASSERT_EQ(modules.size(), expected.size());
589 for (
size_t i = 0; i < expected.size(); ++i) {
591 ASSERT_EQ(modules[i].
get(), expected[i].
get());
595 ASSERT_EQ(modules, model->modules(
false));
598 TEST_F(
ModuleTest, NamedChildrenReturnsExpectedNamedSubmodulesForFlatModel) {
601 model->named_children();
602 std::vector<std::shared_ptr<torch::nn::Module>> expected = {
603 model[0], model[1], model[2]};
604 ASSERT_EQ(modules.
size(), expected.size());
605 for (
size_t i = 0; i < expected.size(); ++i) {
607 ASSERT_EQ(modules[i].key(), std::to_string(i));
608 ASSERT_EQ(modules[i].value().
get(), expected[i].
get());
612 TEST_F(
ModuleTest, ParametersReturnsExpectedTensorsForFlatModel) {
614 std::vector<torch::Tensor> parameters = module.
parameters();
615 ASSERT_EQ(parameters.size(), 2);
616 ASSERT_EQ(parameters[0].data<float>(), module.p1.data<
float>());
617 ASSERT_EQ(parameters[1].data<float>(), module.p2.data<
float>());
620 TEST_F(
ModuleTest, NamedParametersReturnsExpectedTensorsForFlatModel) {
624 ASSERT_EQ(parameters.
size(), 2);
625 ASSERT_EQ(parameters[0].key(),
"p1");
626 ASSERT_EQ(parameters[0]->data<float>(), module.p1.data<
float>());
627 ASSERT_EQ(parameters[1].key(),
"p2");
628 ASSERT_EQ(parameters[1]->data<float>(), module.p2.data<
float>());
631 TEST_F(
ModuleTest, BuffersReturnsExpectedTensorsForFlatModel) {
633 std::vector<torch::Tensor> buffers = module.
buffers();
634 ASSERT_EQ(buffers.size(), 2);
635 ASSERT_EQ(buffers[0].data<float>(), module.b1.data<
float>());
636 ASSERT_EQ(buffers[1].data<float>(), module.b2.data<
float>());
639 TEST_F(
ModuleTest, NamedBuffersReturnsExpectedTensorsForFlatModel) {
643 ASSERT_EQ(buffers.
size(), 2);
644 ASSERT_EQ(buffers[0].key(),
"b1");
645 ASSERT_EQ(buffers[0]->data<float>(), module.b1.data<
float>());
646 ASSERT_EQ(buffers[1].key(),
"b2");
647 ASSERT_EQ(buffers[1]->data<float>(), module.b2.data<
float>());
651 TestContainer(int64_t number, std::vector<TestContainer> modules = {})
652 : tensor(torch::tensor(number)) {
653 for (
size_t i = 0; i < modules.size(); ++i) {
656 std::make_shared<TestContainer>(std::move(modules[i])));
662 int64_t get_test_container_item(std::shared_ptr<torch::nn::Module> module) {
664 ->tensor.item<int64_t>();
667 std::shared_ptr<TestContainer> make_deeply_nested_test_container() {
678 std::vector<std::pair<std::string, int64_t>>
679 make_key_value_pairs_for_deeply_nested_container() {
680 return {{
"test_prefix", 0},
681 {
"test_prefix.0", 1},
682 {
"test_prefix.0.0", 2},
683 {
"test_prefix.0.1", 3},
684 {
"test_prefix.1", 4},
685 {
"test_prefix.2", 5},
686 {
"test_prefix.2.0", 6},
687 {
"test_prefix.2.1", 7},
688 {
"test_prefix.2.1.0", 8},
689 {
"test_prefix.2.1.1", 9}};
692 TEST_F(
ModuleTest, ModulesReturnsExpectedSubmodulesForDeepModel) {
693 auto model = make_deeply_nested_test_container();
694 std::vector<std::shared_ptr<torch::nn::Module>> modules = model->modules();
696 ASSERT_EQ(modules.size(), 10);
697 for (
size_t i = 0; i < modules.size(); ++i) {
698 ASSERT_EQ(get_test_container_item(modules[i]), i);
702 TEST_F(
ModuleTest, NamedModulesReturnsExpectedNamedSubmodulesForDeepModel) {
703 auto model = make_deeply_nested_test_container();
705 model->named_modules(
"test_prefix");
706 auto expected = make_key_value_pairs_for_deeply_nested_container();
708 ASSERT_EQ(modules.
size(), expected.size());
710 for (
size_t i = 0; i < expected.size(); ++i) {
711 ASSERT_EQ(modules[i].key(), expected[i].first);
712 ASSERT_EQ(get_test_container_item(modules[i].value()), expected[i].second);
716 TEST_F(
ModuleTest, ChildrensReturnsExpectedSubmodulesForDeepModel) {
717 auto model = make_deeply_nested_test_container();
718 std::vector<std::shared_ptr<torch::nn::Module>> modules = model->children();
720 ASSERT_EQ(modules.size(), 3);
721 ASSERT_EQ(get_test_container_item(modules[0]), 1);
722 ASSERT_EQ(get_test_container_item(modules[1]), 4);
723 ASSERT_EQ(get_test_container_item(modules[2]), 5);
726 TEST_F(
ModuleTest, NamedChildrensReturnsExpectedNamedSubmodulesForDeepModel) {
727 auto model = make_deeply_nested_test_container();
729 model->named_children();
731 ASSERT_EQ(modules.
size(), 3);
733 ASSERT_EQ(get_test_container_item(modules[0].value()), 1);
734 ASSERT_EQ(modules[0].key(),
"0");
736 ASSERT_EQ(get_test_container_item(modules[1].value()), 4);
737 ASSERT_EQ(modules[1].key(),
"1");
739 ASSERT_EQ(get_test_container_item(modules[2].value()), 5);
740 ASSERT_EQ(modules[2].key(),
"2");
743 TEST_F(
ModuleTest, ModuleApplyIteratesCorreclty) {
744 auto model = make_deeply_nested_test_container();
747 ASSERT_EQ(module.
as<
TestContainer>()->tensor.item<int64_t>(), index++);
749 ASSERT_EQ(index, 10);
752 TEST_F(
ModuleTest, ConstModuleApplyIteratesCorreclty) {
753 std::shared_ptr<const TestContainer> model =
754 make_deeply_nested_test_container();
757 ASSERT_EQ(module.
as<
TestContainer>()->tensor.item<int64_t>(), index++);
759 ASSERT_EQ(index, 10);
762 TEST_F(
ModuleTest, NamedModuleApplyIteratesCorreclty) {
763 auto model = make_deeply_nested_test_container();
764 auto expected = make_key_value_pairs_for_deeply_nested_container();
768 ASSERT_EQ(name, expected[index].first);
771 expected[index++].second);
774 ASSERT_EQ(index, 10);
777 TEST_F(
ModuleTest, ConstNamedModuleApplyIteratesCorreclty) {
778 std::shared_ptr<const TestContainer> model =
779 make_deeply_nested_test_container();
780 auto expected = make_key_value_pairs_for_deeply_nested_container();
785 ASSERT_EQ(name, expected[index].first);
788 expected[index++].second);
791 ASSERT_EQ(index, 10);
794 TEST_F(
ModuleTest, ModulePointerApplyIteratesCorreclty) {
795 auto model = make_deeply_nested_test_container();
797 model->apply([&index](
const std::shared_ptr<torch::nn::Module>& module) {
798 ASSERT_EQ(get_test_container_item(module), index++);
800 ASSERT_EQ(index, 10);
803 TEST_F(
ModuleTest, NamedModulePointerApplyIteratesCorreclty) {
804 auto model = make_deeply_nested_test_container();
805 auto expected = make_key_value_pairs_for_deeply_nested_container();
809 const std::string& name,
810 const std::shared_ptr<torch::nn::Module>& module) {
811 ASSERT_EQ(name, expected[index].first);
812 ASSERT_EQ(get_test_container_item(module), expected[index++].second);
815 ASSERT_EQ(index, 10);
818 TEST_F(
ModuleTest, ThrowsWhenAttemptingtoGetTopLevelModuleAsSharedPtr) {
823 "It looks like you attempted to retrieve " 824 "your top-level module as a shared_ptr")
828 ASSERT_NO_THROW(module.
modules(
false));
831 auto module = std::make_shared<TestModule>(1);
832 ASSERT_NO_THROW(module->modules());
842 void pretty_print(std::ostream& stream)
const override {
843 stream <<
"TestModule(x=" << x_ <<
", y=" << y_ <<
")";
853 ASSERT_EQ(c10::str(
TestModule(1, 3.14)),
"TestModule(x=1, y=3.14)");
861 TORCH_MODULE(ModuleWithNonTensorForward);
863 TEST_F(
ModuleTest, CanCallForwardOnNonTensorForwardThroughPimpl) {
864 ModuleWithNonTensorForward m;
865 ASSERT_EQ(m(torch::ones(123)), 123);
Applies a linear transformation with optional bias.
size_t size() const noexcept
Returns the number of items currently stored in the OrderedDict.
std::vector< Tensor > buffers(bool recurse=true) const
Returns the buffers of this Module and if recurse is true, also recursively of every submodule...
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
OrderedDict< std::string, Tensor > named_parameters(bool recurse=true) const
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is tr...
const std::string & name() const noexcept
Returns the name of the Module.
virtual void zero_grad()
Recursively zeros out the grad value of each registered parameter.
std::vector< std::shared_ptr< Module > > modules(bool include_self=true) const
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true...
Represents a a compute device on which a tensor is located.
std::vector< Tensor > parameters(bool recurse=true) const
Returns the parameters of this Module and if recurse is true, also recursively of every submodule...
std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const override
Performs a recursive "deep copy" of the Module, such that all parameters and submodules in the cloned...
Tensor & register_buffer(std::string name, Tensor tensor)
Registers a buffer with this Module.
virtual std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodule...
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
does bound shape inference given a C2 net.
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
The base class for all modules in PyTorch.
OrderedDict< std::string, Tensor > named_buffers(bool recurse=true) const
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true ...
DeviceIndex index() const noexcept
Returns the optional index.
ModuleType::ContainedType * as() noexcept
Attempts to cast this Module to the given ModuleType.
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module.
An ordered dictionary implementation, akin to Python's OrderedDict.
DeviceType type() const noexcept
Returns the type of device this is.