1 #include <torch/csrc/jit/passes/loop_unrolling.h> 3 #include <c10/util/Exception.h> 4 #include <ATen/core/interned_strings.h> 5 #include <torch/csrc/jit/symbolic_variable.h> 7 #include <torch/csrc/jit/constants.h> 8 #include <torch/csrc/jit/passes/dead_code_elimination.h> 15 static constexpr int64_t kUnrollFactor = 8;
16 static constexpr int64_t kMaxBodySize = 32;
17 static constexpr int64_t kMaxBodyRepeats = 64;
19 bool isTrueConstant(Value* val) {
21 return maybe_value && *maybe_value;
24 bool isForLoop(Node* node) {
25 if (node->kind() != prim::Loop)
27 Value* start_cond = node->inputs().at(1);
28 Value* continue_cond = node->blocks().at(0)->outputs().at(0);
29 return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
34 int64_t limitedBlockSize(Block* body, int64_t limit) {
35 auto it = body->nodes().begin();
36 auto end = body->nodes().end();
37 for (int64_t i = 0; i < limit; ++i, ++it) {
38 for (Block* subblock : it->blocks()) {
39 i += limitedBlockSize(subblock, limit - i);
48 bool isSmallBlock(Block* body) {
49 return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
54 void inlineBody(Node* loop) {
55 auto graph = loop->owningGraph();
56 auto body = loop->blocks().at(0);
57 WithInsertPoint insert_point_guard{loop};
59 std::unordered_map<Value*, Value*> value_map;
60 auto get_value = [&](Value* v) {
61 auto it = value_map.find(v);
62 if (it != value_map.end())
69 for (
size_t i = 2; i < loop->inputs().size(); ++i) {
70 value_map[body->inputs()[i - 1]] = loop->inputs()[i];
73 for (Node* orig : body->nodes()) {
74 Node* clone = graph->insertNode(graph->createClone(orig, get_value));
75 for (
size_t i = 0; i < orig->outputs().size(); ++i) {
76 value_map[orig->outputs()[i]] = clone->outputs()[i];
79 for (
size_t i = 0; i < loop->outputs().size(); ++i) {
80 loop->outputs().at(i)->replaceAllUsesWith(
81 get_value(body->outputs().at(i + 1)));
89 void repeatBody(Block* body, int64_t times) {
94 auto body_start = body->nodes().begin();
95 auto body_end = std::prev(body->nodes().end());
96 auto graph = body->owningGraph();
97 WithInsertPoint insert_point_guard{body};
99 std::unordered_map<Value*, Value*> value_map;
100 auto get_value = [&](Value* v) {
101 auto it = value_map.find(v);
102 if (it != value_map.end())
107 for (int64_t i = 1; i < times; ++i) {
111 AT_ASSERT(body->inputs().size() == body->outputs().size());
112 for (
size_t i = 1; i < body->inputs().size(); ++i) {
113 value_map[body->inputs()[i]] = get_value(body->outputs()[i]);
117 for (
auto it = body_start; it != std::next(body_end); ++it) {
119 Node* clone = graph->insertNode(graph->createClone(orig, get_value));
120 for (
size_t i = 0; i < orig->outputs().size(); ++i) {
121 value_map[orig->outputs()[i]] = clone->outputs()[i];
127 const std::vector<Value*> new_outputs = fmap(body->outputs(), get_value);
128 for (int64_t i = new_outputs.size() - 1; i >= 0; --i) {
129 body->eraseOutput(i);
131 for (Value* output : new_outputs) {
132 body->registerOutput(output);
139 EliminateDeadCode(body,
false);
144 void replaceLoopCounter(Node* loop) {
145 Graph* graph = loop->owningGraph();
146 Block* body = loop->blocks().at(0);
147 WithInsertPoint guard(loop);
148 Value* init_counter = graph->insertConstant(0);
150 loop->insertInput(2, init_counter);
151 loop->insertOutput(0)->setType(IntType::get());
153 Value* internal_counter = body->insertInput(1)->setType(init_counter->type());
154 body->inputs()[0]->replaceAllUsesWith(internal_counter);
156 WithInsertPoint insertPointGuard{body->return_node()};
157 Value* result = graph->insert(aten::add, {internal_counter, 1});
158 body->insertOutput(1, result);
161 void unroll(Node* loop) {
162 Graph* graph = loop->owningGraph();
163 Block* body = loop->blocks().at(0);
164 if (!isSmallBlock(body))
171 if (body->inputs()[0]->uses().size() > 0)
172 replaceLoopCounter(loop);
176 Value* trip_count = loop->inputs().at(0);
177 int64_t const_len = constant_as<int64_t>(trip_count).value_or(-1);
178 if (const_len != -1 && const_len < kMaxBodyRepeats) {
179 repeatBody(body, const_len);
184 WithInsertPoint insert_point_guard{loop};
187 Node* loop_epilogue =
188 graph->createClone(loop, [](Value* v) {
return v; })->insertAfter(loop);
189 for (
size_t i = 0; i < loop->outputs().size(); ++i) {
190 loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]);
191 loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
194 repeatBody(body, kUnrollFactor);
197 Value* iter_count = loop->inputs().at(0);
198 Value* unrolled_iter_count = graph->insert(
199 aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
200 loop->replaceInput(0, unrolled_iter_count);
201 loop_epilogue->replaceInput(
206 graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
209 void UnrollLoops(Block* block) {
210 for (
auto it = block->nodes().begin(); it != block->nodes().end();) {
215 for (Block* subblock : node->blocks()) {
216 UnrollLoops(subblock);
218 if (isForLoop(node)) {
226 void UnrollLoops(std::shared_ptr<Graph>& graph) {
227 UnrollLoops(graph->block());
228 EliminateDeadCode(graph);