Caffe2 - C++ API
A deep learning, cross platform ML framework
test_subgraph_utils.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "test/cpp/jit/test_utils.h"
5 
6 #include "torch/csrc/jit/passes/common_subexpression_elimination.h"
7 #include "torch/csrc/jit/passes/utils/subgraph_utils.h"
8 
9 namespace torch {
10 namespace jit {
11 namespace test {
12 
13 void testSubgraphUtils() {
14  auto graph = build_lstm();
15  EliminateCommonSubexpression(graph);
16 
17  std::vector<Node*> originalNodes(
18  graph->nodes().begin(), graph->nodes().end());
19 
20  // Merge everything into a single subgraph
21  bool first = true;
22  Node* subgraph;
23  for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend();) {
24  if (first) {
25  subgraph = SubgraphUtils::createSingletonSubgraph(
26  *it, prim::DifferentiableGraph);
27  it = ++subgraph->reverseIterator();
28  first = false;
29  }
30 
31  SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
32  it = ++subgraph->reverseIterator();
33  }
34 
35  // Unmerge and compare with original node listing
36  SubgraphUtils::unmergeSubgraph(subgraph);
37  EliminateCommonSubexpression(graph);
38 
39  std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
40  ASSERT_EQ(originalNodes.size(), newNodes.size());
41 }
42 
43 } // namespace test
44 } // namespace jit
45 } // namespace torch
Definition: module.cpp:17
Definition: jit_type.h:17