Caffe2 - C++ API
A deep learning, cross platform ML framework
static.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/modules/any.h>
6 #include <torch/nn/modules/linear.h>
7 
8 #include <torch/csrc/utils/variadic.h>
9 
10 #include <string>
11 #include <vector>
12 
13 template <
14  typename T,
15  typename = torch::enable_if_t<!torch::detail::is_module<T>::value>>
16 bool f(T&& m) {
17  return false;
18 }
19 
20 template <typename T>
21 torch::detail::enable_if_module_t<T, bool> f(T&& m) {
22  return true;
23 }
24 
25 TEST(TestStatic, AllOf) {
26  ASSERT_TRUE(torch::all_of<>::value);
27  ASSERT_TRUE(torch::all_of<true>::value);
29  ASSERT_FALSE(torch::all_of<false>::value);
32 }
33 
34 TEST(TestStatic, AnyOf) {
35  ASSERT_FALSE(torch::any_of<>::value);
36  ASSERT_TRUE(bool((torch::any_of<true>::value)));
37  ASSERT_TRUE(bool((torch::any_of<true, true, true>::value)));
38  ASSERT_FALSE(bool((torch::any_of<false>::value)));
39 }
40 
41 TEST(TestStatic, EnableIfModule) {
42  ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2)));
43  ASSERT_FALSE(f(5));
44  ASSERT_TRUE(torch::detail::check_not_lvalue_references<int>());
45  ASSERT_TRUE((torch::detail::check_not_lvalue_references<float, int, char>()));
46  ASSERT_FALSE(
47  (torch::detail::check_not_lvalue_references<float, int&, char>()));
48  ASSERT_TRUE(torch::detail::check_not_lvalue_references<std::string>());
49  ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
50 }
51 
52 struct A : torch::nn::Module {
53  int forward() {
54  return 5;
55  }
56 };
57 
58 struct B : torch::nn::Module {
59  std::string forward(torch::Tensor tensor) {
60  return "";
61  }
62 };
63 
64 struct C : torch::nn::Module {
65  float forward(torch::Tensor& tensor) {
66  return 5.0;
67  }
68 };
69 
70 struct D : torch::nn::Module {
71  char forward(torch::Tensor&& tensor) {
72  return 'x';
73  }
74 };
75 
76 struct E : torch::nn::Module {};
77 
78 // Put in a function because macros don't handle the comma between arguments to
79 // is_same well ...
80 template <typename Module, typename ExpectedType, typename... Args>
81 void assert_has_expected_type() {
82  using ReturnType =
83  typename torch::detail::return_type_of_forward<Module, Args...>::type;
84  constexpr bool is_expected_type =
85  std::is_same<ReturnType, ExpectedType>::value;
86  ASSERT_TRUE(is_expected_type) << Module().name();
87 }
88 
89 TEST(TestStatic, ReturnTypeOfForward) {
90  assert_has_expected_type<A, int>();
91  assert_has_expected_type<B, std::string, torch::Tensor>();
92  assert_has_expected_type<C, float, torch::Tensor&>();
93  assert_has_expected_type<D, char, torch::Tensor&&>();
94  assert_has_expected_type<E, void>();
95 }
96 
97 TEST(TestStatic, Apply) {
98  std::vector<int> v;
99  torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
100  ASSERT_EQ(v.size(), 5);
101  for (size_t i = 0; i < v.size(); ++i) {
102  ASSERT_EQ(v.at(i), i + 1);
103  }
104 }
Definition: static.cpp:76
Applies a linear transformation with optional bias.
Definition: linear.h:25
Definition: static.cpp:52
The base class for all modules in PyTorch.
Definition: module.h:62
Definition: static.cpp:64
Definition: static.cpp:58
Definition: static.cpp:70