Caffe2 - C++ API
A deep learning, cross platform ML framework
subgraph_utils.cpp
1 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
2 
3 namespace torch {
4 namespace jit {
5 namespace SubgraphUtils {
6 namespace {
7 
8 bool hasSubgraph(Node* n) {
9  return n->hasAttribute(attr::Subgraph);
10 }
11 
12 // Combine the nodes in two subgraph together. The nodes will end up in
13 // `mergeTo`, and `mergeFrom` is destroyed.
14 void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
15  Node* nodeBeforeMergeFrom = mergeFrom->prev();
16  Node* nodeAfterMergeFrom = mergeFrom->next();
17  unmergeSubgraph(mergeFrom);
18  std::vector<Node*> nodes;
19  const auto end_it = nodeBeforeMergeFrom->reverseIterator();
20  auto it = nodeAfterMergeFrom->reverseIterator();
21  ++it;
22  while (it != end_it) {
23  // NB: mergeNodeIntoSubgraph destroys node, hence the complications
24  Node* node = *it;
25  ++it;
26  mergeNodeIntoSubgraph(node, mergeTo);
27  }
28 }
29 } // namespace
30 
31 std::shared_ptr<Graph> getSubgraph(Node* n) {
32  return n->g(attr::Subgraph);
33 }
34 
35 void unmergeSubgraph(Node* subgraphNode) {
36  AT_ASSERT(subgraphNode->kind() == prim::DifferentiableGraph);
37 
38  // Inline the graph, replace uses of node outputs and destroy the node
39  const auto subgraphOutputs = inlineGraph(
40  getSubgraph(subgraphNode), subgraphNode->inputs(), subgraphNode);
41  AT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
42  for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
43  subgraphNode->outputs()[i]->replaceAllUsesWith(subgraphOutputs[i]);
44  }
45  subgraphNode->destroy();
46 }
47 
48 void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
49  AT_ASSERT(hasSubgraph(subgraphNode));
50  if (hasSubgraph(toMerge)) {
51  return mergeSubgraph(subgraphNode, toMerge);
52  }
53 
54  auto subgraph = getSubgraph(subgraphNode);
55 
56  // Map from values in the surrounding graph to inputs in the subgraph
57  std::unordered_map<Value*, Value*> inputsMap;
58 
59  AT_ASSERT(subgraphNode->inputs().size() == subgraph->inputs().size());
60  size_t idx = 0;
61  for (auto input : subgraphNode->inputs()) {
62  inputsMap[input] = subgraph->inputs()[idx];
63  idx++;
64  }
65 
66  // Add n's inputs to the group's input list if we don't already have them
67  WithInsertPoint guard(*subgraph->nodes().begin());
68  for (auto input : toMerge->inputs()) {
69  if (inputsMap.count(input) == 0) {
70  // Clone constants inside the subgraph instead of referencing them, to
71  // enable more optimizations
72  if (auto value = toIValue(input)) {
73  auto nv = subgraph->insertConstant(*value);
74  nv->setType(input->type()); // Need to retain type information on Nones
75  inputsMap[input] = nv;
76  } else {
77  // The common case: this is a regular input, so just register it with
78  // the group node and inner subgraph
79  subgraphNode->addInput(input);
80  auto inputToGraph = subgraph->addInput();
81  inputToGraph->setType(input->type());
82  inputsMap[input] = inputToGraph;
83  }
84  }
85  }
86 
87  // Merge the node into the graph
88  auto mergedNode = subgraph->insertNode(
89  subgraph->createClone(toMerge, [&](Value* v) { return inputsMap[v]; }));
90 
91  // If n's outputs were inputs to `group`, remove them since we just merged
92  // n in.
93  //
94  // i.e.,
95  // x = f(w); group(x, y, z) becomes group(w, y, z).
96  // x, y, z = f(w); group(x, y, z) becomes group(w).
97  auto inputs = subgraphNode->inputs();
98  for (size_t i = 0; i < toMerge->outputs().size(); ++i) {
99  auto it = std::find(inputs.begin(), inputs.end(), toMerge->outputs()[i]);
100  if (it != inputs.end()) {
101  size_t p = it - inputs.begin();
102  subgraphNode->removeInput(p);
103  subgraph->inputs()[p]->replaceAllUsesWith(mergedNode->outputs()[i]);
104  subgraph->eraseInput(p);
105  }
106  }
107 
108  // Add n's outputs to the group node and inner subgraph outputs.
109  for (size_t i = 0; i < toMerge->outputs().size(); i++) {
110  auto oldOutput = toMerge->outputs()[i];
111 
112  // Only register the output in the group node if it's actually used
113  // outside the subgraph.
114  const auto hasUsesOutsideSubgraph = std::any_of(
115  oldOutput->uses().cbegin(),
116  oldOutput->uses().cend(),
117  [&](const Use& use) { return use.user->isAfter(subgraphNode); });
118 
119  if (hasUsesOutsideSubgraph) {
120  auto newOutput = mergedNode->outputs()[i];
121  subgraph->registerOutput(newOutput);
122  auto groupOutput = subgraphNode->addOutput();
123  groupOutput->copyMetadata(oldOutput);
124  oldOutput->replaceAllUsesWith(groupOutput);
125  }
126  }
127 
128  // Remove the original node now that the merge is complete
129  toMerge->destroy();
130 }
131 
132 // Invariant we depend on in mergeSubgraph: All inlined nodes are created
133 // between the node preceding insertBefore and insertBefore.
134 std::vector<Value*> inlineGraph(
135  const std::shared_ptr<Graph>& subgraph,
136  at::ArrayRef<Value*> outerInputs,
137  Node* insertBefore) {
138  auto outerGraph = insertBefore->owningGraph();
139 
140  // Initialize a map of inner graph values to outer graph values
141  std::unordered_map<const Value*, Value*> innerToOuter;
142  const auto innerInputs = subgraph->inputs();
143  AT_ASSERT(outerInputs.size() == innerInputs.size());
144  for (size_t i = 0; i < innerInputs.size(); ++i) {
145  innerToOuter[innerInputs[i]] = outerInputs[i];
146  }
147 
148  // Clone all nodes
149  for (auto inner : subgraph->nodes()) {
150  Node* outer = outerGraph->createClone(
151  inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
152  outer->insertBefore(insertBefore);
153  const auto innerOutputs = inner->outputs();
154  const auto outerOutputs = outer->outputs();
155  for (size_t i = 0; i < innerOutputs.size(); ++i) {
156  innerToOuter[innerOutputs[i]] = outerOutputs[i];
157  }
158  }
159 
160  return fmap(subgraph->outputs(), [&](Value* output) {
161  return innerToOuter.at(output);
162  });
163 }
164 
165 Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
166  auto graph = n->owningGraph();
167  auto subgraph = graph->create(subgraphKind, 0);
168  subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
169  subgraph->insertBefore(n);
170  mergeNodeIntoSubgraph(n, subgraph);
171  return subgraph;
172 }
173 
174 } // namespace SubgraphUtils
175 } // namespace jit
176 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41