3 #include <torch/csrc/jit/ir.h> 4 #include <torch/csrc/jit/irparser.h> 5 #include <torch/csrc/jit/passes/constant_pooling.h> 6 #include <torch/csrc/jit/passes/constant_propagation.h> 7 #include <torch/csrc/jit/testing/file_check.h> 8 #include "test/cpp/jit/test_base.h" 16 void testConstantPooling() {
18 auto graph = std::make_shared<Graph>();
22 %8 : int = prim::Constant[value=1]() 23 %10 : int = prim::Constant[value=1]() 27 ConstantPooling(graph); 29 .check_count("prim::Constant", 1,
true)
33 auto graph = std::make_shared<Graph>();
36 graph(%cond : Tensor): 37 %a : string = prim::Constant[value="bcd"]() 38 %3 : bool = prim::Bool(%cond) 39 %b : string = prim::If(%3) 41 %b.1 : string = prim::Constant[value="abc"]() 44 %b.2 : string = prim::Constant[value="abc"]() 46 %7 : (string, string) = prim::TupleConstruct(%a, %b) 50 ConstantPooling(graph); 52 .check_count("prim::Constant[value=\"abc\"]", 1,
true)
53 ->check_count(
"prim::Constant[value=\"bcd\"]", 1,
true)
57 auto graph = std::make_shared<Graph>();
61 %2 : int = prim::Constant[value=2]() 62 %1 : int = prim::Constant[value=1]() 63 %5 : int? = prim::Constant() 64 %7 : Device? = prim::Constant() 65 %10 : int = prim::Constant[value=6]() 66 %3 : int[] = prim::ListConstruct(%1, %2) 67 %x : Tensor = aten::tensor(%3, %5, %7) 68 %y : Tensor = aten::tensor(%3, %10, %7) 69 %9 : int[] = prim::ListConstruct(%1, %2) 70 %z : Tensor = aten::tensor(%9, %10, %7) 71 %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y) 77 ConstantPropagation(graph);
78 ConstantPooling(graph);
80 .check_count(
"Float(2) = prim::Constant", 1,
true)
81 ->check_count(
"Long(2) = prim::Constant", 1,
true)