Caffe2 - C++ API
A deep learning, cross platform ML framework
test_ir.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "test/cpp/jit/test_utils.h"
5 
6 namespace torch {
7 namespace jit {
8 namespace test {
9 
10 void testAttributes() {
11  Graph g;
12  auto one = attr::alpha;
13  auto two = attr::device;
14  auto three = attr::end;
15  auto four = attr::perm;
16  Node* n = g.create(Symbol::fromQualString("foo::bar"));
17  Node& attr = *n;
18  attr.f_(one, 3.4)->i_(two, 5)->s_(three, "what");
19  ASSERT_EQ(attr.f(one), 3.4);
20  ASSERT_EQ(attr.s(three), "what");
21  ASSERT_EQ(attr.i(two), 5);
22  attr.s_(one, "no");
23  ASSERT_EQ(attr.s(one), "no");
24  ASSERT_TRUE(attr.hasAttribute(three));
25  ASSERT_TRUE(!attr.hasAttribute(four));
26  attr.ss_(two, {"hi", "now"});
27  ASSERT_EQ(attr.ss(two).at(1), "now");
28 
29  Node* n2 = g.create(Symbol::fromQualString("foo::baz"));
30  Node& attr2 = *n2;
31  attr2.copyAttributes(attr);
32  ASSERT_EQ(attr2.s(one), "no");
33  attr2.f_(one, 5);
34  ASSERT_EQ(attr.s(one), "no");
35  ASSERT_EQ(attr2.f(one), 5);
36 }
37 
38 void testBlocks(std::ostream& out = std::cout) {
39  auto g = std::make_shared<Graph>();
40  // auto g = *graph;
41  auto a = Var::asNewInput(*g, "a");
42  auto b = Var::asNewInput(*g, "b");
43  auto c = a + b;
44  auto r =
45  g->appendNode(g->create(prim::If, {Var::asNewInput(*g, "c").value()}));
46  auto then_block = r->addBlock();
47  auto else_block = r->addBlock();
48  {
49  WithInsertPoint guard(then_block);
50  auto t = c + c;
51  then_block->registerOutput(t.value());
52  }
53  {
54  WithInsertPoint guard(else_block);
55  auto d = b + c;
56  auto e = d + c;
57  else_block->registerOutput(e.value());
58  }
59  g->registerOutput((Var(r->output()) + c).value());
60  g->lint();
61  testing::FileCheck()
62  .check("add")
63  ->check("prim::If")
64  ->check("block0")
65  ->check("aten::add")
66  ->check("block1")
67  ->check_count("aten::add", 3)
68  ->run(*g);
69  r->eraseBlock(0);
70  testing::FileCheck()
71  .check("add")
72  ->check("prim::If")
73  ->check("block0")
74  ->check_not("block")
75  ->run(*g);
76  g->lint();
77  // test recursive copy of blocks works
78  auto g2 = g->copy();
79  testing::FileCheck()
80  .check("add")
81  ->check("prim::If")
82  ->check("block0")
83  ->check_not("block")
84  ->run(*g2);
85 }
86 
87 } // namespace test
88 } // namespace jit
89 } // namespace torch
Definition: module.cpp:17
Definition: jit_type.h:17