1 #include <torch/csrc/jit/passes/utils/subgraph_utils.h> 5 namespace SubgraphUtils {
8 bool hasSubgraph(Node* n) {
9 return n->hasAttribute(attr::Subgraph);
14 void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
15 Node* nodeBeforeMergeFrom = mergeFrom->prev();
16 Node* nodeAfterMergeFrom = mergeFrom->next();
17 unmergeSubgraph(mergeFrom);
18 std::vector<Node*> nodes;
19 const auto end_it = nodeBeforeMergeFrom->reverseIterator();
20 auto it = nodeAfterMergeFrom->reverseIterator();
22 while (it != end_it) {
26 mergeNodeIntoSubgraph(node, mergeTo);
31 std::shared_ptr<Graph> getSubgraph(Node* n) {
32 return n->g(attr::Subgraph);
35 void unmergeSubgraph(Node* subgraphNode) {
36 AT_ASSERT(subgraphNode->kind() == prim::DifferentiableGraph);
39 const auto subgraphOutputs = inlineGraph(
40 getSubgraph(subgraphNode), subgraphNode->inputs(), subgraphNode);
41 AT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
42 for (
size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
43 subgraphNode->outputs()[i]->replaceAllUsesWith(subgraphOutputs[i]);
45 subgraphNode->destroy();
48 void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
49 AT_ASSERT(hasSubgraph(subgraphNode));
50 if (hasSubgraph(toMerge)) {
51 return mergeSubgraph(subgraphNode, toMerge);
54 auto subgraph = getSubgraph(subgraphNode);
57 std::unordered_map<Value*, Value*> inputsMap;
59 AT_ASSERT(subgraphNode->inputs().size() == subgraph->inputs().size());
61 for (
auto input : subgraphNode->inputs()) {
62 inputsMap[input] = subgraph->inputs()[idx];
67 WithInsertPoint guard(*subgraph->nodes().begin());
68 for (
auto input : toMerge->inputs()) {
69 if (inputsMap.count(input) == 0) {
72 if (
auto value = toIValue(input)) {
73 auto nv = subgraph->insertConstant(*value);
74 nv->setType(input->type());
75 inputsMap[input] = nv;
79 subgraphNode->addInput(input);
80 auto inputToGraph = subgraph->addInput();
81 inputToGraph->setType(input->type());
82 inputsMap[input] = inputToGraph;
88 auto mergedNode = subgraph->insertNode(
89 subgraph->createClone(toMerge, [&](Value* v) { return inputsMap[v]; }));
97 auto inputs = subgraphNode->inputs();
98 for (
size_t i = 0; i < toMerge->outputs().size(); ++i) {
99 auto it = std::find(inputs.begin(), inputs.end(), toMerge->outputs()[i]);
100 if (it != inputs.end()) {
101 size_t p = it - inputs.begin();
102 subgraphNode->removeInput(p);
103 subgraph->inputs()[p]->replaceAllUsesWith(mergedNode->outputs()[i]);
104 subgraph->eraseInput(p);
109 for (
size_t i = 0; i < toMerge->outputs().size(); i++) {
110 auto oldOutput = toMerge->outputs()[i];
114 const auto hasUsesOutsideSubgraph = std::any_of(
115 oldOutput->uses().cbegin(),
116 oldOutput->uses().cend(),
117 [&](
const Use& use) {
return use.user->isAfter(subgraphNode); });
119 if (hasUsesOutsideSubgraph) {
120 auto newOutput = mergedNode->outputs()[i];
121 subgraph->registerOutput(newOutput);
122 auto groupOutput = subgraphNode->addOutput();
123 groupOutput->copyMetadata(oldOutput);
124 oldOutput->replaceAllUsesWith(groupOutput);
134 std::vector<Value*> inlineGraph(
135 const std::shared_ptr<Graph>& subgraph,
137 Node* insertBefore) {
138 auto outerGraph = insertBefore->owningGraph();
141 std::unordered_map<const Value*, Value*> innerToOuter;
142 const auto innerInputs = subgraph->inputs();
143 AT_ASSERT(outerInputs.
size() == innerInputs.size());
144 for (
size_t i = 0; i < innerInputs.size(); ++i) {
145 innerToOuter[innerInputs[i]] = outerInputs[i];
149 for (
auto inner : subgraph->nodes()) {
150 Node* outer = outerGraph->createClone(
151 inner, [&](Value* k) -> Value* {
return innerToOuter.at(k); });
152 outer->insertBefore(insertBefore);
153 const auto innerOutputs = inner->outputs();
154 const auto outerOutputs = outer->outputs();
155 for (
size_t i = 0; i < innerOutputs.size(); ++i) {
156 innerToOuter[innerOutputs[i]] = outerOutputs[i];
160 return fmap(subgraph->outputs(), [&](Value* output) {
161 return innerToOuter.at(output);
165 Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
166 auto graph = n->owningGraph();
167 auto subgraph = graph->create(subgraphKind, 0);
168 subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
169 subgraph->insertBefore(n);
170 mergeNodeIntoSubgraph(n, subgraph);
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...