3 #include <torch/csrc/jit/netdef_converter.h> 4 #include "test/cpp/jit/test_base.h" 12 static caffe2::OperatorDef createOperator(
13 const std::string& name,
14 const std::vector<std::string>& inputs,
15 const std::vector<std::string>& outputs) {
16 caffe2::OperatorDef op;
18 for (
const auto& input : inputs) {
21 for (
const auto& output : outputs) {
22 op.add_output(output);
27 void testNetDefConverter() {
39 auto graph = std::make_shared<Graph>();
40 auto a = graph->addInput();
41 auto b = graph->addInput();
42 auto c = graph->insert(aten::mul, {a, b});
43 auto d = graph->insert(aten::add, {a, c});
44 graph->registerOutput(c);
45 graph->registerOutput(d);
49 convertIRToNetDef(&net, *graph);
50 AT_ASSERT(net.op().size() == 3);
51 AT_ASSERT(net.external_input().size() == 2);
52 AT_ASSERT(net.external_output().size() == 2);
54 const caffe2::OperatorDef& MulOp = net.op().Get(0);
55 AT_ASSERT(MulOp.input().size() == 2);
56 AT_ASSERT(MulOp.input().Get(0) == net.external_input().Get(0));
57 AT_ASSERT(MulOp.input().Get(1) == net.external_input().Get(1));
58 AT_ASSERT(MulOp.output().size() == 1);
60 const caffe2::OperatorDef& ConstNode = net.op().Get(1);
61 AT_ASSERT(ConstNode.input().size() == 0);
62 AT_ASSERT(ConstNode.output().size() == 1);
63 AT_ASSERT(ConstNode.arg().size() == 1);
64 AT_ASSERT(ConstNode.arg().Get(0).name() ==
"value");
65 AT_ASSERT(ConstNode.arg().Get(0).i() == 1);
67 const caffe2::OperatorDef& AddOp = net.op().Get(2);
68 AT_ASSERT(AddOp.input().size() == 3);
69 AT_ASSERT(AddOp.input().Get(0) == net.external_input().Get(0));
70 AT_ASSERT(AddOp.input().Get(1) == MulOp.output().Get(0));
71 AT_ASSERT(AddOp.input().Get(2) == ConstNode.output().Get(0));
73 AT_ASSERT(net.external_output().Get(0) == MulOp.output().Get(0));
74 AT_ASSERT(net.external_output().Get(1) == AddOp.output().Get(0));
78 std::unordered_map<std::string, Value*> vmap;
79 convertNetDefToIR(net, &graph2, &vmap);
81 Node* mul = graph2.outputs()[0]->node();
82 Node* add = graph2.outputs()[1]->node();
83 AT_ASSERT(mul->kind() == c->node()->kind());
84 AT_ASSERT(add->kind() == d->node()->kind());
85 AT_ASSERT(mul->inputs()[0] == graph2.inputs()[0]);
86 AT_ASSERT(mul->inputs()[1] == graph2.inputs()[1]);
87 AT_ASSERT(add->inputs()[0] == graph2.inputs()[0]);
88 AT_ASSERT(add->inputs()[1] == graph2.outputs()[0]);
92 auto graph = std::make_shared<Graph>();
93 auto a = graph->addInput();
94 auto b = graph->addInput();
96 graph->create(Symbol::fromQualString(
"test::some_op"), {a, b}, 2);
97 graph->insertNode(node);
99 node->i_(Symbol::fromQualString(
"attr::i_attr"), 42);
100 node->f_(Symbol::fromQualString(
"attr::f_attr"), 3.0);
101 node->s_(Symbol::fromQualString(
"attr::s_attr"),
"Hello!");
103 node->is_(Symbol::fromQualString(
"attr::is_attr"), {14, 18, 7});
104 node->fs_(Symbol::fromQualString(
"attr::fs_attr"), {2.72, 3.14});
105 node->ss_(Symbol::fromQualString(
"attr::ss_attr"), {
"Winter",
"Summer"});
107 graph->registerOutput(node->outputs()[0]);
108 graph->registerOutput(node->outputs()[1]);
112 convertIRToNetDef(&net, *graph);
113 const caffe2::OperatorDef& Op = net.op().Get(0);
114 AT_ASSERT(Op.arg().Get(0).name() ==
"i_attr");
115 AT_ASSERT(Op.arg().Get(0).i() == 42);
116 AT_ASSERT(Op.arg().Get(1).name() ==
"f_attr");
117 AT_ASSERT(Op.arg().Get(1).f() == 3.0);
118 AT_ASSERT(Op.arg().Get(2).name() ==
"s_attr");
119 AT_ASSERT(Op.arg().Get(2).s() ==
"Hello!");
121 AT_ASSERT(Op.arg().Get(3).name() ==
"is_attr");
122 AT_ASSERT(Op.arg().Get(3).ints().size() == 3);
123 AT_ASSERT(Op.arg().Get(3).ints().Get(0) == 14);
124 AT_ASSERT(Op.arg().Get(3).ints().Get(1) == 18);
125 AT_ASSERT(Op.arg().Get(3).ints().Get(2) == 7);
127 AT_ASSERT(Op.arg().Get(4).name() ==
"fs_attr");
128 AT_ASSERT(Op.arg().Get(4).floats().size() == 2);
129 AT_ASSERT(fabs(Op.arg().Get(4).floats().Get(0) - 2.72) < 0.001);
131 AT_ASSERT(Op.arg().Get(5).name() ==
"ss_attr");
132 AT_ASSERT(Op.arg().Get(5).strings().size() == 2);
133 AT_ASSERT(Op.arg().Get(5).strings().Get(1) ==
"Summer");
135 AT_ASSERT(net.external_output().Get(0) == Op.output().Get(0));
136 AT_ASSERT(net.external_output().Get(1) == Op.output().Get(1));
140 std::unordered_map<std::string, Value*> vmap;
141 convertNetDefToIR(net, &graph2, &vmap);
143 AT_ASSERT(graph2.outputs()[0]->node() == graph2.outputs()[0]->node());
144 Node* n = graph2.outputs()[0]->node();
145 AT_ASSERT(n->i(Symbol::fromQualString(
"attr::i_attr")) == 42);
146 AT_ASSERT(n->f(Symbol::fromQualString(
"attr::f_attr")) == 3.0);
147 AT_ASSERT(n->s(Symbol::fromQualString(
"attr::s_attr")) ==
"Hello!");
149 n->is(Symbol::fromQualString(
"attr::is_attr")) ==
150 std::vector<int64_t>({14, 18, 7}));
152 fabs(n->fs(Symbol::fromQualString(
"attr::fs_attr"))[0] - 2.72) < 0.001);
154 fabs(n->fs(Symbol::fromQualString(
"attr::fs_attr"))[1] - 3.14) < 0.001);
156 n->ss(Symbol::fromQualString(
"attr::ss_attr")) ==
157 std::vector<std::string>({
"Winter",
"Summer"}));
174 *net.add_op() = createOperator(
"foo::bar", {
"a",
"b"}, {
"a"});
175 *net.add_op() = createOperator(
"foo::baz", {
"b",
"c"}, {
"u"});
176 *net.add_op() = createOperator(
"foo::qux", {
"u",
"a"}, {
"x"});
177 *net.add_op() = createOperator(
"foo::quux", {
"a",
"x",
"u"}, {
"x"});
178 net.add_external_input(
"a");
179 net.add_external_input(
"b");
180 net.add_external_input(
"c");
181 net.add_external_output(
"x");
194 std::unordered_map<std::string, Value*> vmap;
195 convertNetDefToIR(net, &graph, &vmap);
196 AT_ASSERT(graph.inputs().size() == 3);
197 AT_ASSERT(graph.inputs()[0]->uniqueName() ==
"a");
198 AT_ASSERT(graph.inputs()[1]->uniqueName() ==
"b");
199 AT_ASSERT(graph.inputs()[2]->uniqueName() ==
"c");
201 AT_ASSERT(graph.outputs().size() == 1);
202 AT_ASSERT(graph.outputs()[0]->uniqueName() ==
"x");
204 Node* quux = graph.outputs()[0]->node();
205 Value* a0 = quux->inputs()[0];
206 Value* x0 = quux->inputs()[1];
207 Value* u = quux->inputs()[2];
208 AT_ASSERT(a0->uniqueName() !=
"a" && a0->uniqueNameBase() ==
"a");
209 AT_ASSERT(x0->uniqueName() !=
"x" && x0->uniqueNameBase() ==
"x");
210 AT_ASSERT(u->uniqueName() ==
"u");
216 convertIRToNetDef(&net2, graph);
217 AT_ASSERT(net2.external_input().Get(0) ==
"a");
218 AT_ASSERT(net2.external_input().Get(1) ==
"b");
219 AT_ASSERT(net2.external_input().Get(2) ==
"c");
220 AT_ASSERT(net2.external_output().Get(0) ==
"x");
226 *net.add_op() = createOperator(
"MatMul", {
"a",
"b"}, {
"c"});
227 net.add_external_input(
"a");
228 net.add_external_input(
"b");
229 net.add_external_output(
"c");
231 std::unordered_map<std::string, Value*> vmap;
232 convertNetDefToIR(net, &graph, &vmap,
"caffe2::");
234 AT_ASSERT(vmap[
"a"]->uniqueName() ==
"a");
237 convertIRToNetDef(&net2, graph,
"caffe2::");
239 AT_ASSERT(net2.op(0).type() ==
"MatMul");
242 convertIRToNetDef(&net3, graph,
"foo::");
244 AT_ASSERT(net3.op(0).type() ==
"caffe2::MatMul");
247 AT_ASSERT(net2.op(0).input(0) ==
"a");
248 AT_ASSERT(net2.external_input(0) ==
"a");
249 AT_ASSERT(net2.external_output(0) ==
"c");
250 AT_ASSERT(net3.external_input(0) ==
"a");
254 convertNetDefToIR(net, &graph2,
nullptr,
"caffe2::");