1 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/jit/autodiff.h> 5 #include <torch/csrc/jit/ir.h> 6 #include <torch/csrc/jit/passes/alias_analysis.h> 7 #include <torch/csrc/jit/passes/common_subexpression_elimination.h> 8 #include <torch/csrc/jit/passes/utils/subgraph_utils.h> 15 class SubgraphSlicer {
19 std::shared_ptr<Graph> graph,
20 size_t minSubgraphSize)
22 graph_(
std::move(graph)),
23 minSubgraphSize_(minSubgraphSize) {}
25 void run(std::vector<Node*>& diffGraphs) {
40 bool any_changed =
true;
43 AliasDb aliasDb(graph_);
44 for (
auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
46 std::tie(it, changed) = scanNode(*it, aliasDb);
47 any_changed |= changed;
54 auto curNode = *block_->nodes().rbegin();
55 while (curNode != *block_->nodes().rend()) {
56 for (
auto subBlock : curNode->blocks()) {
57 SubgraphSlicer(subBlock, graph_, minSubgraphSize_).run(diffGraphs);
61 auto prevNode = curNode->prev();
62 if (curNode->kind() == prim::DifferentiableGraph) {
66 EliminateCommonSubexpression(curNode->g(attr::Subgraph));
68 if (!inlineIfTooSmall(curNode)) {
69 diffGraphs.push_back(curNode);
76 EliminateCommonSubexpression(graph_);
84 bool inlineIfTooSmall(Node* n) {
85 AT_ASSERT(n->kind() == prim::DifferentiableGraph);
86 auto subgraph = SubgraphUtils::getSubgraph(n);
88 for (
auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
90 if (++i >= minSubgraphSize_) {
95 SubgraphUtils::unmergeSubgraph(n);
99 value_list sortReverseTopological(ArrayRef<Value*> inputs) {
101 for (
auto i : inputs) {
102 if (i->node()->owningBlock() == block_) {
107 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
108 return a->node()->isAfter(b->node());
113 bool shouldConsiderForMerge(Node* node) {
115 if (node->kind() == prim::DifferentiableGraph) {
118 if (node->kind() == prim::Constant) {
121 return isDifferentiable(node);
124 std::pair<graph_node_list::iterator, bool> scanNode(
127 if (shouldConsiderForMerge(consumer)) {
128 if (consumer->kind() != prim::DifferentiableGraph) {
129 consumer = SubgraphUtils::createSingletonSubgraph(
130 consumer, prim::DifferentiableGraph);
132 auto inputs = sortReverseTopological(consumer->inputs());
133 for (
auto input : inputs) {
134 if (
auto group = tryMerge(consumer, input->node(), aliasDb)) {
137 return std::make_pair(group.value()->reverseIterator(),
true);
142 return std::make_pair(++consumer->reverseIterator(),
false);
151 AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
152 bool canMerge = shouldConsiderForMerge(producer) &&
153 aliasDb.moveBeforeTopologicallyValid(producer, consumer);
159 SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
165 std::shared_ptr<Graph> graph_;
166 size_t minSubgraphSize_;
170 std::vector<Node*> CreateAutodiffSubgraphs(
171 const std::shared_ptr<Graph>& graph,
173 std::vector<Node*> diff_nodes;
174 SubgraphSlicer(graph->block(), graph, threshold).run(diff_nodes);