Caffe2 - C++ API
A deep learning, cross platform ML framework
test_constant_pooling.h
1 #pragma once
2 
3 #include <torch/csrc/jit/ir.h>
4 #include <torch/csrc/jit/irparser.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/constant_propagation.h>
7 #include <torch/csrc/jit/testing/file_check.h>
8 #include "test/cpp/jit/test_base.h"
9 
10 #include <sstream>
11 #include <string>
12 
13 namespace torch {
14 namespace jit {
15 
16 void testConstantPooling() {
17  {
18  auto graph = std::make_shared<Graph>();
19  script::parseIR(
20  R"IR(
21 graph():
22  %8 : int = prim::Constant[value=1]()
23  %10 : int = prim::Constant[value=1]()
24  return (%8, %10)
25  )IR",
26  &*graph);
27  ConstantPooling(graph);
28  testing::FileCheck()
29  .check_count("prim::Constant", 1, /*exactly*/ true)
30  ->run(*graph);
31  }
32  {
33  auto graph = std::make_shared<Graph>();
34  script::parseIR(
35  R"IR(
36 graph(%cond : Tensor):
37  %a : string = prim::Constant[value="bcd"]()
38  %3 : bool = prim::Bool(%cond)
39  %b : string = prim::If(%3)
40  block0():
41  %b.1 : string = prim::Constant[value="abc"]()
42  -> (%b.1)
43  block1():
44  %b.2 : string = prim::Constant[value="abc"]()
45  -> (%b.2)
46  %7 : (string, string) = prim::TupleConstruct(%a, %b)
47  return (%7)
48  )IR",
49  &*graph);
50  ConstantPooling(graph);
51  testing::FileCheck()
52  .check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
53  ->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
54  ->run(*graph);
55  }
56  {
57  auto graph = std::make_shared<Graph>();
58  script::parseIR(
59  R"IR(
60 graph():
61  %2 : int = prim::Constant[value=2]()
62  %1 : int = prim::Constant[value=1]()
63  %5 : int? = prim::Constant()
64  %7 : Device? = prim::Constant()
65  %10 : int = prim::Constant[value=6]()
66  %3 : int[] = prim::ListConstruct(%1, %2)
67  %x : Tensor = aten::tensor(%3, %5, %7)
68  %y : Tensor = aten::tensor(%3, %10, %7)
69  %9 : int[] = prim::ListConstruct(%1, %2)
70  %z : Tensor = aten::tensor(%9, %10, %7)
71  %14 : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
72  return (%14)
73  )IR",
74  &*graph);
75  // three tensors created - two different devices among the three
76  // don't have good support for parsing tensor constants
77  ConstantPropagation(graph);
78  ConstantPooling(graph);
79  testing::FileCheck()
80  .check_count("Float(2) = prim::Constant", 1, /*exactly*/ true)
81  ->check_count("Long(2) = prim::Constant", 1, /*exactly*/ true)
82  ->run(*graph);
83  }
84 }
85 
86 } // namespace jit
87 } // namespace torch
Definition: jit_type.h:17