Caffe2 - C++ API
A deep learning, cross platform ML framework
test_irparser.h
1 #pragma once
2 
3 #include <torch/csrc/jit/ir.h>
4 #include <torch/csrc/jit/irparser.h>
5 #include "test/cpp/jit/test_base.h"
6 
7 #include <sstream>
8 #include <string>
9 
10 namespace torch {
11 namespace jit {
12 
19 static void checkRoundtrip(const std::string& s) {
20  auto graph = std::make_shared<Graph>();
21  script::parseIR(s, &*graph);
22  std::ostringstream ss;
23  ss << *graph;
24  std::string parsed = ss.str();
25 
26  // Skip whitespace in the beginning of the input string.
27  int i = 0;
28  for (char c : s) {
29  if (!isspace(c)) {
30  break;
31  }
32  i++;
33  }
34  std::string original = s.substr(i, s.size());
35  if (original != parsed) {
36  std::cerr << "Input:" << std::endl << original << std::endl;
37  std::cerr << "Parsed:" << std::endl << parsed << std::endl;
38  }
39  AT_ASSERT(original == parsed);
40 }
41 
42 void testIRParser() {
43  {
44  auto graph = std::make_shared<Graph>();
45  script::parseIR(
46  R"IR(
47 graph(%0 : Tensor, %1 : Tensor):
48  %2 : Tensor = foo::add(%0, %1)
49  %res, %3 = foo::mul(%0, %2)
50  %x, %y = foo::combine(%res, %2, %3)
51  return (%x, %y, %res))IR",
52  &*graph);
53 
54  AT_ASSERT(graph->inputs().size() == 2);
55  AT_ASSERT(graph->outputs().size() == 3);
56  Value* x = graph->outputs()[0];
57  Value* y = graph->outputs()[1];
58  Value* res = graph->outputs()[2];
59  Value* t0 = graph->inputs()[0];
60  Value* t1 = graph->inputs()[1];
61  AT_ASSERT(x->node() == y->node());
62  Node* comb = x->node();
63  Value* t2 = comb->inputs()[1];
64  Value* t3 = comb->inputs()[2];
65  AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
66  AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
67  AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
68  Node* mul = res->node();
69  AT_ASSERT(mul->kind().toQualString() == std::string("foo::mul"));
70  AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2}));
71  AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3}));
72  Node* add = t2->node();
73  AT_ASSERT(add->kind().toQualString() == std::string("foo::add"));
74  AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1}));
75  AT_ASSERT(add->outputs() == std::vector<Value*>({t2}));
76  }
77  {
78  checkRoundtrip(R"IR(
79 graph():
80  %0 : Tensor = a::a()
81  block0():
82  %1 : Tensor = b::b()
83  block0():
84  %2 : Tensor = c::c()
85  -> ()
86  -> ()
87  %3 : Tensor = d::d()
88  return (%3)
89 )IR");
90  }
91  {
92  checkRoundtrip(R"IR(
93 graph(%0 : Tensor,
94  %1 : Tensor,
95  %2 : Tensor):
96  %3 : int = prim::Constant[value=1]()
97  %4 : Tensor = aten::add(%0, %1, %3)
98  %5 : Tensor = prim::If(%2)
99  block0():
100  %6 : int = prim::Constant[value=1]()
101  %7 : Tensor = aten::add(%1, %3, %6)
102  %8 : int = prim::Constant[value=1]()
103  %9 : Tensor = aten::add(%7, %3, %8)
104  -> (%9)
105  %10 : int = prim::Constant[value=1]()
106  %11 : Tensor = aten::add(%5, %3, %10)
107  return (%11)
108 )IR");
109  }
110  {
111  auto graph = std::make_shared<Graph>();
112  script::parseIR(
113  R"IR(
114 graph(%a):
115  return (%a))IR",
116  &*graph);
117  graph->inputs()[0]->type()->expect<TensorType>();
118  }
119  {
120  // Check that parser corectly handles values reusing the same name.
121  auto graph = std::make_shared<Graph>();
122  script::parseIR(
123  R"IR(
124 graph(%x):
125  %x = a::a(%x)
126  %x = b::b(%x)
127  return (%x))IR",
128  &*graph);
129  Value* x0 = graph->inputs()[0];
130  Value* x2 = graph->outputs()[0];
131  Node* b = x2->node();
132  Value* x1 = b->inputs()[0];
133  Node* a = x1->node();
134  AT_ASSERT(a->inputs() == std::vector<Value*>({x0}));
135  AT_ASSERT(a->outputs() == std::vector<Value*>({x1}));
136  AT_ASSERT(b->inputs() == std::vector<Value*>({x1}));
137  AT_ASSERT(b->outputs() == std::vector<Value*>({x2}));
138  }
139  {
140  // Check that parser handles attributes and types.
141  checkRoundtrip(
142  R"IR(
143 graph(%0 : Tensor,
144  %1 : Tensor,
145  %2 : Tensor):
146  %3 : int, %4 : Tensor = qqq::qqq[i_asdf=2, f_asdf=3.14, s_asdf="hello", ss_asdf=["hello world", "bye bye"]](%0)
147  %5 : int, %6 : Tensor = ppp::ppp[i_asdf=2, f_asdf=3.14, s_asdf="\"\"\"\"\nhe\"llo", q=[3, 2, 4]](%0)
148  %7 : float = vvv::vvv[s_asdf="hello"](%0)
149  %8 : string = z::z()
150  return (%7)
151 )IR");
152  }
153 
154  {
155  checkRoundtrip(
156  R"IR(
157 graph(%0 : Tensor,
158  %1 : Tensor,
159  %2 : Tensor):
160  %3 : int? = prim::Constant()
161  return (%3)
162 )IR");
163  }
164 
165  {
166  checkRoundtrip(
167  R"IR(
168 graph(%0 : Tensor,
169  %1 : Tensor,
170  %2 : Tensor):
171  %3 : Float(*, *, *) = prim::Constant()
172  return (%3)
173 )IR");
174  }
175 
176  {
177  checkRoundtrip(
178  R"IR(
179 graph(%0 : Tensor,
180  %1 : Tensor,
181  %2 : Tensor):
182  %3 : Long() = prim::Constant()
183  return (%3)
184 )IR");
185  }
186 
187  {
188  checkRoundtrip(
189  R"IR(
190 graph(%0 : Tensor,
191  %1 : Tensor,
192  %2 : Tensor):
193  %3 : Double(4, 4, 5) = prim::Constant()
194  return (%3)
195 )IR");
196  }
197 
198  {
199  bool error_thrown = false;
200  try {
201  checkRoundtrip(
202  R"IR(
203 graph(%0 : Tensor,
204  %1 : Tensor,
205  %2 : Tensor):
206  %3 : Double(4!, 4, 5) = prim::Constant()
207  return (%3)
208 )IR");
209  } catch (const std::exception& error) {
210  error_thrown = true;
211  }
212  AT_ASSERT(error_thrown);
213  }
214 }
215 } // namespace jit
216 } // namespace torch
Definition: jit_type.h:17