1 #include <gtest/gtest.h> 3 #include <torch/nn/modules/any.h> 4 #include <torch/nn/modules/linear.h> 5 #include <torch/utils.h> 7 #include <test/cpp/api/support.h> 24 ASSERT_EQ(any.forward<
int>(), 123);
34 ASSERT_EQ(any.forward<
int>(5), 5);
39 const char* forward(
const char* x) {
44 ASSERT_EQ(any.forward<
const char*>(
"hello"), std::string(
"hello"));
49 std::string forward(
int x,
const double f) {
50 return std::to_string(static_cast<int>(x + f));
55 ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string(
"7"));
60 TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
66 const auto s = a + b + c;
67 return torch::ones({
static_cast<int64_t
>(s.size())});
72 any.forward(std::string(
"a"), std::string(
"ab"), std::string(
"abc"))
74 .item<int32_t>() == 6);
79 int forward(
float x) {
86 "Expected argument #0 to be of type float, " 87 "but received value of type double");
92 int forward(
int a,
int b) {
99 "M's forward() method expects 2 arguments, but received 0");
102 "M's forward() method expects 2 arguments, but received 1");
104 any.forward(1, 2, 3),
105 "M's forward() method expects 2 arguments, but received 3");
111 int forward(
float x) {
118 ASSERT_EQ(any.get<
M>().value, 5);
128 ASSERT_THROWS_WITH(any.get<N>(),
"Attempted to cast module");
133 auto ptr = any.ptr();
134 ASSERT_NE(ptr,
nullptr);
135 ASSERT_EQ(ptr->name(),
"M");
140 auto ptr = any.ptr<
M>();
141 ASSERT_NE(ptr,
nullptr);
142 ASSERT_EQ(ptr->value, 5);
152 ASSERT_THROWS_WITH(any.ptr<N>(),
"Attempted to cast module");
157 explicit M(
int value_) : value(value_) {}
159 int forward(
float x) {
165 any = std::make_shared<M>(5);
166 ASSERT_FALSE(any.is_empty());
167 ASSERT_EQ(any.get<
M>().value, 5);
178 ASSERT_THROWS_WITH(any.
get<
M>(),
"Cannot call get() on an empty AnyModule");
179 ASSERT_THROWS_WITH(any.
ptr<
M>(),
"Cannot call ptr() on an empty AnyModule");
180 ASSERT_THROWS_WITH(any.
ptr(),
"Cannot call ptr() on an empty AnyModule");
182 any.
type_info(),
"Cannot call type_info() on an empty AnyModule");
184 any.
forward<
int>(5),
"Cannot call forward() on an empty AnyModule");
189 std::string forward(
int x) {
190 return std::to_string(x);
194 int forward(
float x) {
200 any = std::make_shared<M>();
201 ASSERT_FALSE(any.is_empty());
202 ASSERT_EQ(any.forward<std::string>(5),
"5");
203 any = std::make_shared<N>();
204 ASSERT_FALSE(any.is_empty());
205 ASSERT_EQ(any.forward<
int>(5.0f), 8);
212 int forward(
float x) {
223 ASSERT_EQ(any.get<MImpl>().value, 5);
224 ASSERT_EQ(any.get<
M>()->value, 5);
227 std::shared_ptr<Module> ptr = module.
ptr();
228 Linear linear(module.
get<Linear>());
245 .item<
float>() == 5);
247 ASSERT_EQ(any.forward(at::ones(5)).sum().item<
float>(), 5);
253 template <
typename T>
254 explicit TestValue(
T&& value) : value_(std::forward<T>(value)) {}
256 return std::move(value_);
260 template <
typename T>
262 return TestValue(std::forward<T>(value))();
269 TEST_F(
AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
270 auto value = make_value(5);
272 ASSERT_NE(value.try_get<
int>(),
nullptr);
273 ASSERT_NE(value.try_get<
const int>(),
nullptr);
274 ASSERT_EQ(value.get<
int>(), 5);
276 TEST_F(
AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
277 auto value = make_value(5);
278 ASSERT_NE(value.try_get<
const int>(),
nullptr);
279 ASSERT_NE(value.try_get<
int>(),
nullptr);
280 ASSERT_EQ(value.get<
const int>(), 5);
282 TEST_F(
AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
283 auto value = make_value(
"hello");
284 ASSERT_NE(value.try_get<
const char*>(),
nullptr);
285 ASSERT_EQ(value.get<
const char*>(), std::string(
"hello"));
287 TEST_F(
AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
288 auto value = make_value(std::string(
"hello"));
289 ASSERT_NE(value.try_get<std::string>(),
nullptr);
290 ASSERT_EQ(value.get<std::string>(),
"hello");
292 TEST_F(
AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
293 std::string s(
"hello");
295 auto value = make_value(p);
296 ASSERT_NE(value.try_get<std::string*>(),
nullptr);
297 ASSERT_EQ(*value.get<std::string*>(),
"hello");
299 TEST_F(
AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
300 std::string s(
"hello");
301 const std::string& t = s;
302 auto value = make_value(t);
303 ASSERT_NE(value.try_get<std::string>(),
nullptr);
304 ASSERT_EQ(value.get<std::string>(),
"hello");
307 TEST_F(
AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
308 auto value = make_value(5);
309 ASSERT_NE(value.try_get<
int>(),
nullptr);
310 ASSERT_EQ(value.try_get<
float>(),
nullptr);
311 ASSERT_EQ(value.try_get<
long>(),
nullptr);
312 ASSERT_EQ(value.try_get<std::string>(),
nullptr);
316 auto value = make_value(5);
317 ASSERT_NE(value.try_get<
int>(),
nullptr);
320 "Attempted to cast Value to float, " 321 "but its actual type is int");
324 "Attempted to cast Value to long, " 325 "but its actual type is int");
329 auto value = make_value(5);
330 auto copy = make_value(std::move(value));
331 ASSERT_NE(copy.try_get<
int>(),
nullptr);
332 ASSERT_EQ(copy.get<
int>(), 5);
336 auto value = make_value(5);
337 auto copy = make_value(10);
338 copy = std::move(value);
339 ASSERT_NE(copy.try_get<
int>(),
nullptr);
340 ASSERT_EQ(copy.get<
int>(), 5);
344 auto value = make_value(5);
345 ASSERT_EQ(value.type_info().hash_code(),
typeid(int).hash_code());
348 TEST_F(
AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
349 auto value = make_value(
"hello");
350 ASSERT_EQ(value.type_info().hash_code(),
typeid(
const char*).hash_code());
354 auto value = make_value(std::string(
"hello"));
355 ASSERT_EQ(value.type_info().hash_code(),
typeid(std::string).hash_code());
const std::type_info & type_info() const
Returns the type_info object of the contained value.
std::shared_ptr< Module > ptr() const
Returns a std::shared_ptr whose dynamic type is that of the underlying module.
T & get()
Attempts to cast the underlying module to the given module type.
A simplified implementation of std::any which stores a type erased object, whose concrete value can b...
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
The base class for all modules in PyTorch.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
ReturnType forward(ArgumentTypes &&...arguments)
Invokes forward() on the contained module with the given arguments, and casts the returned Value to t...
Stores a type erased Module.
bool is_empty() const noexcept
Returns true if the AnyModule does not contain a module.