Caffe2 - C++ API
A deep learning, cross platform ML framework
inline_autodiff_subgraphs.cpp
1 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
2 
3 #include <torch/csrc/jit/ir.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 namespace {
11 
12 bool canRunWithAutograd(Node* node) {
13  return node->kind() != prim::FusionGroup;
14 }
15 
16 void InlineAutodiffSubgraphs(Block* block, size_t threshold);
17 
18 graph_node_list::iterator scanNode(Node* node, size_t threshold) {
19  auto next_node = ++node->iterator();
20 
21  for (Block* block : node->blocks()) {
22  InlineAutodiffSubgraphs(block, threshold);
23  }
24 
25  if (node->kind() != prim::DifferentiableGraph) {
26  return next_node;
27  }
28 
29  auto subgraph = node->g(attr::Subgraph);
30  int64_t subgraph_size =
31  std::distance(subgraph->nodes().begin(), subgraph->nodes().end());
32  if (subgraph_size >= static_cast<int64_t>(threshold)) {
33  return next_node;
34  }
35 
36  if (!std::all_of(
37  subgraph->nodes().begin(),
38  subgraph->nodes().end(),
39  canRunWithAutograd)) {
40  return next_node;
41  }
42 
43  SubgraphUtils::unmergeSubgraph(node);
44  return next_node;
45 }
46 
47 void InlineAutodiffSubgraphs(Block* block, size_t threshold) {
48  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
49  it = scanNode(*it, threshold);
50  }
51 }
52 
53 } // anonymous namespace
54 
55 void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold) {
56  InlineAutodiffSubgraphs(graph->block(), threshold);
57  EliminateDeadCode(graph);
58 }
59 
60 } // namespace jit
61 } // namespace torch
Definition: jit_type.h:17