Caffe2 - C++ API
A deep learning, cross platform ML framework
create_autodiff_subgraphs.cpp
1 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/autodiff.h>
5 #include <torch/csrc/jit/ir.h>
6 #include <torch/csrc/jit/passes/alias_analysis.h>
7 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
8 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
9 
10 namespace torch {
11 namespace jit {
12 
13 namespace {
14 
15 class SubgraphSlicer {
16  public:
17  SubgraphSlicer(
18  Block* block,
19  std::shared_ptr<Graph> graph,
20  size_t minSubgraphSize)
21  : block_(block),
22  graph_(std::move(graph)),
23  minSubgraphSize_(minSubgraphSize) {}
24 
25  void run(std::vector<Node*>& diffGraphs) {
26  // We need to run the slicer multiple times in order to get all merge
27  // opportunities. This is because moveBeforeTopologicalValid may reorder
28  // nodes to be AFTER the current iteration point. In order to properly
29  // consider those nodes for merging, we need run the pass until no changes
30  // have been made.
31  //
32  // Example:
33  // c = f(a, b)
34  // d = f(c)
35  // e = f(d) <- iter is here, moving upward
36  // After c.moveBeforeTopologicallyValid(e), we have:
37  // c = f(a, b)
38  // e = f(d) <- iter still here
39  // d = f(c) <- this was node moved on the other side.
40  bool any_changed = true;
41  while (any_changed) {
42  any_changed = false;
43  AliasDb aliasDb(graph_);
44  for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
45  bool changed;
46  std::tie(it, changed) = scanNode(*it, aliasDb);
47  any_changed |= changed;
48  }
49  }
50 
51  // Done constructing subgraphs. Do some post-processing cleanup:
52  // 1. Run CSE to delete redundanet constant nodes.
53  // 2. We may need to re-inline ones that are too small.
54  auto curNode = *block_->nodes().rbegin();
55  while (curNode != *block_->nodes().rend()) {
56  for (auto subBlock : curNode->blocks()) {
57  SubgraphSlicer(subBlock, graph_, minSubgraphSize_).run(diffGraphs);
58  }
59 
60  // Save the previous node, since we might delete `curNode` in next block
61  auto prevNode = curNode->prev();
62  if (curNode->kind() == prim::DifferentiableGraph) {
63  // Inlining nodes may cause some subexpression to come back in the
64  // subgraphs (for example, copying constants in repeatedly will generate
65  // redundant prim::Constants). Run CSE to clean them up.
66  EliminateCommonSubexpression(curNode->g(attr::Subgraph));
67 
68  if (!inlineIfTooSmall(curNode)) {
69  diffGraphs.push_back(curNode);
70  }
71  }
72  curNode = prevNode;
73  }
74  // Run CSE one more time to eliminate duplicates that may have occured
75  // while re-inlining subgraphs.
76  EliminateCommonSubexpression(graph_);
77  }
78 
79  private:
80  // Inline this node's group subgraph into the outer graph if it's smaller
81  // than the specified minimum size.
82  //
83  // Returns true if an inlining has occured, false otherwise.
84  bool inlineIfTooSmall(Node* n) {
85  AT_ASSERT(n->kind() == prim::DifferentiableGraph);
86  auto subgraph = SubgraphUtils::getSubgraph(n);
87  size_t i = 0;
88  for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
89  ++it) {
90  if (++i >= minSubgraphSize_) {
91  return false;
92  }
93  }
94 
95  SubgraphUtils::unmergeSubgraph(n);
96  return true;
97  }
98 
99  value_list sortReverseTopological(ArrayRef<Value*> inputs) {
100  value_list result;
101  for (auto i : inputs) {
102  if (i->node()->owningBlock() == block_) {
103  result.push_back(i);
104  }
105  }
106  // Sort in reverse topological order
107  std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
108  return a->node()->isAfter(b->node());
109  });
110  return result;
111  }
112 
113  bool shouldConsiderForMerge(Node* node) {
114  // if we're already in the process of merging
115  if (node->kind() == prim::DifferentiableGraph) {
116  return true;
117  }
118  if (node->kind() == prim::Constant) {
119  return false;
120  }
121  return isDifferentiable(node);
122  }
123 
124  std::pair<graph_node_list::iterator, bool> scanNode(
125  Node* consumer,
126  AliasDb& aliasDb) {
127  if (shouldConsiderForMerge(consumer)) {
128  if (consumer->kind() != prim::DifferentiableGraph) {
129  consumer = SubgraphUtils::createSingletonSubgraph(
130  consumer, prim::DifferentiableGraph);
131  }
132  auto inputs = sortReverseTopological(consumer->inputs());
133  for (auto input : inputs) {
134  if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
135  // we successfully merged, so the new group's `inputs` may have
136  // changed. So rescan the new group for more merging opportunities.
137  return std::make_pair(group.value()->reverseIterator(), true);
138  }
139  }
140  }
141 
142  return std::make_pair(++consumer->reverseIterator(), false);
143  }
144 
145  // Try to merge `producer` into `consumer`. If successful, this destroys
146  // `producer` and returns the `consumer` group.
147  c10::optional<Node*> tryMerge(
148  Node* consumer,
149  Node* producer,
150  AliasDb& aliasDb) {
151  AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
152  bool canMerge = shouldConsiderForMerge(producer) &&
153  aliasDb.moveBeforeTopologicallyValid(producer, consumer);
154 
155  if (!canMerge) {
156  return c10::nullopt;
157  }
158 
159  SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
160 
161  return consumer;
162  }
163 
164  Block* block_;
165  std::shared_ptr<Graph> graph_;
166  size_t minSubgraphSize_;
167 };
168 } // anonymous namespace
169 
170 std::vector<Node*> CreateAutodiffSubgraphs(
171  const std::shared_ptr<Graph>& graph,
172  size_t threshold) {
173  std::vector<Node*> diff_nodes;
174  SubgraphSlicer(graph->block(), graph, threshold).run(diff_nodes);
175  return diff_nodes;
176 }
177 
178 } // namespace jit
179 } // namespace torch
Definition: jit_type.h:17