Caffe2 - C++ API
A deep learning, cross platform ML framework
dead_code_elimination.cpp
1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2 
3 #include <torch/csrc/jit/ir_views.h>
4 #include <torch/csrc/jit/passes/alias_analysis.h>
5 #include <torch/csrc/utils/memory.h>
6 
7 #include <unordered_map>
8 
9 namespace torch {
10 namespace jit {
11 
12 namespace prim {
13 using namespace ::c10::prim;
14 }
15 
17  public:
18  explicit DeadCodeEliminator(std::shared_ptr<Graph> graph)
19  : aliasDb_(torch::make_unique<AliasDb>(std::move(graph))) {}
20  DeadCodeEliminator() = default;
21 
22  // The algorithm is an inverse mark-and-sweep. Starting from the return node,
23  // we mark "live" nodes that are necessary for the output. Nodes that have
24  // side effects are also marked.
25  void run(Block* block, bool recurse) {
26  // Initialize by marking the return node and all its consumed values as live
27  mark(block->return_node());
28 
29  mark(block);
30 
31  deleteCallback_(liveValues_);
32 
33  sweep(block, recurse);
34  }
35 
36  void setDeleteCallback(
37  std::function<void(const std::unordered_set<const Value*>&)>
38  deleteCallback) {
39  deleteCallback_ = std::move(deleteCallback);
40  }
41 
42  private:
43  // Special handling for block return nodes. Unlike other nodes, the block
44  // return node doesn't really "use" its inputs. Consider:
45  //
46  // %a0 = aten::foo()
47  // %b = aten::foo()
48  // %a2, %b2 = prim::If(%cond) {
49  // block0() {
50  // %a1 = aten::foo(%.0)
51  // %b1 = aten::foo(%b)
52  // } -> (%a1, %b1)
53  // }
54  // return (%a2)
55  //
56  // We want to be able to DCE all the %b stuff. So when processing block
57  // returns, we only mark producers for values that "live" (i.e. used outside
58  // the block).
59  void markReturnNode(Node* node) {
60  if (marked_.count(node)) {
61  return;
62  }
63 
64  AT_ASSERT(node->owningBlock()->return_node() == node);
65  auto outerNode = node->owningBlock()->owningNode();
66  if (outerNode == nullptr || outerNode->kind() == prim::Reverse) {
67  // If there's no outer node, we're looking at the graph's top-level
68  // return block. We consider all graph outputs to be "used", so just mark
69  // this node normally.
70  return mark(node);
71  }
72 
73  // Collect all inputs that are actually live
74  if (outerNode->kind() == prim::Loop ||
75  outerNode->kind() == c10::onnx::Loop) {
76  // Special handling to deal with loop carried dependencies.
77  auto loop = LoopView(outerNode);
78  for (size_t i = 0; i < loop.carriedOutputs().size(); i++) {
79  auto innerInput = loop.bodyCarriedInputs().at(i);
80  auto innerOutput = loop.bodyCarriedOutputs().at(i);
81  auto outerOutput = loop.carriedOutputs().at(i);
82  if (liveValues_.count(outerOutput) || innerInput->hasUses()) {
83  liveValues_.insert(innerOutput);
84  }
85  }
86 
87  // Also mark the loop next condition as live, since it will be used inside
88  // the loop body.
89  liveValues_.insert(loop.nextCond());
90  } else {
91  AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
92  for (size_t i = 0; i < outerNode->outputs().size(); i++) {
93  auto innerOutput = node->inputs()[i];
94  auto outerOutput = outerNode->outputs()[i];
95  if (liveValues_.count(outerOutput)) {
96  liveValues_.insert(innerOutput);
97  }
98  }
99  }
100 
101  marked_.insert(node);
102  }
103 
104  void mark(Block* block) {
105  // Mark all nodes with side effects.
106  for (auto node : block->nodes()) {
107  if (hasSideEffects(node)) {
108  mark(node);
109  }
110  }
111 
112  // Initialize by marking the return node
113  markReturnNode(block->return_node());
114 
115  for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
116  auto node = *it;
117  for (auto subBlock : node->blocks()) {
118  mark(subBlock);
119  }
120  markIfLive(node);
121  }
122  }
123 
124  // If we output or write to a live memory location, mark this node
125  void markIfLive(Node* node) {
126  for (const auto output : node->outputs()) {
127  if (liveValues_.count(output)) {
128  return mark(node);
129  }
130  }
131 
132  if (aliasDb_) {
133  if (aliasDb_->writesToAlias(node, liveValues_, /*recurseBlocks=*/false)) {
134  return mark(node);
135  }
136  }
137  }
138 
139  // Mark this node as live and add this node's inputs and aliases to the live
140  // value sets.
141  void mark(Node* node) {
142  if (marked_.count(node)) {
143  return;
144  }
145 
146  marked_.insert(node);
147 
148  // Mark all nodes in this node's blockchain (since owning nodes are
149  // considered live if they contain a live node)
150  auto curNode = node;
151  while (curNode) {
152  if (!curNode->owningBlock()) {
153  break;
154  }
155 
156  mark(curNode);
157  curNode = curNode->owningBlock()->owningNode();
158  }
159 
160  for (const auto input : node->inputs()) {
161  if (liveValues_.count(input)) {
162  continue;
163  }
164  liveValues_.insert(input);
165  }
166  }
167 
168  // Delete all unmarked nodes.
169  void sweep(Block* block, bool recurse) {
170  auto nodes = block->nodes().reverse();
171  for (auto it = nodes.begin(); it != nodes.end(); it++) {
172  auto node = *it;
173  // note these occur before the recursion because we want to uncover
174  // dead code in the blocks used to calculate the output
175  removeDeadBlockOutputs(node);
176  removeDeadLoopOutputs(node);
177  if (recurse) {
178  for (Block* block : node->blocks()) {
179  sweep(block, true);
180  }
181  }
182  // NB: Checking hasUses() is required. AD graphs are not perfectly
183  // valid, as a node in grad_desc.f might be used in reverse_block.
184  // Reverse_block is inlined in grad_desc.f before it's separated
185  // to grad_desc.df.
186  if (!(marked_.count(node) || node->hasUses())) {
187  it.destroyCurrent();
188  }
189  }
190  }
191 
192  bool hasUntrackedMutation(Node* node) {
193  if (!aliasDb_) {
194  // If we don't have alias information, all mutable ops have unknown
195  // effects and can't be considered for elimination.
196  if (!node->kind().is_aten() && !node->kind().is_prim()) {
197  return false;
198  }
199  // onnx export calls EliminateDeadCode but sometimes passes invalid
200  // aten operators. So we call maybeSchema so we handle the cases when
201  // there is no valid schema for a node
202  auto schema = node->maybeSchema();
203  return schema && schema->is_mutable();
204  } else {
205  return aliasDb_->hasUntrackedEffects(node);
206  }
207  }
208 
209  bool hasSideEffects(Node* node) {
210  auto it = memo_.find(node);
211  if (it != memo_.end())
212  return it->second;
213  bool has_side_effects = node->hasSideEffects() ||
214  std::any_of(node->blocks().begin(),
215  node->blocks().end(),
216  [&](Block* b) {
217  return std::any_of(
218  b->nodes().begin(), b->nodes().end(), [&](Node* n) {
219  return hasSideEffects(n);
220  });
221  }) ||
222  hasUntrackedMutation(node);
223 
224  memo_.emplace(node, has_side_effects);
225  return has_side_effects;
226  }
227 
228  void removeDeadBlockOutputs(Node* node) {
229  if (node->kind() != prim::If && node->kind() != prim::GradOf) {
230  return;
231  }
232 
233  for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
234  size_t i = i_1 - 1;
235  if (!node->outputs().at(i)->hasUses()) {
236  node->eraseOutput(i);
237  for (Block* b : node->blocks()) {
238  b->eraseOutput(i);
239  }
240  }
241  }
242  }
243 
244  void removeDeadLoopOutputs(Node* node) {
245  if (node->kind() != prim::Loop)
246  return;
247  auto loop_body = node->blocks().at(0);
248  auto loop_input_offset = 2; // offset of loop carried deps in input list
249  auto loop_body_offset =
250  1; // offset to the loop carried dependencies in block inputs/outputs
251 
252  for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
253  size_t i = i_1 - 1;
254  if (!node->outputs().at(i)->hasUses() &&
255  !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
256  node->eraseOutput(i);
257  node->removeInput(loop_input_offset + i);
258  loop_body->eraseInput(loop_body_offset + i);
259  loop_body->eraseOutput(loop_body_offset + i);
260  }
261  }
262  }
263 
264  std::unique_ptr<AliasDb> aliasDb_ = nullptr;
265  std::unordered_map<Node*, bool> memo_;
266  std::unordered_set<Node*> marked_;
267  std::unordered_set<const Value*> liveValues_;
268  std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
269  [](const std::unordered_set<const Value*>&) {};
270 };
271 
272 void EliminateDeadCode(const std::shared_ptr<Graph>& graph) {
273  DeadCodeEliminator(graph).run(graph->block(), /*recurse=*/true);
274 }
275 
276 void EliminateDeadCode(Block* block, bool recurse) {
277  DeadCodeEliminator().run(block, recurse);
278 }
279 
280 void EliminateDeadCode(
281  Block* block,
282  std::function<void(const std::unordered_set<const Value*>&)> cb) {
283  DeadCodeEliminator eliminator;
284  eliminator.setDeleteCallback(std::move(cb));
285  eliminator.run(block, /*recurse=*/true);
286 }
287 
288 } // namespace jit
289 } // namespace torch
Definition: jit_type.h:17