Caffe2 - C++ API
A deep learning, cross platform ML framework
loop_unrolling.cpp
1 #include <torch/csrc/jit/passes/loop_unrolling.h>
2 
3 #include <c10/util/Exception.h>
4 #include <ATen/core/interned_strings.h>
5 #include <torch/csrc/jit/symbolic_variable.h>
6 
7 #include <torch/csrc/jit/constants.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 
10 namespace torch {
11 namespace jit {
12 
13 namespace {
14 
15 static constexpr int64_t kUnrollFactor = 8;
16 static constexpr int64_t kMaxBodySize = 32;
17 static constexpr int64_t kMaxBodyRepeats = 64;
18 
19 bool isTrueConstant(Value* val) {
20  c10::optional<bool> maybe_value = constant_as<bool>(val);
21  return maybe_value && *maybe_value;
22 }
23 
24 bool isForLoop(Node* node) {
25  if (node->kind() != prim::Loop)
26  return false;
27  Value* start_cond = node->inputs().at(1);
28  Value* continue_cond = node->blocks().at(0)->outputs().at(0);
29  return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
30 }
31 
32 // Counts the size of this block, stopping and returning once reaches limit
33 // instructions.
34 int64_t limitedBlockSize(Block* body, int64_t limit) {
35  auto it = body->nodes().begin();
36  auto end = body->nodes().end();
37  for (int64_t i = 0; i < limit; ++i, ++it) {
38  for (Block* subblock : it->blocks()) {
39  i += limitedBlockSize(subblock, limit - i);
40  }
41  if (it == end) {
42  return i;
43  }
44  }
45  return limit;
46 }
47 
48 bool isSmallBlock(Block* body) {
49  return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
50 }
51 
52 // XXX: This function can only be called with a loop that is guaranteed to
53 // execute EXACTLY ONCE.
54 void inlineBody(Node* loop) {
55  auto graph = loop->owningGraph();
56  auto body = loop->blocks().at(0);
57  WithInsertPoint insert_point_guard{loop};
58 
59  std::unordered_map<Value*, Value*> value_map;
60  auto get_value = [&](Value* v) {
61  auto it = value_map.find(v);
62  if (it != value_map.end())
63  return it->second;
64  return v;
65  };
66 
67  // Loop node has extra (max_iters, initial_cond) inputs,
68  // body has an extra (loop_counter) input.
69  for (size_t i = 2; i < loop->inputs().size(); ++i) {
70  value_map[body->inputs()[i - 1]] = loop->inputs()[i];
71  }
72 
73  for (Node* orig : body->nodes()) {
74  Node* clone = graph->insertNode(graph->createClone(orig, get_value));
75  for (size_t i = 0; i < orig->outputs().size(); ++i) {
76  value_map[orig->outputs()[i]] = clone->outputs()[i];
77  }
78  }
79  for (size_t i = 0; i < loop->outputs().size(); ++i) {
80  loop->outputs().at(i)->replaceAllUsesWith(
81  get_value(body->outputs().at(i + 1)));
82  }
83  // XXX: it is extremely important to destroy the loop in here. DCE might not
84  // be able to conclude that it's safe, because the loop might contain side
85  // effects.
86  loop->destroy();
87 }
88 
89 void repeatBody(Block* body, int64_t times) {
90  // We will be adding nodes to the body, so cache the initial start and end.
91  // XXX: they are both inclusive, because the exclusive body_end would point to
92  // return_node, which would move further away if we were to add nodes,
93  // and we would enter an infinite loop.
94  auto body_start = body->nodes().begin();
95  auto body_end = std::prev(body->nodes().end());
96  auto graph = body->owningGraph();
97  WithInsertPoint insert_point_guard{body};
98 
99  std::unordered_map<Value*, Value*> value_map;
100  auto get_value = [&](Value* v) {
101  auto it = value_map.find(v);
102  if (it != value_map.end())
103  return it->second;
104  return v;
105  };
106 
107  for (int64_t i = 1; i < times; ++i) {
108  // Update loop-carried values
109  // NB: note that we don't need to worry about the loop counter, because
110  // we've replaced it with a loop-carried variable
111  AT_ASSERT(body->inputs().size() == body->outputs().size());
112  for (size_t i = 1; i < body->inputs().size(); ++i) {
113  value_map[body->inputs()[i]] = get_value(body->outputs()[i]);
114  }
115 
116  // Clone the nodes
117  for (auto it = body_start; it != std::next(body_end); ++it) {
118  Node* orig = *it;
119  Node* clone = graph->insertNode(graph->createClone(orig, get_value));
120  for (size_t i = 0; i < orig->outputs().size(); ++i) {
121  value_map[orig->outputs()[i]] = clone->outputs()[i];
122  }
123  }
124  }
125 
126  // Update outputs of the body
127  const std::vector<Value*> new_outputs = fmap(body->outputs(), get_value);
128  for (int64_t i = new_outputs.size() - 1; i >= 0; --i) {
129  body->eraseOutput(i);
130  }
131  for (Value* output : new_outputs) {
132  body->registerOutput(output);
133  }
134 
135  // It's likely that we have some dead nodes now - for example the "true"
136  // constant that prevents the loop from breaking. We shouldn't wait too long
137  // before removing them because they might artificially increase the loop size
138  // and prevent outer loop unrolling.
139  EliminateDeadCode(body, false);
140 }
141 
142 // Replaces the builtin loop counter with a "mutable" variable outside of the
143 // loop.
144 void replaceLoopCounter(Node* loop) {
145  Graph* graph = loop->owningGraph();
146  Block* body = loop->blocks().at(0);
147  WithInsertPoint guard(loop);
148  Value* init_counter = graph->insertConstant(0);
149 
150  loop->insertInput(2, init_counter);
151  loop->insertOutput(0)->setType(IntType::get());
152 
153  Value* internal_counter = body->insertInput(1)->setType(init_counter->type());
154  body->inputs()[0]->replaceAllUsesWith(internal_counter);
155 
156  WithInsertPoint insertPointGuard{body->return_node()};
157  Value* result = graph->insert(aten::add, {internal_counter, 1});
158  body->insertOutput(1, result);
159 }
160 
161 void unroll(Node* loop) {
162  Graph* graph = loop->owningGraph();
163  Block* body = loop->blocks().at(0);
164  if (!isSmallBlock(body))
165  return;
166 
167  // We will be using a "mutable" counter outside of the loop instead of the
168  // default one, because this will allow us to share it between the unrolled
169  // loop and its epilogue. This is necessary only if the loop counter is
170  // actually used in the body.
171  if (body->inputs()[0]->uses().size() > 0)
172  replaceLoopCounter(loop);
173 
174  // Some optimization for constant-length loops. If we know they won't run too
175  // many times, then we can unroll them entirely.
176  Value* trip_count = loop->inputs().at(0);
177  int64_t const_len = constant_as<int64_t>(trip_count).value_or(-1);
178  if (const_len != -1 && const_len < kMaxBodyRepeats) {
179  repeatBody(body, const_len);
180  inlineBody(loop);
181  return;
182  }
183 
184  WithInsertPoint insert_point_guard{loop};
185 
186  // Clone the loop before we unroll it. The clone will become the epilogue.
187  Node* loop_epilogue =
188  graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop);
189  for (size_t i = 0; i < loop->outputs().size(); ++i) {
190  loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]);
191  loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
192  }
193 
194  repeatBody(body, kUnrollFactor);
195 
196  // Change the iteration counts of both loops
197  Value* iter_count = loop->inputs().at(0);
198  Value* unrolled_iter_count = graph->insert(
199  aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
200  loop->replaceInput(0, unrolled_iter_count);
201  loop_epilogue->replaceInput(
202  0,
203  graph->insert(
204  aten::sub,
205  {iter_count,
206  graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
207 }
208 
209 void UnrollLoops(Block* block) {
210  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
211  // XXX: unroll might destroy the current node, so we need to pre-increment
212  // the iterator
213  Node* node = *it;
214  ++it;
215  for (Block* subblock : node->blocks()) {
216  UnrollLoops(subblock);
217  }
218  if (isForLoop(node)) {
219  unroll(node);
220  }
221  }
222 }
223 
224 } // anonymous namespace
225 
226 void UnrollLoops(std::shared_ptr<Graph>& graph) {
227  UnrollLoops(graph->block());
228  EliminateDeadCode(graph);
229 }
230 
231 } // namespace jit
232 } // namespace torch
Definition: jit_type.h:17