1 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h> 3 #include <torch/csrc/jit/ir.h> 4 #include <torch/csrc/jit/passes/dead_code_elimination.h> 5 #include <torch/csrc/jit/passes/utils/subgraph_utils.h> 12 bool canRunWithAutograd(Node* node) {
13 return node->kind() != prim::FusionGroup;
16 void InlineAutodiffSubgraphs(Block* block,
size_t threshold);
18 graph_node_list::iterator scanNode(Node* node,
size_t threshold) {
19 auto next_node = ++node->iterator();
21 for (Block* block : node->blocks()) {
22 InlineAutodiffSubgraphs(block, threshold);
25 if (node->kind() != prim::DifferentiableGraph) {
29 auto subgraph = node->g(attr::Subgraph);
30 int64_t subgraph_size =
31 std::distance(subgraph->nodes().begin(), subgraph->nodes().end());
32 if (subgraph_size >= static_cast<int64_t>(threshold)) {
37 subgraph->nodes().begin(),
38 subgraph->nodes().end(),
39 canRunWithAutograd)) {
43 SubgraphUtils::unmergeSubgraph(node);
47 void InlineAutodiffSubgraphs(Block* block,
size_t threshold) {
48 for (
auto it = block->nodes().begin(); it != block->nodes().end();) {
49 it = scanNode(*it, threshold);
55 void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph,
size_t threshold) {
56 InlineAutodiffSubgraphs(graph->block(), threshold);
57 EliminateDeadCode(graph);