1 #include <gtest/gtest.h>     3 #include <torch/nn/modules/any.h>     4 #include <torch/nn/modules/linear.h>     5 #include <torch/utils.h>     7 #include <test/cpp/api/support.h>    24   ASSERT_EQ(any.forward<
int>(), 123);
    34   ASSERT_EQ(any.forward<
int>(5), 5);
    39     const char* forward(
const char* x) {
    44   ASSERT_EQ(any.forward<
const char*>(
"hello"), std::string(
"hello"));
    49     std::string forward(
int x, 
const double f) {
    50       return std::to_string(static_cast<int>(x + f));
    55   ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string(
"7"));
    60     TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
    66       const auto s = a + b + c;
    67       return torch::ones({
static_cast<int64_t
>(s.size())});
    72       any.forward(std::string(
"a"), std::string(
"ab"), std::string(
"abc"))
    74           .item<int32_t>() == 6);
    79     int forward(
float x) {
    86       "Expected argument #0 to be of type float, "    87       "but received value of type double");
    92     int forward(
int a, 
int b) {
    99       "M's forward() method expects 2 arguments, but received 0");
   102       "M's forward() method expects 2 arguments, but received 1");
   104       any.forward(1, 2, 3),
   105       "M's forward() method expects 2 arguments, but received 3");
   111   int forward(
float x) {
   118   ASSERT_EQ(any.get<
M>().value, 5);
   128   ASSERT_THROWS_WITH(any.get<N>(), 
"Attempted to cast module");
   133   auto ptr = any.ptr();
   134   ASSERT_NE(ptr, 
nullptr);
   135   ASSERT_EQ(ptr->name(), 
"M");
   140   auto ptr = any.ptr<
M>();
   141   ASSERT_NE(ptr, 
nullptr);
   142   ASSERT_EQ(ptr->value, 5);
   152   ASSERT_THROWS_WITH(any.ptr<N>(), 
"Attempted to cast module");
   157     explicit M(
int value_) : value(value_) {}
   159     int forward(
float x) {
   165   any = std::make_shared<M>(5);
   166   ASSERT_FALSE(any.is_empty());
   167   ASSERT_EQ(any.get<
M>().value, 5);
   178   ASSERT_THROWS_WITH(any.
get<
M>(), 
"Cannot call get() on an empty AnyModule");
   179   ASSERT_THROWS_WITH(any.
ptr<
M>(), 
"Cannot call ptr() on an empty AnyModule");
   180   ASSERT_THROWS_WITH(any.
ptr(), 
"Cannot call ptr() on an empty AnyModule");
   182       any.
type_info(), 
"Cannot call type_info() on an empty AnyModule");
   184       any.
forward<
int>(5), 
"Cannot call forward() on an empty AnyModule");
   189     std::string forward(
int x) {
   190       return std::to_string(x);
   194     int forward(
float x) {
   200   any = std::make_shared<M>();
   201   ASSERT_FALSE(any.is_empty());
   202   ASSERT_EQ(any.forward<std::string>(5), 
"5");
   203   any = std::make_shared<N>();
   204   ASSERT_FALSE(any.is_empty());
   205   ASSERT_EQ(any.forward<
int>(5.0f), 8);
   212     int forward(
float x) {
   223   ASSERT_EQ(any.get<MImpl>().value, 5);
   224   ASSERT_EQ(any.get<
M>()->value, 5);
   227   std::shared_ptr<Module> ptr = module.
ptr();
   228   Linear linear(module.
get<Linear>());
   245           .item<
float>() == 5);
   247   ASSERT_EQ(any.forward(at::ones(5)).sum().item<
float>(), 5);
   253   template <
typename T>
   254   explicit TestValue(
T&& value) : value_(std::forward<T>(value)) {}
   256     return std::move(value_);
   260 template <
typename T>
   262   return TestValue(std::forward<T>(value))();
   269 TEST_F(
AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
   270   auto value = make_value(5);
   272   ASSERT_NE(value.try_get<
int>(), 
nullptr);
   273   ASSERT_NE(value.try_get<
const int>(), 
nullptr);
   274   ASSERT_EQ(value.get<
int>(), 5);
   276 TEST_F(
AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
   277   auto value = make_value(5);
   278   ASSERT_NE(value.try_get<
const int>(), 
nullptr);
   279   ASSERT_NE(value.try_get<
int>(), 
nullptr);
   280   ASSERT_EQ(value.get<
const int>(), 5);
   282 TEST_F(
AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
   283   auto value = make_value(
"hello");
   284   ASSERT_NE(value.try_get<
const char*>(), 
nullptr);
   285   ASSERT_EQ(value.get<
const char*>(), std::string(
"hello"));
   287 TEST_F(
AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
   288   auto value = make_value(std::string(
"hello"));
   289   ASSERT_NE(value.try_get<std::string>(), 
nullptr);
   290   ASSERT_EQ(value.get<std::string>(), 
"hello");
   292 TEST_F(
AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
   293   std::string s(
"hello");
   295   auto value = make_value(p);
   296   ASSERT_NE(value.try_get<std::string*>(), 
nullptr);
   297   ASSERT_EQ(*value.get<std::string*>(), 
"hello");
   299 TEST_F(
AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
   300   std::string s(
"hello");
   301   const std::string& t = s;
   302   auto value = make_value(t);
   303   ASSERT_NE(value.try_get<std::string>(), 
nullptr);
   304   ASSERT_EQ(value.get<std::string>(), 
"hello");
   307 TEST_F(
AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
   308   auto value = make_value(5);
   309   ASSERT_NE(value.try_get<
int>(), 
nullptr);
   310   ASSERT_EQ(value.try_get<
float>(), 
nullptr);
   311   ASSERT_EQ(value.try_get<
long>(), 
nullptr);
   312   ASSERT_EQ(value.try_get<std::string>(), 
nullptr);
   316   auto value = make_value(5);
   317   ASSERT_NE(value.try_get<
int>(), 
nullptr);
   320       "Attempted to cast Value to float, "   321       "but its actual type is int");
   324       "Attempted to cast Value to long, "   325       "but its actual type is int");
   329   auto value = make_value(5);
   330   auto copy = make_value(std::move(value));
   331   ASSERT_NE(copy.try_get<
int>(), 
nullptr);
   332   ASSERT_EQ(copy.get<
int>(), 5);
   336   auto value = make_value(5);
   337   auto copy = make_value(10);
   338   copy = std::move(value);
   339   ASSERT_NE(copy.try_get<
int>(), 
nullptr);
   340   ASSERT_EQ(copy.get<
int>(), 5);
   344   auto value = make_value(5);
   345   ASSERT_EQ(value.type_info().hash_code(), 
typeid(int).hash_code());
   348 TEST_F(
AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
   349   auto value = make_value(
"hello");
   350   ASSERT_EQ(value.type_info().hash_code(), 
typeid(
const char*).hash_code());
   354   auto value = make_value(std::string(
"hello"));
   355   ASSERT_EQ(value.type_info().hash_code(), 
typeid(std::string).hash_code());
 const std::type_info & type_info() const 
Returns the type_info object of the contained value. 
std::shared_ptr< Module > ptr() const 
Returns a std::shared_ptr whose dynamic type is that of the underlying module. 
T & get()
Attempts to cast the underlying module to the given module type. 
A simplified implementation of std::any which stores a type erased object, whose concrete value can b...
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
The base class for all modules in PyTorch. 
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
ReturnType forward(ArgumentTypes &&...arguments)
Invokes forward() on the contained module with the given arguments, and casts the returned Value to t...
Stores a type erased Module. 
bool is_empty() const noexcept
Returns true if the AnyModule does not contain a module.