Caffe2 - C++ API
A deep learning, cross platform ML framework
stack.h
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 
5 // TODO move this to c10 namespace
6 
7 namespace torch {
8 namespace jit {
9 
10 using c10::IValue;
11 using Stack = std::vector<IValue>;
12 using Operation = std::function<int(Stack&)>;
13 
14 // An operation with N inputs and M outputs pops the last N inputs off
15 // the stack and pushes its M inputs onto the stack
16 // before: <other stack items> I0, I1, ... IN <- stack.back()
17 // after: <other stack items> O0, O1, ... OM
18 // operations are defined this way so that ownership of inputs can be
19 // transferred to the operation and it can incrementally drop ownership of
20 // tensors when they become unneeded. For large operations, like 'run an entire
21 // subgraph', this functionality is very important for minimizing gpu memory
22 // usage return value is the relative 'offset' to jump to for the next
23 // operation:
24 // pc += 1 + offset
25 // so a return value of 0 goes to the next instruction
26 
27 // treat the last N elements of the stack as a list, looking up
28 // element i
29 static inline IValue& peek(Stack& stack, size_t i, size_t N) {
30  return *(stack.end() - N + i);
31 }
32 static inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
33  return *(stack.end() - N + i);
34 }
35 // treat the last N elements of the stack as a list, looking up the
36 // slice starting at index i and having length len
37 static inline at::ArrayRef<IValue> peekSlice(
38  const Stack& stack,
39  size_t i,
40  size_t len,
41  size_t N) {
42  return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
43 }
44 static inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
45  return peekSlice(stack, 0, N, N);
46 }
47 static inline void drop(Stack& stack, size_t n) {
48  stack.erase(stack.end() - n, stack.end());
49 }
50 static inline IValue pop(Stack& stack) {
51  auto r = std::move(stack.back());
52  stack.pop_back();
53  return r;
54 }
55 static inline std::vector<IValue> pop(Stack& stack, size_t n) {
56  std::vector<IValue> result;
57  result.reserve(n);
58  for (size_t i = 0; i < n; ++i) {
59  result.push_back(std::move(peek(stack, i, n)));
60  }
61  drop(stack, n);
62  return result;
63 }
64 
65 // variadic pop:
66 // int64_t a; at::Tensor b;
67 // pop(stack, a, b);
68 // equivalent to:
69 // b = pop(stack).toTensor();
70 // a = pop(stack).toInt();
71 template <typename... Types>
72 static inline void pop(Stack& stack, Types&... args) {
73  size_t i = 0;
74  constexpr size_t N = sizeof...(args);
75  int result[N] = {
76  (args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...};
77  (void)result;
78  drop(stack, N);
79 }
80 template <typename... Types>
81 static inline void push(Stack& stack, Types&&... args) {
82  (void)std::initializer_list<int>{(stack.emplace_back(std::forward<Types>(args)), 0)...};
83 }
84 
85 // The packer here is carefully written not to make any unnecessary
86 // copies.
87 
88 // pack takes the return values of aten functions pushes them onto the stack
89 template <typename T>
90 inline void pack(Stack& stack, T&& v) {
91  stack.emplace_back(std::forward<T>(v));
92 }
93 
94 template <std::size_t remaining, typename... Args>
95 struct TuplePacker {
96  // NB: *Not* a universal reference.
97  static void execute(Stack& stack, std::tuple<Args...>&& t) {
98  // NB: The move here does not "destroy" the entire tuple, that is
99  // not what std::move does; only the particular tuple index
100  // processed here gets stolen.
101  pack(stack, std::get<sizeof...(Args) - remaining>(std::move(t)));
103  }
104 };
105 
106 template <typename... Args>
107 struct TuplePacker<0, Args...> {
108  static void execute(Stack& stack, std::tuple<Args...>&& t){};
109 };
110 
111 template <typename... Args>
112 inline void pack(Stack& stack, std::tuple<Args...>&& t) {
113  TuplePacker<sizeof...(Args), Args...>::execute(stack, std::move(t));
114 }
115 
116 } // namespace jit
117 } // namespace torch
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41