3 #include <torch/csrc/jit/ir.h> 4 #include <torch/csrc/jit/irparser.h> 5 #include "test/cpp/jit/test_base.h" 19 static void checkRoundtrip(
const std::string& s) {
20 auto graph = std::make_shared<Graph>();
21 script::parseIR(s, &*graph);
22 std::ostringstream ss;
24 std::string parsed = ss.str();
34 std::string original = s.substr(i, s.size());
35 if (original != parsed) {
36 std::cerr <<
"Input:" << std::endl << original << std::endl;
37 std::cerr <<
"Parsed:" << std::endl << parsed << std::endl;
39 AT_ASSERT(original == parsed);
44 auto graph = std::make_shared<Graph>();
47 graph(%0 : Tensor, %1 : Tensor): 48 %2 : Tensor = foo::add(%0, %1) 49 %res, %3 = foo::mul(%0, %2) 50 %x, %y = foo::combine(%res, %2, %3) 51 return (%x, %y, %res))IR", 54 AT_ASSERT(graph->inputs().size() == 2); 55 AT_ASSERT(graph->outputs().size() == 3); 56 Value* x = graph->outputs()[0]; 57 Value* y = graph->outputs()[1]; 58 Value* res = graph->outputs()[2]; 59 Value* t0 = graph->inputs()[0]; 60 Value* t1 = graph->inputs()[1]; 61 AT_ASSERT(x->node() == y->node()); 62 Node* comb = x->node(); 63 Value* t2 = comb->inputs()[1]; 64 Value* t3 = comb->inputs()[2]; 65 AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
66 AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
67 AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
68 Node* mul = res->node();
69 AT_ASSERT(mul->kind().toQualString() == std::string(
"foo::mul"));
70 AT_ASSERT(mul->inputs() == std::vector<Value*>({t0, t2}));
71 AT_ASSERT(mul->outputs() == std::vector<Value*>({res, t3}));
72 Node* add = t2->node();
73 AT_ASSERT(add->kind().toQualString() == std::string(
"foo::add"));
74 AT_ASSERT(add->inputs() == std::vector<Value*>({t0, t1}));
75 AT_ASSERT(add->outputs() == std::vector<Value*>({t2}));
96 %3 : int = prim::Constant[value=1]() 97 %4 : Tensor = aten::add(%0, %1, %3) 98 %5 : Tensor = prim::If(%2) 100 %6 : int = prim::Constant[value=1]() 101 %7 : Tensor = aten::add(%1, %3, %6) 102 %8 : int = prim::Constant[value=1]() 103 %9 : Tensor = aten::add(%7, %3, %8) 105 %10 : int = prim::Constant[value=1]() 106 %11 : Tensor = aten::add(%5, %3, %10) 111 auto graph = std::make_shared<Graph>();
117 graph->inputs()[0]->type()->expect<TensorType>(); 121 auto graph = std::make_shared<Graph>();
129 Value* x0 = graph->inputs()[0]; 130 Value* x2 = graph->outputs()[0]; 131 Node* b = x2->node(); 132 Value* x1 = b->inputs()[0]; 133 Node* a = x1->node(); 134 AT_ASSERT(a->inputs() == std::vector<Value*>({x0})); 135 AT_ASSERT(a->outputs() == std::vector<Value*>({x1})); 136 AT_ASSERT(b->inputs() == std::vector<Value*>({x1})); 137 AT_ASSERT(b->outputs() == std::vector<Value*>({x2})); 146 %3 : int, %4 : Tensor = qqq::qqq[i_asdf=2, f_asdf=3.14, s_asdf="hello", ss_asdf=["hello world", "bye bye"]](%0) 147 %5 : int, %6 : Tensor = ppp::ppp[i_asdf=2, f_asdf=3.14, s_asdf="\"\"\"\"\nhe\"llo", q=[3, 2, 4]](%0) 148 %7 : float = vvv::vvv[s_asdf="hello"](%0) 160 %3 : int? = prim::Constant() 171 %3 : Float(*, *, *) = prim::Constant() 182 %3 : Long() = prim::Constant() 193 %3 : Double(4, 4, 5) = prim::Constant() 199 bool error_thrown =
false;
206 %3 : Double(4!, 4, 5) = prim::Constant() 209 } catch (
const std::exception& error) {
212 AT_ASSERT(error_thrown);