Caffe2 - C++ API
A deep learning, cross platform ML framework
variadic.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/autograd/variable.h>
5 
6 #include <cstdint>
7 #include <tuple>
8 #include <type_traits>
9 #include <utility>
10 
11 namespace torch {
12 
13 // This class allows you to write variadic functions which
14 // call a (possibly overloaded) function on each argument,
15 // in order. This is most commonly used in autogenerated code,
16 // where it is convenient to have a function that can uniformly
17 // take arguments of different types. If your arguments
18 // are homogenous consider using a std::initializer_list instead.
19 template <typename F>
20 struct IterArgs {
21  template <typename... Args>
22  inline F& apply() {
23  return self();
24  }
25 
26  // NB: Use perfect forwarding here, otherwise we'll make value
27  // copies of all arguments!
28  template <typename T, typename... Args>
29  inline F& apply(T&& arg, Args&&... args) {
30  self()(std::forward<T>(arg));
31  if (self().short_circuit()) {
32  return self();
33  } else {
34  return apply(std::forward<Args>(args)...);
35  }
36  }
37 
38  // Here are some handy overloads which provide sensible
39  // defaults for container-like structures that one might
40  // be interested in recursing into. You can enable them
41  // by adding:
42  //
43  // using IterArgs<YourStructName>::operator()
44  //
45  // to your struct. These are not enabled by default because
46  // you may be able to process these structures more efficiently
47  // than handling them one-by-one.
48 
49  template <typename T>
50  void operator()(at::ArrayRef<T> args) {
51  for (const auto& arg : args) {
52  self()(arg);
53  if (short_circuit())
54  return;
55  }
56  }
57 
58  // NB: we need to specify std::vector manually as C++ won't
59  // do an implicit conversion to make a template deduction go through.
60  template <typename T>
61  void operator()(const std::vector<T>& args) {
62  self()(at::ArrayRef<T>{args});
63  }
64 
65  bool short_circuit() {
66  return false;
67  }
68 
69  private:
70  inline F& self() {
71  return *static_cast<F*>(this);
72  }
73 };
74 
75 struct CountTensors : IterArgs<CountTensors> {
76  size_t out = 0;
77  void operator()(const at::Tensor& x) {
78  out += 1;
79  }
80  void operator()(at::ArrayRef<at::Tensor> xs) {
81  out += xs.size();
82  }
83 };
84 
85 template <typename... Args>
86 size_t count_tensors(Args&&... args) {
87  return CountTensors().apply(std::forward<Args>(args)...).out;
88 }
89 
90 struct CountVariables : IterArgs<CountVariables> {
91  size_t out = 0;
92  void operator()(const autograd::Variable& x) {
93  out += 1;
94  }
95  void operator()(at::ArrayRef<autograd::Variable> xs) {
96  out += xs.size();
97  }
98 };
99 
100 template <typename... Args>
101 inline size_t count_variables(Args&&... args) {
102  return CountVariables().apply(std::forward<Args>(args)...).out;
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // std::index_sequence shim for C++11
107 //===----------------------------------------------------------------------===//
108 
109 // A container of type-template parameter indices.
110 template <size_t... Is>
111 struct Indices {};
112 
113 // Decrements the index N, adds N-1 to the list of indices and forwards
114 // whatever we arleady have.
115 template <size_t N, size_t... Is>
116 struct MakeIndices : MakeIndices<N - 1, N - 1, Is...> {};
117 
118 // Partial specialization that forms our base case. When N is zero, we stop
119 // and define a typedef that will be visible to earlier classes due to
120 // inheritance. The typedef we define is an index list containing the numbers
121 // 0 through N-1.
122 template <size_t... Is>
123 struct MakeIndices<0, Is...> {
124  using indices = Indices<Is...>;
125 };
126 
127 //===----------------------------------------------------------------------===//
128 // Utilities
129 //===----------------------------------------------------------------------===//
130 
131 template <bool value, typename T = void>
132 using enable_if_t = typename std::enable_if<value, T>::type;
133 
134 template <bool value, typename T = void>
135 using disable_if_t = enable_if_t<!value, T>;
136 
137 template <typename T>
138 using decay_t = typename std::decay<T>::type;
139 
140 namespace detail {
141 template <bool...>
142 struct pack;
143 } // namespace detail
144 
145 template <bool... values>
146 struct all_of : std::is_same<
147  detail::pack<values..., true>,
148  detail::pack<true, values...>> {};
149 
150 template <bool...>
151 struct any_of;
152 
153 template <>
154 struct any_of<> : std::false_type {};
155 
156 template <bool head, bool... tail>
157 struct any_of<head, tail...> {
158  static constexpr bool value = head || any_of<tail...>::value;
159 };
160 
161 template <bool... values>
162 struct none_of {
163  static constexpr bool value = !any_of<values...>::value;
164 };
165 
166 template <bool... values>
167 using enable_if_all_of_t = enable_if_t<all_of<values...>::value>;
168 
169 template <typename T, typename... Ts>
170 using disable_if_contains_t =
171  enable_if_all_of_t<(!std::is_same<T, decay_t<Ts>>::value)...>;
172 
173 template <typename Function, typename... Ts>
174 void apply(Function function, Ts&&... ts) {
175  // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector
176  // Creates a dummy array, so that each function call is evaluated in order.
177  // `(function(), 0)` is because `function` should (!) return `void`, so
178  // according to the comma operator, it is evaluated and its result (`void`)
179  // is discarded. Then the zero is evaluated and used as an element in the
180  // array. The first zero ensures the array is not empty.
181  int _[]{0, (function(std::forward<Ts>(ts)), 0)...};
182  (void)_;
183 }
184 
185 template <typename ReturnType, typename... Ts, typename Function, typename Accessor>
186 ReturnType unpack(Function function, Accessor accessor) {
187  return ReturnType(unpack<ReturnType, Ts...>(
188  std::move(function),
189  std::move(accessor),
190  typename MakeIndices<sizeof...(Ts)>::indices()));
191 }
192 
193 template <typename ReturnType, typename... Ts, typename Function, typename Accessor, size_t... Is>
194 ReturnType unpack(Function function, Accessor accessor, Indices<Is...>) {
195  return ReturnType(function(accessor.template operator()<Ts>(Is)...));
196 }
197 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41