3 #include "test/cpp/jit/test_base.h" 4 #include "test/cpp/jit/test_utils.h" 6 #include "torch/csrc/jit/passes/common_subexpression_elimination.h" 7 #include "torch/csrc/jit/passes/utils/subgraph_utils.h" 13 void testSubgraphUtils() {
14 auto graph = build_lstm();
15 EliminateCommonSubexpression(graph);
17 std::vector<Node*> originalNodes(
18 graph->nodes().begin(), graph->nodes().end());
23 for (
auto it = graph->nodes().rbegin(); it != graph->nodes().rend();) {
25 subgraph = SubgraphUtils::createSingletonSubgraph(
26 *it, prim::DifferentiableGraph);
27 it = ++subgraph->reverseIterator();
31 SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
32 it = ++subgraph->reverseIterator();
36 SubgraphUtils::unmergeSubgraph(subgraph);
37 EliminateCommonSubexpression(graph);
39 std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
40 ASSERT_EQ(originalNodes.size(), newNodes.size());