Caffe2 - C++ API
A deep learning, cross platform ML framework
test_netdef_converter.h
1 #pragma once
2 
3 #include <torch/csrc/jit/netdef_converter.h>
4 #include "test/cpp/jit/test_base.h"
5 
6 #include <sstream>
7 #include <string>
8 
9 namespace torch {
10 namespace jit {
11 
12 static caffe2::OperatorDef createOperator(
13  const std::string& name,
14  const std::vector<std::string>& inputs,
15  const std::vector<std::string>& outputs) {
16  caffe2::OperatorDef op;
17  op.set_type(name);
18  for (const auto& input : inputs) {
19  op.add_input(input);
20  }
21  for (const auto& output : outputs) {
22  op.add_output(output);
23  }
24  return op;
25 }
26 
27 void testNetDefConverter() {
28  {
29  // Check a simple net conversion back and forth.
30 
31  // Create a simple graph:
32  // graph(%0 : Tensor
33  // %1 : Tensor) {
34  // %2 : Tensor = aten::mul(%0, %1)
35  // %3 : int = prim::Constant[value=1]()
36  // %4 : Tensor = aten::add(%0, %2, %3)
37  // return (%2, %4);
38  // }
39  auto graph = std::make_shared<Graph>();
40  auto a = graph->addInput();
41  auto b = graph->addInput();
42  auto c = graph->insert(aten::mul, {a, b});
43  auto d = graph->insert(aten::add, {a, c});
44  graph->registerOutput(c);
45  graph->registerOutput(d);
46 
47  // Convert it to netdef and check the result
48  caffe2::NetDef net;
49  convertIRToNetDef(&net, *graph);
50  AT_ASSERT(net.op().size() == 3);
51  AT_ASSERT(net.external_input().size() == 2);
52  AT_ASSERT(net.external_output().size() == 2);
53 
54  const caffe2::OperatorDef& MulOp = net.op().Get(0);
55  AT_ASSERT(MulOp.input().size() == 2);
56  AT_ASSERT(MulOp.input().Get(0) == net.external_input().Get(0));
57  AT_ASSERT(MulOp.input().Get(1) == net.external_input().Get(1));
58  AT_ASSERT(MulOp.output().size() == 1);
59 
60  const caffe2::OperatorDef& ConstNode = net.op().Get(1);
61  AT_ASSERT(ConstNode.input().size() == 0);
62  AT_ASSERT(ConstNode.output().size() == 1);
63  AT_ASSERT(ConstNode.arg().size() == 1);
64  AT_ASSERT(ConstNode.arg().Get(0).name() == "value");
65  AT_ASSERT(ConstNode.arg().Get(0).i() == 1);
66 
67  const caffe2::OperatorDef& AddOp = net.op().Get(2);
68  AT_ASSERT(AddOp.input().size() == 3);
69  AT_ASSERT(AddOp.input().Get(0) == net.external_input().Get(0));
70  AT_ASSERT(AddOp.input().Get(1) == MulOp.output().Get(0));
71  AT_ASSERT(AddOp.input().Get(2) == ConstNode.output().Get(0));
72 
73  AT_ASSERT(net.external_output().Get(0) == MulOp.output().Get(0));
74  AT_ASSERT(net.external_output().Get(1) == AddOp.output().Get(0));
75 
76  // Convert NetDef back to IR and check if we get the original.
77  Graph graph2;
78  std::unordered_map<std::string, Value*> vmap;
79  convertNetDefToIR(net, &graph2, &vmap);
80 
81  Node* mul = graph2.outputs()[0]->node();
82  Node* add = graph2.outputs()[1]->node();
83  AT_ASSERT(mul->kind() == c->node()->kind());
84  AT_ASSERT(add->kind() == d->node()->kind());
85  AT_ASSERT(mul->inputs()[0] == graph2.inputs()[0]);
86  AT_ASSERT(mul->inputs()[1] == graph2.inputs()[1]);
87  AT_ASSERT(add->inputs()[0] == graph2.inputs()[0]);
88  AT_ASSERT(add->inputs()[1] == graph2.outputs()[0]);
89  }
90  {
91  // Check attributes conversion
92  auto graph = std::make_shared<Graph>();
93  auto a = graph->addInput();
94  auto b = graph->addInput();
95  Node* node =
96  graph->create(Symbol::fromQualString("test::some_op"), {a, b}, 2);
97  graph->insertNode(node);
98 
99  node->i_(Symbol::fromQualString("attr::i_attr"), 42);
100  node->f_(Symbol::fromQualString("attr::f_attr"), 3.0);
101  node->s_(Symbol::fromQualString("attr::s_attr"), "Hello!");
102 
103  node->is_(Symbol::fromQualString("attr::is_attr"), {14, 18, 7});
104  node->fs_(Symbol::fromQualString("attr::fs_attr"), {2.72, 3.14});
105  node->ss_(Symbol::fromQualString("attr::ss_attr"), {"Winter", "Summer"});
106 
107  graph->registerOutput(node->outputs()[0]);
108  graph->registerOutput(node->outputs()[1]);
109 
110  // Convert it to netdef and check the result
111  caffe2::NetDef net;
112  convertIRToNetDef(&net, *graph);
113  const caffe2::OperatorDef& Op = net.op().Get(0);
114  AT_ASSERT(Op.arg().Get(0).name() == "i_attr");
115  AT_ASSERT(Op.arg().Get(0).i() == 42);
116  AT_ASSERT(Op.arg().Get(1).name() == "f_attr");
117  AT_ASSERT(Op.arg().Get(1).f() == 3.0);
118  AT_ASSERT(Op.arg().Get(2).name() == "s_attr");
119  AT_ASSERT(Op.arg().Get(2).s() == "Hello!");
120 
121  AT_ASSERT(Op.arg().Get(3).name() == "is_attr");
122  AT_ASSERT(Op.arg().Get(3).ints().size() == 3);
123  AT_ASSERT(Op.arg().Get(3).ints().Get(0) == 14);
124  AT_ASSERT(Op.arg().Get(3).ints().Get(1) == 18);
125  AT_ASSERT(Op.arg().Get(3).ints().Get(2) == 7);
126 
127  AT_ASSERT(Op.arg().Get(4).name() == "fs_attr");
128  AT_ASSERT(Op.arg().Get(4).floats().size() == 2);
129  AT_ASSERT(fabs(Op.arg().Get(4).floats().Get(0) - 2.72) < 0.001);
130 
131  AT_ASSERT(Op.arg().Get(5).name() == "ss_attr");
132  AT_ASSERT(Op.arg().Get(5).strings().size() == 2);
133  AT_ASSERT(Op.arg().Get(5).strings().Get(1) == "Summer");
134 
135  AT_ASSERT(net.external_output().Get(0) == Op.output().Get(0));
136  AT_ASSERT(net.external_output().Get(1) == Op.output().Get(1));
137 
138  // Convert NetDef back to IR and check if we get the original.
139  Graph graph2;
140  std::unordered_map<std::string, Value*> vmap;
141  convertNetDefToIR(net, &graph2, &vmap);
142 
143  AT_ASSERT(graph2.outputs()[0]->node() == graph2.outputs()[0]->node());
144  Node* n = graph2.outputs()[0]->node();
145  AT_ASSERT(n->i(Symbol::fromQualString("attr::i_attr")) == 42);
146  AT_ASSERT(n->f(Symbol::fromQualString("attr::f_attr")) == 3.0);
147  AT_ASSERT(n->s(Symbol::fromQualString("attr::s_attr")) == "Hello!");
148  AT_ASSERT(
149  n->is(Symbol::fromQualString("attr::is_attr")) ==
150  std::vector<int64_t>({14, 18, 7}));
151  AT_ASSERT(
152  fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[0] - 2.72) < 0.001);
153  AT_ASSERT(
154  fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[1] - 3.14) < 0.001);
155  AT_ASSERT(
156  n->ss(Symbol::fromQualString("attr::ss_attr")) ==
157  std::vector<std::string>({"Winter", "Summer"}));
158  }
159  {
160  // Check how value names are preserved in conversion. They naturally might
161  // change as IR is in SSA form, but we should try not to change names of
162  // external inputs and outputs.
163 
164  // Create a simple net:
165  // net(ext_inputs = {a, b, c})
166  // a = foo::bar(a, b)
167  // u = foo::baz(b, c)
168  // x = foo::qux(u, a)
169  // x = foo::quux(a, x)
170  // -> (ext_outputs = {x})
171  //
172  caffe2::NetDef net;
173 
174  *net.add_op() = createOperator("foo::bar", {"a", "b"}, {"a"});
175  *net.add_op() = createOperator("foo::baz", {"b", "c"}, {"u"});
176  *net.add_op() = createOperator("foo::qux", {"u", "a"}, {"x"});
177  *net.add_op() = createOperator("foo::quux", {"a", "x", "u"}, {"x"});
178  net.add_external_input("a");
179  net.add_external_input("b");
180  net.add_external_input("c");
181  net.add_external_output("x");
182 
183  // Expect the following graph to be generated:
184  // graph(%a : Tensor,
185  // %b : Tensor,
186  // %c : Tensor) {
187  // %a.1 : Tensor = foo::bar(%a, %b)
188  // %u : Tensor = foo::baz(%b, %c)
189  // %x.1 : Tensor = foo::qux(%u, %a.1)
190  // %x : Tensor = foo::quux(%a.1, %x.1, u)
191  // return (%x)
192  // }
193  Graph graph;
194  std::unordered_map<std::string, Value*> vmap;
195  convertNetDefToIR(net, &graph, &vmap);
196  AT_ASSERT(graph.inputs().size() == 3);
197  AT_ASSERT(graph.inputs()[0]->uniqueName() == "a");
198  AT_ASSERT(graph.inputs()[1]->uniqueName() == "b");
199  AT_ASSERT(graph.inputs()[2]->uniqueName() == "c");
200 
201  AT_ASSERT(graph.outputs().size() == 1);
202  AT_ASSERT(graph.outputs()[0]->uniqueName() == "x");
203 
204  Node* quux = graph.outputs()[0]->node();
205  Value* a0 = quux->inputs()[0];
206  Value* x0 = quux->inputs()[1];
207  Value* u = quux->inputs()[2];
208  AT_ASSERT(a0->uniqueName() != "a" && a0->uniqueNameBase() == "a");
209  AT_ASSERT(x0->uniqueName() != "x" && x0->uniqueNameBase() == "x");
210  AT_ASSERT(u->uniqueName() == "u");
211 
212  // Convert back to netdef and check if the names are preserved.
213  // We still expect them to be in SSA form, but we should preserve names for
214  // external inputs and outputs.
215  caffe2::NetDef net2;
216  convertIRToNetDef(&net2, graph);
217  AT_ASSERT(net2.external_input().Get(0) == "a");
218  AT_ASSERT(net2.external_input().Get(1) == "b");
219  AT_ASSERT(net2.external_input().Get(2) == "c");
220  AT_ASSERT(net2.external_output().Get(0) == "x");
221  }
222 
223  {
224  // Test that prefix is removed when converting from NetDef to IR and back.
225  caffe2::NetDef net;
226  *net.add_op() = createOperator("MatMul", {"a", "b"}, {"c"});
227  net.add_external_input("a");
228  net.add_external_input("b");
229  net.add_external_output("c");
230  Graph graph;
231  std::unordered_map<std::string, Value*> vmap;
232  convertNetDefToIR(net, &graph, &vmap, "caffe2::");
233  // Sanity check that value map is returned and it works.
234  AT_ASSERT(vmap["a"]->uniqueName() == "a");
235 
236  caffe2::NetDef net2;
237  convertIRToNetDef(&net2, graph, "caffe2::");
238  // The conversion should remove the prefix if it maches.
239  AT_ASSERT(net2.op(0).type() == "MatMul");
240 
241  caffe2::NetDef net3;
242  convertIRToNetDef(&net3, graph, "foo::");
243  // The conversion should still work if the prefix does not match.
244  AT_ASSERT(net3.op(0).type() == "caffe2::MatMul");
245 
246  // Prefix shouldn't affect blob names.
247  AT_ASSERT(net2.op(0).input(0) == "a");
248  AT_ASSERT(net2.external_input(0) == "a");
249  AT_ASSERT(net2.external_output(0) == "c");
250  AT_ASSERT(net3.external_input(0) == "a");
251 
252  Graph graph2;
253  // Test that conversion works without passing in a valueMap.
254  convertNetDefToIR(net, &graph2, nullptr, "caffe2::");
255  }
256 }
257 
258 } // namespace jit
259 } // namespace torch
Definition: jit_type.h:17