1 #include <torch/csrc/jit/netdef_converter.h> 8 return AttributeKind::i;
9 }
else if (arg.has_f()) {
10 return AttributeKind::f;
11 }
else if (arg.has_s()) {
12 return AttributeKind::s;
13 }
else if (arg.has_t()) {
14 return AttributeKind::t;
15 }
else if (arg.has_n()) {
16 return AttributeKind::g;
17 }
else if (arg.ints().size()) {
18 return AttributeKind::is;
19 }
else if (arg.floats().size()) {
20 return AttributeKind::fs;
21 }
else if (arg.strings().size()) {
22 return AttributeKind::ss;
23 }
else if (arg.tensors().size()) {
24 return AttributeKind::ts;
25 }
else if (arg.nets().size()) {
26 return AttributeKind::gs;
33 std::string attrName =
"attr::" + arg.name();
34 auto attrSymbol = Symbol::fromQualString(attrName);
35 AttributeKind kind = getArgKind(arg);
37 case AttributeKind::i: {
38 node->i_(attrSymbol, (int64_t)arg.i());
41 case AttributeKind::f: {
42 node->f_(attrSymbol, arg.f());
45 case AttributeKind::s: {
46 node->s_(attrSymbol, arg.s());
49 case AttributeKind::is: {
50 std::vector<int64_t> is(arg.ints().begin(), arg.ints().end());
51 node->is_(attrSymbol, is);
54 case AttributeKind::fs: {
55 std::vector<double> fs(arg.floats().begin(), arg.floats().end());
56 node->fs_(attrSymbol, fs);
59 case AttributeKind::ss: {
60 std::vector<std::string> ss(arg.strings().begin(), arg.strings().end());
61 node->ss_(attrSymbol, ss);
65 std::cout <<
"Unsupported type '" << toString(kind) <<
"' of attribute '" 67 <<
" in node:" << std::endl;
74 void convertNetDefToIR(
75 const caffe2::NetDef& net,
77 std::unordered_map<std::string, Value*>* valueMapPtr,
78 const std::string& prefix) {
80 std::unordered_map<std::string, Value*> localValueMap;
83 return convertNetDefToIR(net, g, &localValueMap, prefix);
85 std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
86 std::unordered_map<Value*, std::string> namesMap;
89 for (
const auto& inputName : net.external_input()) {
90 AT_ASSERT(!valueMap.count(inputName));
91 valueMap[inputName] = g->addInput();
92 namesMap[valueMap.at(inputName)] = inputName;
95 for (
const auto& op : net.op()) {
96 std::string name = prefix + op.type();
98 g->create(Symbol::fromQualString(name), {}, op.output().size());
101 for (
const auto& input : op.input()) {
102 AT_ASSERT(valueMap.count(input));
103 node->addInput(valueMap[input]);
106 for (
const auto& output : op.output()) {
109 Value* v = node->outputs()[idx++];
110 valueMap[output] = v;
111 namesMap[v] = output;
113 for (
const auto& arg : op.arg()) {
114 convertArg(arg, node);
118 for (
const auto& outputName : net.external_output()) {
119 AT_ASSERT(valueMap.count(outputName));
120 g->registerOutput(valueMap.at(outputName));
121 namesMap[valueMap.at(outputName)] = outputName;
128 for (Node* n : g->nodes()) {
129 for (Value* v : n->outputs()) {
130 AT_ASSERT(namesMap.count(v));
131 const std::string& name = namesMap.at(v);
132 if (Value::isValidName(name)) {
133 v->setUniqueName(name);
137 for (Value* v : g->inputs()) {
138 AT_ASSERT(namesMap.count(v));
139 const std::string& name = namesMap.at(v);
140 if (Value::isValidName(name)) {
141 v->setUniqueName(name);
144 for (Value* v : g->outputs()) {
145 AT_ASSERT(namesMap.count(v));
146 const std::string& name = namesMap.at(v);
147 if (Value::isValidName(name)) {
148 v->setUniqueName(name);
153 static void convertAttrToCaffe2Arg(
157 arg->set_name(name.toUnqualString());
158 switch (node->kindOf(name)) {
159 case AttributeKind::i: {
160 arg->set_i(node->i(name));
163 case AttributeKind::f: {
164 arg->set_f(node->f(name));
167 case AttributeKind::s: {
168 arg->set_s(node->s(name));
171 case AttributeKind::is: {
172 for (int64_t i : node->is(name)) {
177 case AttributeKind::fs: {
178 for (
double f : node->fs(name)) {
183 case AttributeKind::ss: {
184 for (
const std::string& s : node->ss(name)) {
190 std::cout <<
"Unsupported type '" << toString(node->kindOf(name))
191 <<
"' of attribute '" << name.toUnqualString() <<
"'" 192 <<
" in node:" << std::endl;
199 static std::string removePrefixIfNeeded(
const std::string& name,
200 const std::string& prefix) {
201 if (!name.compare(0, prefix.size(), prefix)) {
202 return name.substr(prefix.size());
208 static void convertNodeToCaffe2Op(
const Node* node, caffe2::NetDef* net,
209 const std::string& prefix =
"") {
210 caffe2::OperatorDef op;
211 op.set_type(removePrefixIfNeeded(node->kind().toQualString(), prefix));
212 for (
const Value* input : node->inputs()) {
213 op.add_input(input->uniqueName());
215 for (
const Value* output : node->outputs()) {
216 op.add_output(output->uniqueName());
218 std::vector<Symbol> names = node->attributeNames();
219 for (
const Symbol& name : names) {
221 convertAttrToCaffe2Arg(node, name, arg);
226 void convertIRToNetDef(caffe2::NetDef* net,
const Graph& g,
227 const std::string& prefix) {
228 net->mutable_op()->Clear();
230 for (
const Value* value : g.inputs()) {
231 net->add_external_input(value->uniqueName());
234 for (
const Node* node : g.nodes()) {
235 convertNodeToCaffe2Op(node, net, prefix);
238 for (
const Value* value : g.outputs()) {
239 net->add_external_output(value->uniqueName());