Caffe2 - C++ API
A deep learning, cross platform ML framework
any.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/nn/modules/any.h>
4 #include <torch/nn/modules/linear.h>
5 #include <torch/utils.h>
6 
7 #include <test/cpp/api/support.h>
8 
9 #include <algorithm>
10 #include <string>
11 
12 using namespace torch::nn;
13 using namespace torch::detail;
14 
16 
17 TEST_F(AnyModuleTest, SimpleReturnType) {
18  struct M : torch::nn::Module {
19  int forward() {
20  return 123;
21  }
22  };
23  AnyModule any(M{});
24  ASSERT_EQ(any.forward<int>(), 123);
25 }
26 
27 TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) {
28  struct M : torch::nn::Module {
29  int forward(int x) {
30  return x;
31  }
32  };
33  AnyModule any(M{});
34  ASSERT_EQ(any.forward<int>(5), 5);
35 }
36 
37 TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) {
38  struct M : torch::nn::Module {
39  const char* forward(const char* x) {
40  return x;
41  }
42  };
43  AnyModule any(M{});
44  ASSERT_EQ(any.forward<const char*>("hello"), std::string("hello"));
45 }
46 
47 TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) {
48  struct M : torch::nn::Module {
49  std::string forward(int x, const double f) {
50  return std::to_string(static_cast<int>(x + f));
51  }
52  };
53  AnyModule any(M{});
54  int x = 4;
55  ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7"));
56 }
57 
58 TEST_F(
60  TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
61  struct M : torch::nn::Module {
62  torch::Tensor forward(
63  std::string a,
64  const std::string& b,
65  std::string&& c) {
66  const auto s = a + b + c;
67  return torch::ones({static_cast<int64_t>(s.size())});
68  }
69  };
70  AnyModule any(M{});
71  ASSERT_TRUE(
72  any.forward(std::string("a"), std::string("ab"), std::string("abc"))
73  .sum()
74  .item<int32_t>() == 6);
75 }
76 
77 TEST_F(AnyModuleTest, WrongArgumentType) {
78  struct M : torch::nn::Module {
79  int forward(float x) {
80  return x;
81  }
82  };
83  AnyModule any(M{});
84  ASSERT_THROWS_WITH(
85  any.forward(5.0),
86  "Expected argument #0 to be of type float, "
87  "but received value of type double");
88 }
89 
90 TEST_F(AnyModuleTest, WrongNumberOfArguments) {
91  struct M : torch::nn::Module {
92  int forward(int a, int b) {
93  return a + b;
94  }
95  };
96  AnyModule any(M{});
97  ASSERT_THROWS_WITH(
98  any.forward(),
99  "M's forward() method expects 2 arguments, but received 0");
100  ASSERT_THROWS_WITH(
101  any.forward(5),
102  "M's forward() method expects 2 arguments, but received 1");
103  ASSERT_THROWS_WITH(
104  any.forward(1, 2, 3),
105  "M's forward() method expects 2 arguments, but received 3");
106 }
107 
109  explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
110  int value;
111  int forward(float x) {
112  return x;
113  }
114 };
115 
116 TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) {
117  AnyModule any(M{5});
118  ASSERT_EQ(any.get<M>().value, 5);
119 }
120 
121 TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) {
122  struct N : torch::nn::Module {
123  torch::Tensor forward(torch::Tensor input) {
124  return input;
125  }
126  };
127  AnyModule any(M{5});
128  ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module");
129 }
130 
131 TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) {
132  AnyModule any(M{5});
133  auto ptr = any.ptr();
134  ASSERT_NE(ptr, nullptr);
135  ASSERT_EQ(ptr->name(), "M");
136 }
137 
138 TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) {
139  AnyModule any(M{5});
140  auto ptr = any.ptr<M>();
141  ASSERT_NE(ptr, nullptr);
142  ASSERT_EQ(ptr->value, 5);
143 }
144 
145 TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) {
146  struct N : torch::nn::Module {
147  torch::Tensor forward(torch::Tensor input) {
148  return input;
149  }
150  };
151  AnyModule any(M{5});
152  ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module");
153 }
154 
155 TEST_F(AnyModuleTest, DefaultStateIsEmpty) {
156  struct M : torch::nn::Module {
157  explicit M(int value_) : value(value_) {}
158  int value;
159  int forward(float x) {
160  return x;
161  }
162  };
163  AnyModule any;
164  ASSERT_TRUE(any.is_empty());
165  any = std::make_shared<M>(5);
166  ASSERT_FALSE(any.is_empty());
167  ASSERT_EQ(any.get<M>().value, 5);
168 }
169 
170 TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) {
171  struct M : torch::nn::Module {
172  int forward(int x) {
173  return x;
174  }
175  };
176  AnyModule any;
177  ASSERT_TRUE(any.is_empty());
178  ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule");
179  ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule");
180  ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule");
181  ASSERT_THROWS_WITH(
182  any.type_info(), "Cannot call type_info() on an empty AnyModule");
183  ASSERT_THROWS_WITH(
184  any.forward<int>(5), "Cannot call forward() on an empty AnyModule");
185 }
186 
187 TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) {
188  struct M : torch::nn::Module {
189  std::string forward(int x) {
190  return std::to_string(x);
191  }
192  };
193  struct N : torch::nn::Module {
194  int forward(float x) {
195  return 3 + x;
196  }
197  };
198  AnyModule any;
199  ASSERT_TRUE(any.is_empty());
200  any = std::make_shared<M>();
201  ASSERT_FALSE(any.is_empty());
202  ASSERT_EQ(any.forward<std::string>(5), "5");
203  any = std::make_shared<N>();
204  ASSERT_FALSE(any.is_empty());
205  ASSERT_EQ(any.forward<int>(5.0f), 8);
206 }
207 
208 TEST_F(AnyModuleTest, ConstructsFromModuleHolder) {
209  struct MImpl : torch::nn::Module {
210  explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
211  int value;
212  int forward(float x) {
213  return x;
214  }
215  };
216 
217  struct M : torch::nn::ModuleHolder<MImpl> {
220  };
221 
222  AnyModule any(M{5});
223  ASSERT_EQ(any.get<MImpl>().value, 5);
224  ASSERT_EQ(any.get<M>()->value, 5);
225 
226  AnyModule module(Linear(3, 4));
227  std::shared_ptr<Module> ptr = module.ptr();
228  Linear linear(module.get<Linear>());
229 }
230 
231 TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
232  struct M : torch::nn::Module {
233  torch::Tensor forward(torch::Tensor input) {
234  return input;
235  }
236  };
237 
238  // When you have an autograd::Variable, it should be converted to a
239  // torch::Tensor before being passed to the function (to avoid a type
240  // mismatch).
241  AnyModule any(M{});
242  ASSERT_TRUE(
243  any.forward(torch::autograd::Variable(torch::ones(5)))
244  .sum()
245  .item<float>() == 5);
246  // at::Tensors that are not variables work too.
247  ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
248 }
249 
250 namespace torch {
251 namespace nn {
252 struct TestValue {
253  template <typename T>
254  explicit TestValue(T&& value) : value_(std::forward<T>(value)) {}
255  AnyModule::Value operator()() {
256  return std::move(value_);
257  }
258  AnyModule::Value value_;
259 };
260 template <typename T>
261 AnyModule::Value make_value(T&& value) {
262  return TestValue(std::forward<T>(value))();
263 }
264 } // namespace nn
265 } // namespace torch
266 
268 
269 TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
270  auto value = make_value(5);
271  // const and non-const types have the same typeid()
272  ASSERT_NE(value.try_get<int>(), nullptr);
273  ASSERT_NE(value.try_get<const int>(), nullptr);
274  ASSERT_EQ(value.get<int>(), 5);
275 }
276 TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
277  auto value = make_value(5);
278  ASSERT_NE(value.try_get<const int>(), nullptr);
279  ASSERT_NE(value.try_get<int>(), nullptr);
280  ASSERT_EQ(value.get<const int>(), 5);
281 }
282 TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
283  auto value = make_value("hello");
284  ASSERT_NE(value.try_get<const char*>(), nullptr);
285  ASSERT_EQ(value.get<const char*>(), std::string("hello"));
286 }
287 TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
288  auto value = make_value(std::string("hello"));
289  ASSERT_NE(value.try_get<std::string>(), nullptr);
290  ASSERT_EQ(value.get<std::string>(), "hello");
291 }
292 TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
293  std::string s("hello");
294  std::string* p = &s;
295  auto value = make_value(p);
296  ASSERT_NE(value.try_get<std::string*>(), nullptr);
297  ASSERT_EQ(*value.get<std::string*>(), "hello");
298 }
299 TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
300  std::string s("hello");
301  const std::string& t = s;
302  auto value = make_value(t);
303  ASSERT_NE(value.try_get<std::string>(), nullptr);
304  ASSERT_EQ(value.get<std::string>(), "hello");
305 }
306 
307 TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
308  auto value = make_value(5);
309  ASSERT_NE(value.try_get<int>(), nullptr);
310  ASSERT_EQ(value.try_get<float>(), nullptr);
311  ASSERT_EQ(value.try_get<long>(), nullptr);
312  ASSERT_EQ(value.try_get<std::string>(), nullptr);
313 }
314 
315 TEST_F(AnyValueTest, GetThrowsForTheWrongType) {
316  auto value = make_value(5);
317  ASSERT_NE(value.try_get<int>(), nullptr);
318  ASSERT_THROWS_WITH(
319  value.get<float>(),
320  "Attempted to cast Value to float, "
321  "but its actual type is int");
322  ASSERT_THROWS_WITH(
323  value.get<long>(),
324  "Attempted to cast Value to long, "
325  "but its actual type is int");
326 }
327 
328 TEST_F(AnyValueTest, MoveConstructionIsAllowed) {
329  auto value = make_value(5);
330  auto copy = make_value(std::move(value));
331  ASSERT_NE(copy.try_get<int>(), nullptr);
332  ASSERT_EQ(copy.get<int>(), 5);
333 }
334 
335 TEST_F(AnyValueTest, MoveAssignmentIsAllowed) {
336  auto value = make_value(5);
337  auto copy = make_value(10);
338  copy = std::move(value);
339  ASSERT_NE(copy.try_get<int>(), nullptr);
340  ASSERT_EQ(copy.get<int>(), 5);
341 }
342 
343 TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) {
344  auto value = make_value(5);
345  ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code());
346 }
347 
348 TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
349  auto value = make_value("hello");
350  ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code());
351 }
352 
353 TEST_F(AnyValueTest, TypeInfoIsCorrectForString) {
354  auto value = make_value(std::string("hello"));
355  ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code());
356 }
const std::type_info & type_info() const
Returns the type_info object of the contained value.
Definition: any.h:501
Definition: any.cpp:108
std::shared_ptr< Module > ptr() const
Returns a std::shared_ptr whose dynamic type is that of the underlying module.
Definition: any.h:488
T & get()
Attempts to cast the underlying module to the given module type.
Definition: any.h:472
A simplified implementation of std::any which stores a type erased object, whose concrete value can b...
Definition: any.h:235
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
Definition: pimpl.h:26
The base class for all modules in PyTorch.
Definition: module.h:62
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
ReturnType forward(ArgumentTypes &&...arguments)
Invokes forward() on the contained module with the given arguments, and casts the returned Value to t...
Definition: any.h:466
Stores a type erased Module.
Definition: any.h:108
bool is_empty() const noexcept
Returns true if the AnyModule does not contain a module.
Definition: any.h:506