Caffe2 - C++ API
A deep learning, cross platform ML framework
sequential.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/nn/modules.h>
4 #include <torch/nn/modules/batchnorm.h>
5 #include <torch/nn/modules/conv.h>
6 #include <torch/nn/modules/dropout.h>
7 #include <torch/nn/modules/linear.h>
8 #include <torch/nn/modules/rnn.h>
9 #include <torch/nn/modules/sequential.h>
10 #include <torch/types.h>
11 #include <torch/utils.h>
12 
13 #include <algorithm>
14 #include <memory>
15 #include <vector>
16 
17 #include <test/cpp/api/support.h>
18 
19 using namespace torch::nn;
20 using namespace torch::test;
21 
23 
24 TEST_F(SequentialTest, ConstructsFromSharedPointer) {
25  struct M : torch::nn::Module {
26  explicit M(int value_) : value(value_) {}
27  int value;
28  int forward() {
29  return value;
30  }
31  };
32  Sequential sequential(
33  std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
34  ASSERT_EQ(sequential->size(), 3);
35 }
36 
37 TEST_F(SequentialTest, ConstructsFromConcreteType) {
38  struct M : torch::nn::Module {
39  explicit M(int value_) : value(value_) {}
40  int value;
41  int forward() {
42  return value;
43  }
44  };
45 
46  Sequential sequential(M(1), M(2), M(3));
47  ASSERT_EQ(sequential->size(), 3);
48 }
49 TEST_F(SequentialTest, ConstructsFromModuleHolder) {
50  struct MImpl : torch::nn::Module {
51  explicit MImpl(int value_) : value(value_) {}
52  int forward() {
53  return value;
54  }
55  int value;
56  };
57 
58  struct M : torch::nn::ModuleHolder<MImpl> {
61  };
62 
63  Sequential sequential(M(1), M(2), M(3));
64  ASSERT_EQ(sequential->size(), 3);
65 }
66 
67 TEST_F(SequentialTest, PushBackAddsAnElement) {
68  struct M : torch::nn::Module {
69  explicit M(int value_) : value(value_) {}
70  int forward() {
71  return value;
72  }
73  int value;
74  };
75  Sequential sequential;
76  ASSERT_EQ(sequential->size(), 0);
77  ASSERT_TRUE(sequential->is_empty());
78  sequential->push_back(Linear(3, 4));
79  ASSERT_EQ(sequential->size(), 1);
80  sequential->push_back(std::make_shared<M>(1));
81  ASSERT_EQ(sequential->size(), 2);
82  sequential->push_back(M(2));
83  ASSERT_EQ(sequential->size(), 3);
84 }
85 
86 TEST_F(SequentialTest, AccessWithAt) {
87  struct M : torch::nn::Module {
88  explicit M(int value_) : value(value_) {}
89  int forward() {
90  return value;
91  }
92  int value;
93  };
94  std::vector<std::shared_ptr<M>> modules = {
95  std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
96 
97  Sequential sequential;
98  for (auto& module : modules) {
99  sequential->push_back(module);
100  }
101  ASSERT_EQ(sequential->size(), 3);
102 
103  // returns the correct module for a given index
104  for (size_t i = 0; i < modules.size(); ++i) {
105  ASSERT_EQ(&sequential->at<M>(i), modules[i].get());
106  }
107 
108  // throws for a bad index
109  ASSERT_THROWS_WITH(
110  sequential->at<M>(modules.size() + 1), "Index out of range");
111  ASSERT_THROWS_WITH(
112  sequential->at<M>(modules.size() + 1000000), "Index out of range");
113 }
114 
115 TEST_F(SequentialTest, AccessWithPtr) {
116  struct M : torch::nn::Module {
117  explicit M(int value_) : value(value_) {}
118  int forward() {
119  return value;
120  }
121  int value;
122  };
123  std::vector<std::shared_ptr<M>> modules = {
124  std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3)};
125 
126  Sequential sequential;
127  for (auto& module : modules) {
128  sequential->push_back(module);
129  }
130  ASSERT_EQ(sequential->size(), 3);
131 
132  // returns the correct module for a given index
133  for (size_t i = 0; i < modules.size(); ++i) {
134  ASSERT_EQ(sequential->ptr(i).get(), modules[i].get());
135  ASSERT_EQ(sequential[i].get(), modules[i].get());
136  ASSERT_EQ(sequential->ptr<M>(i).get(), modules[i].get());
137  }
138 
139  // throws for a bad index
140  ASSERT_THROWS_WITH(sequential->ptr(modules.size() + 1), "Index out of range");
141  ASSERT_THROWS_WITH(
142  sequential->ptr(modules.size() + 1000000), "Index out of range");
143 }
144 
145 TEST_F(SequentialTest, CallingForwardOnEmptySequentialIsDisallowed) {
146  Sequential empty;
147  ASSERT_THROWS_WITH(
148  empty->forward<int>(), "Cannot call forward() on an empty Sequential");
149 }
150 
151 TEST_F(SequentialTest, CallingForwardChainsCorrectly) {
152  struct MockModule : torch::nn::Module {
153  explicit MockModule(int value) : expected(value) {}
154  int expected;
155  int forward(int value) {
156  assert(value == expected);
157  return value + 1;
158  }
159  };
160 
161  Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
162 
163  ASSERT_EQ(sequential->forward<int>(1), 4);
164 }
165 
166 TEST_F(SequentialTest, CallingForwardWithTheWrongReturnTypeThrows) {
167  struct M : public torch::nn::Module {
168  int forward() {
169  return 5;
170  }
171  };
172 
173  Sequential sequential(M{});
174  ASSERT_EQ(sequential->forward<int>(), 5);
175  ASSERT_THROWS_WITH(
176  sequential->forward<float>(),
177  "The type of the return value is int, but you asked for type float");
178 }
179 
180 TEST_F(SequentialTest, TheReturnTypeOfForwardDefaultsToTensor) {
181  struct M : public torch::nn::Module {
182  torch::Tensor forward(torch::Tensor v) {
183  return v;
184  }
185  };
186 
187  Sequential sequential(M{});
188  auto variable = torch::ones({3, 3}, torch::requires_grad());
189  ASSERT_TRUE(sequential->forward(variable).equal(variable));
190 }
191 
192 TEST_F(SequentialTest, ForwardReturnsTheLastValue) {
193  torch::manual_seed(0);
194  Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
195 
196  auto x = torch::randn({1000, 10}, torch::requires_grad());
197  auto y = sequential->forward(x);
198  ASSERT_EQ(y.ndimension(), 2);
199  ASSERT_EQ(y.size(0), 1000);
200  ASSERT_EQ(y.size(1), 100);
201 }
202 
203 TEST_F(SequentialTest, SanityCheckForHoldingStandardModules) {
204  Sequential sequential(
205  Linear(10, 3),
206  Conv2d(1, 2, 3),
207  Dropout(0.5),
208  BatchNorm(5),
209  Embedding(4, 10),
210  LSTM(4, 5));
211 }
212 
213 TEST_F(SequentialTest, ExtendPushesModulesFromOtherSequential) {
214  struct A : torch::nn::Module {
215  int forward(int x) {
216  return x;
217  }
218  };
219  struct B : torch::nn::Module {
220  int forward(int x) {
221  return x;
222  }
223  };
224  struct C : torch::nn::Module {
225  int forward(int x) {
226  return x;
227  }
228  };
229  struct D : torch::nn::Module {
230  int forward(int x) {
231  return x;
232  }
233  };
234  Sequential a(A{}, B{});
235  Sequential b(C{}, D{});
236  a->extend(*b);
237 
238  ASSERT_EQ(a->size(), 4);
239  ASSERT_TRUE(a[0]->as<A>());
240  ASSERT_TRUE(a[1]->as<B>());
241  ASSERT_TRUE(a[2]->as<C>());
242  ASSERT_TRUE(a[3]->as<D>());
243 
244  ASSERT_EQ(b->size(), 2);
245  ASSERT_TRUE(b[0]->as<C>());
246  ASSERT_TRUE(b[1]->as<D>());
247 
248  std::vector<std::shared_ptr<A>> c = {std::make_shared<A>(),
249  std::make_shared<A>()};
250  b->extend(c);
251 
252  ASSERT_EQ(b->size(), 4);
253  ASSERT_TRUE(b[0]->as<C>());
254  ASSERT_TRUE(b[1]->as<D>());
255  ASSERT_TRUE(b[2]->as<A>());
256  ASSERT_TRUE(b[3]->as<A>());
257 }
258 
259 TEST_F(SequentialTest, HasReferenceSemantics) {
260  Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
261  Sequential second(first);
262 
263  ASSERT_EQ(first.get(), second.get());
264  ASSERT_EQ(first->size(), second->size());
265  ASSERT_TRUE(std::equal(
266  first->begin(),
267  first->end(),
268  second->begin(),
269  [](const AnyModule& first, const AnyModule& second) {
270  return &first == &second;
271  }));
272 }
273 
274 TEST_F(SequentialTest, IsCloneable) {
275  Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
276  Sequential clone =
277  std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
278  ASSERT_EQ(sequential->size(), clone->size());
279 
280  for (size_t i = 0; i < sequential->size(); ++i) {
281  // The modules should be the same kind (type).
282  ASSERT_EQ(sequential[i]->name(), clone[i]->name());
283  // But not pointer-equal (distinct objects).
284  ASSERT_NE(sequential[i], clone[i]);
285  }
286 
287  // Verify that the clone is deep, i.e. parameters of modules are cloned too.
288 
289  torch::NoGradGuard no_grad;
290 
291  auto params1 = sequential->named_parameters();
292  auto params2 = clone->named_parameters();
293  ASSERT_EQ(params1.size(), params2.size());
294  for (auto& param : params1) {
295  ASSERT_FALSE(pointer_equal(param.value(), params2[param.key()]));
296  ASSERT_EQ(param->device(), params2[param.key()].device());
297  ASSERT_TRUE(param->allclose(params2[param.key()]));
298  param->add_(2);
299  }
300  for (auto& param : params1) {
301  ASSERT_FALSE(param->allclose(params2[param.key()]));
302  }
303 }
304 
305 TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
306  Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), FeatureDropout(0.5));
307 
308  auto modules = sequential->children();
309  ASSERT_TRUE(modules[0]->as<Linear>());
310  ASSERT_TRUE(modules[1]->as<Conv2d>());
311  ASSERT_TRUE(modules[2]->as<FeatureDropout>());
312 }
313 
314 TEST_F(SequentialTest, CloneToDevice_CUDA) {
315  Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
316  torch::Device device(torch::kCUDA, 0);
317  Sequential clone =
318  std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
319  for (const auto& p : clone->parameters()) {
320  ASSERT_EQ(p.device(), device);
321  }
322  for (const auto& b : clone->buffers()) {
323  ASSERT_EQ(b.device(), device);
324  }
325 }
326 
327 TEST_F(SequentialTest, PrettyPrintSequential) {
328  Sequential sequential(
329  Linear(10, 3),
330  Conv2d(1, 2, 3),
331  Dropout(0.5),
332  BatchNorm(5),
333  Embedding(4, 10),
334  LSTM(4, 5));
335  ASSERT_EQ(
336  c10::str(sequential),
337  "torch::nn::Sequential(\n"
338  " (0): torch::nn::Linear(in=10, out=3, with_bias=true)\n"
339  " (1): torch::nn::Conv2d(input_channels=1, output_channels=2, kernel_size=[3, 3], stride=[1, 1])\n"
340  " (2): torch::nn::Dropout(rate=0.5)\n"
341  " (3): torch::nn::BatchNorm(features=5, eps=1e-05, momentum=0.1, affine=true, stateful=true)\n"
342  " (4): torch::nn::Embedding(count=4, dimension=10)\n"
343  " (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
344  ")");
345 }
Definition: any.cpp:108
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
Definition: pimpl.h:26
does bound shape inference given a C2 net.
The base class for all modules in PyTorch.
Definition: module.h:62
Definition: static.cpp:64
Definition: static.cpp:58
Stores a type erased Module.
Definition: any.h:108
A list of Modules that acts as a Module itself.
Definition: sequential.h:91
Definition: static.cpp:70