4 #include <torch/csrc/autograd/variable.h> 21 template <
typename... Args>
28 template <
typename T,
typename... Args>
29 inline F& apply(
T&& arg, Args&&... args) {
30 self()(std::forward<T>(arg));
31 if (
self().short_circuit()) {
34 return apply(std::forward<Args>(args)...);
51 for (
const auto& arg : args) {
61 void operator()(
const std::vector<T>& args) {
65 bool short_circuit() {
71 return *
static_cast<F*
>(
this);
85 template <
typename... Args>
86 size_t count_tensors(Args&&... args) {
87 return CountTensors().apply(std::forward<Args>(args)...).out;
100 template <
typename... Args>
101 inline size_t count_variables(Args&&... args) {
110 template <
size_t... Is>
115 template <
size_t N,
size_t... Is>
122 template <
size_t... Is>
131 template <
bool value,
typename T =
void>
132 using enable_if_t =
typename std::enable_if<value, T>::type;
134 template <
bool value,
typename T =
void>
135 using disable_if_t = enable_if_t<!value, T>;
137 template <
typename T>
138 using decay_t =
typename std::decay<T>::type;
145 template <
bool... values>
147 detail::pack<values..., true>,
148 detail::pack<true, values...>> {};
156 template <
bool head,
bool... tail>
158 static constexpr
bool value = head ||
any_of<tail...>::value;
161 template <
bool... values>
163 static constexpr
bool value = !
any_of<values...>::value;
166 template <
bool... values>
167 using enable_if_all_of_t = enable_if_t<
all_of<values...>::value>;
169 template <
typename T,
typename... Ts>
170 using disable_if_contains_t =
171 enable_if_all_of_t<(!std::is_same<T, decay_t<Ts>>::value)...>;
173 template <
typename Function,
typename... Ts>
174 void apply(Function
function, Ts&&... ts) {
181 int _[]{0, (
function(std::forward<Ts>(ts)), 0)...};
185 template <
typename ReturnType,
typename... Ts,
typename Function,
typename Accessor>
186 ReturnType unpack(Function
function, Accessor accessor) {
187 return ReturnType(unpack<ReturnType, Ts...>(
193 template <
typename ReturnType,
typename... Ts,
typename Function,
typename Accessor,
size_t... Is>
194 ReturnType unpack(Function
function, Accessor accessor,
Indices<Is...>) {
195 return ReturnType(
function(accessor.template
operator()<Ts>(Is)...));
constexpr size_t size() const
size - Get the array size.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...