Caffe2 - C++ API
A deep learning, cross platform ML framework
common_subexpression_elimination.cpp
1 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
2 
3 #include <torch/csrc/jit/ir.h>
4 #include <torch/csrc/jit/node_hashing.h>
5 #include <torch/csrc/jit/passes/alias_analysis.h>
6 
7 #include <unordered_map>
8 
9 namespace torch {
10 namespace jit {
11 namespace {
12 // The function implements common subexpression elimination.
13 // Since the nodes are visited in topological order, one pass is enough.
14 void EliminateCommonSubexpression(
15  Block* block,
16  const AliasDb& aliasDb,
17  std::function<Node*(Node*)> parent_lookup_fn) {
18  std::unordered_set<Node*, HashNode, EqualNode> subexprs;
19  for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
20  auto node = *it;
21  if (node->hasSideEffects() || node->isNondeterministic() ||
22  aliasDb.hasWriters(node)) {
23  // Do NOT have enough information to do CSE on these nodes.
24  continue;
25  }
26 
27  if (!node->blocks().empty()) {
28  // Traverse sub-blocks.
29  for (auto block : node->blocks()) {
30  EliminateCommonSubexpression(block, aliasDb, [&](Node* n) {
31  auto existing = subexprs.find(n);
32  if (existing != subexprs.end()) {
33  return *existing;
34  }
35 
36  return parent_lookup_fn(n);
37  });
38  }
39 
40  continue;
41  }
42 
43  // Check for CSE opportunities in the parent block.
44  auto parent_lookup = parent_lookup_fn(node);
45  if (parent_lookup) {
46  node->replaceAllUsesWith(parent_lookup);
47  it.destroyCurrent();
48  continue;
49  }
50 
51  // Check whether the same subexpression already exists.
52  auto subit = subexprs.insert(node);
53  if (!subit.second) {
54  // Subexpression exists, replace the uses of node, and destroy it.
55  auto existing = *subit.first;
56  node->replaceAllUsesWith(existing);
57  // Destroy the node.
58  it.destroyCurrent();
59  }
60  }
61 }
62 } // namespace
63 
64 void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
65  AliasDb aliasDb(graph);
66  EliminateCommonSubexpression(
67  graph->block(), aliasDb, [](Node*) { return nullptr; });
68 }
69 } // namespace jit
70 } // namespace torch
Definition: jit_type.h:17