Caffe2 - C++ API
A deep learning, cross platform ML framework
netdef_converter.cpp
1 #include <torch/csrc/jit/netdef_converter.h>
2 
3 namespace torch {
4 namespace jit {
5 
6 static AttributeKind getArgKind(const caffe2::Argument& arg) {
7  if (arg.has_i()) {
8  return AttributeKind::i;
9  } else if (arg.has_f()) {
10  return AttributeKind::f;
11  } else if (arg.has_s()) {
12  return AttributeKind::s;
13  } else if (arg.has_t()) {
14  return AttributeKind::t;
15  } else if (arg.has_n()) {
16  return AttributeKind::g;
17  } else if (arg.ints().size()) {
18  return AttributeKind::is;
19  } else if (arg.floats().size()) {
20  return AttributeKind::fs;
21  } else if (arg.strings().size()) {
22  return AttributeKind::ss;
23  } else if (arg.tensors().size()) {
24  return AttributeKind::ts;
25  } else if (arg.nets().size()) {
26  return AttributeKind::gs;
27  }
28  // Unknown type.
29  abort();
30 }
31 
32 static void convertArg(const caffe2::Argument& arg, Node* node) {
33  std::string attrName = "attr::" + arg.name();
34  auto attrSymbol = Symbol::fromQualString(attrName);
35  AttributeKind kind = getArgKind(arg);
36  switch (kind) {
37  case AttributeKind::i: {
38  node->i_(attrSymbol, (int64_t)arg.i());
39  break;
40  }
41  case AttributeKind::f: {
42  node->f_(attrSymbol, arg.f());
43  break;
44  }
45  case AttributeKind::s: {
46  node->s_(attrSymbol, arg.s());
47  break;
48  }
49  case AttributeKind::is: {
50  std::vector<int64_t> is(arg.ints().begin(), arg.ints().end());
51  node->is_(attrSymbol, is);
52  break;
53  }
54  case AttributeKind::fs: {
55  std::vector<double> fs(arg.floats().begin(), arg.floats().end());
56  node->fs_(attrSymbol, fs);
57  break;
58  }
59  case AttributeKind::ss: {
60  std::vector<std::string> ss(arg.strings().begin(), arg.strings().end());
61  node->ss_(attrSymbol, ss);
62  break;
63  }
64  default: {
65  std::cout << "Unsupported type '" << toString(kind) << "' of attribute '"
66  << attrName << "'"
67  << " in node:" << std::endl;
68  node->dump();
69  abort();
70  }
71  }
72 }
73 
74 void convertNetDefToIR(
75  const caffe2::NetDef& net,
76  Graph* g,
77  std::unordered_map<std::string, Value*>* valueMapPtr,
78  const std::string& prefix) {
79  if (!valueMapPtr) {
80  std::unordered_map<std::string, Value*> localValueMap;
81  // If valueMapPtr is null, we just use a local map since we don't need
82  // to return the valueMap to the caller.
83  return convertNetDefToIR(net, g, &localValueMap, prefix);
84  }
85  std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
86  std::unordered_map<Value*, std::string> namesMap;
87  valueMap.clear();
88 
89  for (const auto& inputName : net.external_input()) {
90  AT_ASSERT(!valueMap.count(inputName));
91  valueMap[inputName] = g->addInput();
92  namesMap[valueMap.at(inputName)] = inputName;
93  }
94 
95  for (const auto& op : net.op()) {
96  std::string name = prefix + op.type();
97  Node* node =
98  g->create(Symbol::fromQualString(name), {}, op.output().size());
99  g->insertNode(node);
100 
101  for (const auto& input : op.input()) {
102  AT_ASSERT(valueMap.count(input));
103  node->addInput(valueMap[input]);
104  }
105  int idx = 0;
106  for (const auto& output : op.output()) {
107  // If output already exists in valueMap, overwrite it. This way we will
108  // have the last definition of a value named 'output' in valueMap.
109  Value* v = node->outputs()[idx++];
110  valueMap[output] = v;
111  namesMap[v] = output;
112  }
113  for (const auto& arg : op.arg()) {
114  convertArg(arg, node);
115  }
116  }
117 
118  for (const auto& outputName : net.external_output()) {
119  AT_ASSERT(valueMap.count(outputName));
120  g->registerOutput(valueMap.at(outputName));
121  namesMap[valueMap.at(outputName)] = outputName;
122  }
123 
124  // Set proper unique names for all values.
125  // We will set the names for external inputs and outputs last, so that if the
126  // names are reused, then intermediate values will be renamed and the external
127  // values will keep the original names.
128  for (Node* n : g->nodes()) {
129  for (Value* v : n->outputs()) {
130  AT_ASSERT(namesMap.count(v));
131  const std::string& name = namesMap.at(v);
132  if (Value::isValidName(name)) {
133  v->setUniqueName(name);
134  }
135  }
136  }
137  for (Value* v : g->inputs()) {
138  AT_ASSERT(namesMap.count(v));
139  const std::string& name = namesMap.at(v);
140  if (Value::isValidName(name)) {
141  v->setUniqueName(name);
142  }
143  }
144  for (Value* v : g->outputs()) {
145  AT_ASSERT(namesMap.count(v));
146  const std::string& name = namesMap.at(v);
147  if (Value::isValidName(name)) {
148  v->setUniqueName(name);
149  }
150  }
151 }
152 
153 static void convertAttrToCaffe2Arg(
154  const Node* node,
155  const Symbol& name,
156  caffe2::Argument* arg) {
157  arg->set_name(name.toUnqualString());
158  switch (node->kindOf(name)) {
159  case AttributeKind::i: {
160  arg->set_i(node->i(name));
161  break;
162  }
163  case AttributeKind::f: {
164  arg->set_f(node->f(name));
165  break;
166  }
167  case AttributeKind::s: {
168  arg->set_s(node->s(name));
169  break;
170  }
171  case AttributeKind::is: {
172  for (int64_t i : node->is(name)) {
173  arg->add_ints(i);
174  }
175  break;
176  }
177  case AttributeKind::fs: {
178  for (double f : node->fs(name)) {
179  arg->add_floats(f);
180  }
181  break;
182  }
183  case AttributeKind::ss: {
184  for (const std::string& s : node->ss(name)) {
185  arg->add_strings(s);
186  }
187  break;
188  }
189  default: {
190  std::cout << "Unsupported type '" << toString(node->kindOf(name))
191  << "' of attribute '" << name.toUnqualString() << "'"
192  << " in node:" << std::endl;
193  node->dump();
194  abort();
195  }
196  }
197 }
198 
199 static std::string removePrefixIfNeeded(const std::string& name,
200  const std::string& prefix) {
201  if (!name.compare(0, prefix.size(), prefix)) {
202  return name.substr(prefix.size());
203  } else {
204  return name;
205  }
206 }
207 
208 static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net,
209  const std::string& prefix = "") {
210  caffe2::OperatorDef op;
211  op.set_type(removePrefixIfNeeded(node->kind().toQualString(), prefix));
212  for (const Value* input : node->inputs()) {
213  op.add_input(input->uniqueName());
214  }
215  for (const Value* output : node->outputs()) {
216  op.add_output(output->uniqueName());
217  }
218  std::vector<Symbol> names = node->attributeNames();
219  for (const Symbol& name : names) {
220  caffe2::Argument* arg = op.add_arg();
221  convertAttrToCaffe2Arg(node, name, arg);
222  }
223  *net->add_op() = op;
224 }
225 
226 void convertIRToNetDef(caffe2::NetDef* net, const Graph& g,
227  const std::string& prefix) {
228  net->mutable_op()->Clear();
229 
230  for (const Value* value : g.inputs()) {
231  net->add_external_input(value->uniqueName());
232  }
233 
234  for (const Node* node : g.nodes()) {
235  convertNodeToCaffe2Op(node, net, prefix);
236  }
237 
238  for (const Value* value : g.outputs()) {
239  net->add_external_output(value->uniqueName());
240  }
241 }
242 
243 } // namespace jit
244 } // namespace torch
Definition: jit_type.h:17