1 #include <gtest/gtest.h> 3 #include <torch/nn/modules.h> 4 #include <torch/nn/modules/batchnorm.h> 5 #include <torch/nn/modules/conv.h> 6 #include <torch/nn/modules/dropout.h> 7 #include <torch/nn/modules/linear.h> 8 #include <torch/nn/modules/rnn.h> 9 #include <torch/nn/modules/sequential.h> 10 #include <torch/types.h> 11 #include <torch/utils.h> 17 #include <test/cpp/api/support.h> 26 explicit M(
int value_) : value(value_) {}
32 Sequential sequential(
33 std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
34 ASSERT_EQ(sequential->size(), 3);
39 explicit M(
int value_) : value(value_) {}
46 Sequential sequential(
M(1),
M(2),
M(3));
47 ASSERT_EQ(sequential->size(), 3);
51 explicit MImpl(
int value_) : value(value_) {}
63 Sequential sequential(
M(1),
M(2),
M(3));
64 ASSERT_EQ(sequential->size(), 3);
69 explicit M(
int value_) : value(value_) {}
75 Sequential sequential;
76 ASSERT_EQ(sequential->size(), 0);
77 ASSERT_TRUE(sequential->is_empty());
78 sequential->push_back(Linear(3, 4));
79 ASSERT_EQ(sequential->size(), 1);
80 sequential->push_back(std::make_shared<M>(1));
81 ASSERT_EQ(sequential->size(), 2);
82 sequential->push_back(
M(2));
83 ASSERT_EQ(sequential->size(), 3);
88 explicit M(
int value_) : value(value_) {}
94 std::vector<std::shared_ptr<M>> modules = {
95 std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
97 Sequential sequential;
98 for (
auto& module : modules) {
99 sequential->push_back(module);
101 ASSERT_EQ(sequential->size(), 3);
104 for (
size_t i = 0; i < modules.size(); ++i) {
105 ASSERT_EQ(&sequential->at<
M>(i), modules[i].get());
110 sequential->at<
M>(modules.size() + 1),
"Index out of range");
112 sequential->at<
M>(modules.size() + 1000000),
"Index out of range");
117 explicit M(
int value_) : value(value_) {}
123 std::vector<std::shared_ptr<M>> modules = {
124 std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
126 Sequential sequential;
127 for (
auto& module : modules) {
128 sequential->push_back(module);
130 ASSERT_EQ(sequential->size(), 3);
133 for (
size_t i = 0; i < modules.size(); ++i) {
134 ASSERT_EQ(sequential->ptr(i).get(), modules[i].get());
135 ASSERT_EQ(sequential[i].
get(), modules[i].
get());
136 ASSERT_EQ(sequential->ptr<
M>(i).get(), modules[i].get());
140 ASSERT_THROWS_WITH(sequential->ptr(modules.size() + 1),
"Index out of range");
142 sequential->ptr(modules.size() + 1000000),
"Index out of range");
145 TEST_F(
SequentialTest, CallingForwardOnEmptySequentialIsDisallowed) {
148 empty->forward<
int>(),
"Cannot call forward() on an empty Sequential");
153 explicit MockModule(
int value) : expected(value) {}
155 int forward(
int value) {
156 assert(value == expected);
161 Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
163 ASSERT_EQ(sequential->forward<
int>(1), 4);
166 TEST_F(
SequentialTest, CallingForwardWithTheWrongReturnTypeThrows) {
173 Sequential sequential(
M{});
174 ASSERT_EQ(sequential->forward<
int>(), 5);
176 sequential->forward<
float>(),
177 "The type of the return value is int, but you asked for type float");
187 Sequential sequential(
M{});
188 auto variable = torch::ones({3, 3}, torch::requires_grad());
189 ASSERT_TRUE(sequential->forward(variable).equal(variable));
193 torch::manual_seed(0);
194 Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
196 auto x = torch::randn({1000, 10}, torch::requires_grad());
197 auto y = sequential->forward(x);
198 ASSERT_EQ(y.ndimension(), 2);
199 ASSERT_EQ(y.size(0), 1000);
200 ASSERT_EQ(y.size(1), 100);
204 Sequential sequential(
234 Sequential a(
A{},
B{});
235 Sequential b(
C{},
D{});
238 ASSERT_EQ(a->size(), 4);
239 ASSERT_TRUE(a[0]->as<A>());
240 ASSERT_TRUE(a[1]->as<B>());
241 ASSERT_TRUE(a[2]->as<C>());
242 ASSERT_TRUE(a[3]->as<D>());
244 ASSERT_EQ(b->size(), 2);
245 ASSERT_TRUE(b[0]->as<C>());
246 ASSERT_TRUE(b[1]->as<D>());
248 std::vector<std::shared_ptr<A>> c = {std::make_shared<A>(),
249 std::make_shared<A>()};
252 ASSERT_EQ(b->size(), 4);
253 ASSERT_TRUE(b[0]->as<C>());
254 ASSERT_TRUE(b[1]->as<D>());
255 ASSERT_TRUE(b[2]->as<A>());
256 ASSERT_TRUE(b[3]->as<A>());
260 Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
261 Sequential second(first);
263 ASSERT_EQ(first.get(), second.get());
264 ASSERT_EQ(first->size(), second->size());
265 ASSERT_TRUE(std::equal(
270 return &first == &second;
275 Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
278 ASSERT_EQ(sequential->size(), clone->size());
280 for (
size_t i = 0; i < sequential->size(); ++i) {
282 ASSERT_EQ(sequential[i]->name(), clone[i]->name());
284 ASSERT_NE(sequential[i], clone[i]);
291 auto params1 = sequential->named_parameters();
292 auto params2 = clone->named_parameters();
293 ASSERT_EQ(params1.size(), params2.size());
294 for (
auto& param : params1) {
295 ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
296 ASSERT_EQ(param->device(), params2[param.key()].device());
297 ASSERT_TRUE(param->allclose(params2[param.key()]));
300 for (
auto& param : params1) {
301 ASSERT_FALSE(param->allclose(params2[param.key()]));
306 Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), FeatureDropout(0.5));
308 auto modules = sequential->children();
309 ASSERT_TRUE(modules[0]->as<Linear>());
310 ASSERT_TRUE(modules[1]->as<Conv2d>());
311 ASSERT_TRUE(modules[2]->as<FeatureDropout>());
315 Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
318 std::dynamic_pointer_cast<
SequentialImpl>(sequential->clone(device));
319 for (
const auto& p : clone->parameters()) {
320 ASSERT_EQ(p.device(), device);
322 for (
const auto& b : clone->buffers()) {
323 ASSERT_EQ(b.device(), device);
328 Sequential sequential(
336 c10::str(sequential),
337 "torch::nn::Sequential(\n" 338 " (0): torch::nn::Linear(in=10, out=3, with_bias=true)\n" 339 " (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n" 340 " (2): torch::nn::Dropout(rate=0.5)\n" 341 " (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n" 342 " (4): torch::nn::Embedding(count=4, dimension=10)\n" 343 " (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
Represents a a compute device on which a tensor is located.
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
does bound shape inference given a C2 net.
The base class for all modules in PyTorch.
Stores a type erased Module.
A list of Modules that acts as a Module itself.