Caffe2 - C++ API
A deep learning, cross platform ML framework
onnx_exporter.cc
1 #include "caffe2/onnx/onnx_exporter.h"
2 #include "caffe2/core/logging.h"
3 #include "caffe2/core/tensor_impl.h"
4 #include "caffe2/onnx/helper.h"
5 #include "caffe2/proto/caffe2_legacy.pb.h"
6 #include "caffe2/utils/map_utils.h"
7 #include "caffe2/utils/proto_utils.h"
8 
9 #include <numeric>
10 #include <unordered_set>
11 
12 namespace caffe2 {
13 namespace onnx {
14 
15 namespace {
16 // rewrite padding attributes
17 void ApplyTrans(
18  std::unordered_map<std::string, AttributeProto>* attrs,
19  bool global,
20  const std::string& k,
21  int dim = 2,
22  const std::string& ks = "") {
23  std::string ks2 = ks.empty() ? (k + "s") : ks;
24  std::string k_h, k_w, k_t, k_l, k_b, k_r;
25  if (dim == 2) {
26  k_h = k + "_h";
27  k_w = k + "_w";
28  } else {
29  k_t = k + "_t";
30  k_l = k + "_l";
31  k_b = k + "_b";
32  k_r = k + "_r";
33  }
34 
35  std::vector<int64_t> vals;
36  if (dim == 2 && attrs->count(k_h) && attrs->count(k_w)) {
37  auto it = attrs->find(k_h);
38  vals.push_back(it->second.i());
39  attrs->erase(it);
40  it = attrs->find(k_w);
41  vals.push_back(it->second.i());
42  attrs->erase(it);
43  } else if (
44  dim == 4 && attrs->count(k_t) && attrs->count(k_b) && attrs->count(k_l) &&
45  attrs->count(k_r)) {
46  auto it = attrs->find(k_t);
47  vals.push_back(it->second.i());
48  attrs->erase(it);
49  it = attrs->find(k_l);
50  vals.push_back(it->second.i());
51  attrs->erase(it);
52  it = attrs->find(k_b);
53  vals.push_back(it->second.i());
54  attrs->erase(it);
55  it = attrs->find(k_r);
56  vals.push_back(it->second.i());
57  attrs->erase(it);
58  } else if (attrs->count(k)) {
59  auto it = attrs->find(k);
60  auto tmp = it->second.i();
61  for (int i = 0; i < dim; ++i) {
62  vals.push_back(tmp);
63  }
64  attrs->erase(it);
65  }
66 
67  if (!vals.empty() && !global) {
68  attrs->emplace(ks2, MakeAttribute(ks2, vals));
69  }
70 }
71 
72 int64_t DimProd(const caffe2::TensorShape& shape, int start, int end) {
73  int64_t acc = 1;
74  for (int i = start; i < end; ++i) {
75  acc *= shape.dims(i);
76  }
77  return acc;
78 }
79 
80 TensorProto CreateOnnxShapeTensor(
81  std::shared_ptr<DummyName> dummy,
82  const std::vector<int64_t>& shape) {
83  TensorProto tensor;
84  tensor.set_name(dummy->NewDummyName());
85  tensor.set_data_type(TensorProto::INT64);
86  tensor.add_dims(shape.size());
87  tensor.mutable_raw_data()->assign(
88  reinterpret_cast<const char*>(shape.data()),
89  sizeof(int64_t) * shape.size());
90  return tensor;
91 }
92 
93 std::string SsaName(const std::string& n, int version) {
94  return c10::str(n, "_", version);
95 }
96 
97 NodeProto AddShapeNode(const std::string& input, const std::string& output) {
98  NodeProto shape_node;
99  shape_node.set_op_type("Shape");
100  shape_node.add_input(input);
101  shape_node.add_output(output);
102  return shape_node;
103 }
104 
105 } // namespace
106 
107 ::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
108  caffe2::TensorProto::DataType t) {
109 #define CAFFE2_TO_ONNX_TYPE(x) \
110  case (caffe2::TensorProto::x): \
111  return ::ONNX_NAMESPACE::TensorProto::x
112  switch (t) {
113  CAFFE2_TO_ONNX_TYPE(FLOAT);
114  CAFFE2_TO_ONNX_TYPE(BOOL);
115  CAFFE2_TO_ONNX_TYPE(INT8);
116  CAFFE2_TO_ONNX_TYPE(UINT8);
117  CAFFE2_TO_ONNX_TYPE(UINT16);
118  CAFFE2_TO_ONNX_TYPE(INT16);
119  CAFFE2_TO_ONNX_TYPE(INT32);
120  CAFFE2_TO_ONNX_TYPE(INT64);
121  CAFFE2_TO_ONNX_TYPE(FLOAT16);
122  default:
123  LOG(WARNING) << "Unsupported Caffe2 tensor type: " << t
124  << ", fallback to FLOAT";
125  return ::ONNX_NAMESPACE::TensorProto::FLOAT;
126  }
127 #undef CAFFE2_TO_ONNX_TYPE
128 }
129 
130 std::unordered_map<std::string, std::string> SsaRewrite(
131  caffe2::NetDef* init_net,
132  caffe2::NetDef* pred_net) {
133  std::unordered_map<std::string, std::string> input_mapping;
134  std::unordered_map<std::string, int> blob_versions;
135 
136  if (init_net) {
137  // No ssa rewrite is done for init net. The reason being that the output
138  // blobs of init net are what becomes the input blobs of pred_net. Since
139  // inputs of pred_net are not renamed we are not renaming the output of
140  // init_net. Furthermore, the assumption made is that init_net is simple net
141  // with each operator producing the one output and thus not renaming
142  // translates to not renaming the outputs of the init_net. Create identical
143  // mapping for now. This shall be removed eventually.
144  for (const auto& name : init_net->external_input()) {
145  input_mapping.emplace(name, name);
146  }
147  blob_versions.clear();
148  }
149 
150  if (pred_net) {
151  std::unordered_set<std::string> external_outputs;
152  for (const auto& input : pred_net->external_input()) {
153  // Create identical mapping for now. This shall be removed eventually.
154  input_mapping.emplace(input, input);
155  }
156  for (const auto& output : pred_net->external_output()) {
157  external_outputs.emplace(output);
158  }
159  for (auto& op : *pred_net->mutable_op()) {
160  for (auto& input : *op.mutable_input()) {
161  const auto it = blob_versions.find(input);
162  if (it != blob_versions.end()) {
163  input = SsaName(input, it->second);
164  } else {
165  // Input blob is not versioned yet.
166  // If it is not versioned yet, it is assumed to be primary input,
167  // Thus skip renaming it.
168  continue;
169  }
170  }
171  for (auto& output : *op.mutable_output()) {
172  auto it = blob_versions.find(output);
173  if (it != blob_versions.end()) {
174  it->second += 1;
175  output = SsaName(output, it->second);
176  } else {
177  blob_versions.emplace(output, 0);
178  output = SsaName(output, 0);
179  }
180  }
181  }
182 
183  // For all the renamed blobs find if the blob is one of the external
184  // output. If so add a mapping from it's latest renamed version to its
185  // original name.
186  std::unordered_map<std::string, std::string> renamed_external_outputs;
187  for (const auto it : blob_versions) {
188  if (external_outputs.count(it.first)) {
189  renamed_external_outputs.emplace(
190  SsaName(it.first, it.second), it.first);
191  }
192  }
193 
194  // Use the mapping to find if the input or output of an op was a renamed
195  // external output. If so replace it with its original name.
196  for (auto& op : *pred_net->mutable_op()) {
197  for (auto& input : *op.mutable_input()) {
198  const auto it = renamed_external_outputs.find(input);
199  if (it != renamed_external_outputs.end()) {
200  input = it->second;
201  }
202  }
203  for (auto& output : *op.mutable_output()) {
204  const auto it = renamed_external_outputs.find(output);
205  if (it != renamed_external_outputs.end()) {
206  output = it->second;
207  }
208  }
209  }
210  }
211 
212  return input_mapping;
213 }
214 
215 const std::unordered_map<std::string, std::string>&
216 OnnxExporter::get_renamed_operators() const {
217  const static std::unordered_map<std::string, std::string> kRenamedOperators{
218  {"SpatialBN", "BatchNormalization"},
219  {"Conv1D", "Conv"},
220  {"Conv2D", "Conv"},
221  {"Conv3D", "Conv"},
222  {"ConvTranspose1D", "ConvTranspose"},
223  {"ConvTranspose2D", "ConvTranspose"},
224  {"ConvTranspose3D", "ConvTranspose"},
225  {"MaxPool1D", "MaxPool"},
226  {"MaxPool2D", "MaxPool"},
227  {"MaxPool3D", "MaxPool"},
228  {"AveragePool1D", "AveragePool"},
229  {"AveragePool2D", "AveragePool"},
230  {"AveragePool3D", "AveragePool"}};
231  return kRenamedOperators;
232 }
233 
234 const std::unordered_map<std::string, std::string>&
235 OnnxExporter::get_renamed_attrs() const {
236  const static std::unordered_map<std::string, std::string> kRenamedAttrs{
237  {"kernels", "kernel_shape"}};
238  return kRenamedAttrs;
239 }
240 
241 const std::
242  unordered_map<std::string, std::unordered_map<std::string, std::string>>&
243  OnnxExporter::get_per_op_renamed_attrs() const {
244  const static std::
245  unordered_map<std::string, std::unordered_map<std::string, std::string>>
246  kPerOpRenamedAttrs = {{"Squeeze", {{"dims", "axes"}}},
247  {"Unsqueeze", {{"dims", "axes"}}},
248  {"Transpose", {{"axes", "perm"}}},
249  {"ConvTranspose", {{"adjs", "output_padding"}}},
250  {"Selu", {{"scale", "gamma"}}}};
251 
252  return kPerOpRenamedAttrs;
253 }
254 
255 const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
256 OnnxExporter::get_special_operators() const {
257  const static std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>
258  kSpecialOperators = {
259  {"ArgMax", &OnnxExporter::CreateArgMaxMinOpNodes},
260  {"ArgMin", &OnnxExporter::CreateArgMaxMinOpNodes},
261  {"Add", &OnnxExporter::CreateBinaryElementwiseOpNodes},
262  {"Sub", &OnnxExporter::CreateBinaryElementwiseOpNodes},
263  {"Mul", &OnnxExporter::CreateBinaryElementwiseOpNodes},
264  {"Div", &OnnxExporter::CreateBinaryElementwiseOpNodes},
265  {"Pow", &OnnxExporter::CreateBinaryElementwiseOpNodes},
266  {"And", &OnnxExporter::CreateBinaryElementwiseOpNodes},
267  {"Or", &OnnxExporter::CreateBinaryElementwiseOpNodes},
268  {"Xor", &OnnxExporter::CreateBinaryElementwiseOpNodes},
269  {"Equal", &OnnxExporter::CreateBinaryElementwiseOpNodes},
270  {"Greater", &OnnxExporter::CreateBinaryElementwiseOpNodes},
271  {"Less", &OnnxExporter::CreateBinaryElementwiseOpNodes},
272  {"Cast", &OnnxExporter::CreateCastNodes},
273  {"ElementwiseLinear", &OnnxExporter::CreateElementwiseLinearNodes},
274  {"Conv", &OnnxExporter::CreateConvPoolNodes},
275  {"ConvTranspose", &OnnxExporter::CreateConvPoolNodes},
276  {"MaxPool", &OnnxExporter::CreateConvPoolNodes},
277  {"AveragePool", &OnnxExporter::CreateConvPoolNodes},
278  {"FC", &OnnxExporter::CreateGemmNodes},
279  {"Concat", &OnnxExporter::CreateConcatNodes},
280  {"MergeDim", &OnnxExporter::CreateMergeDimNodes},
281  {"LRN", &OnnxExporter::CreateLrnNodes},
282  {"Reshape", &OnnxExporter::CreateReshapeNodes},
283  {"Slice", &OnnxExporter::CreateSliceNodes},
284  {"ChannelShuffle", &OnnxExporter::CreateChannelShuffleNodes},
285  {"ReduceMean", &OnnxExporter::CreateReduceMeanNodes},
286  {"ReduceFrontMean", &OnnxExporter::CreateReduceMeanNodes},
287  {"ReduceBackMean", &OnnxExporter::CreateReduceMeanNodes},
288  {"ResizeNearest", &OnnxExporter::CreateUpsampleNodes}};
289  return kSpecialOperators;
290 }
291 
292 void OnnxExporter::CopyCaffe2ArgToOnnxAttr(
293  AttributeProto* attr,
294  const std::string& op_type,
295  const caffe2::Argument& arg) {
296  std::string name =
297  caffe2::get_default(get_renamed_attrs(), arg.name(), arg.name());
298  const auto& per_op_renamed_attr_lut = get_per_op_renamed_attrs();
299  const auto it = per_op_renamed_attr_lut.find(op_type);
300  if (it != per_op_renamed_attr_lut.end()) {
301  // Per-op attribute renames override the global attribute renames
302  name = caffe2::get_default(it->second, arg.name(), name);
303  }
304  attr->set_name(name);
305 
306  if (arg.has_f()) {
307  attr->set_f(arg.f());
308  attr->set_type(AttributeProto::FLOAT);
309  } else if (arg.has_i()) {
310  attr->set_i(arg.i());
311  attr->set_type(AttributeProto::INT);
312  } else if (arg.has_s()) {
313  attr->set_s(arg.s());
314  attr->set_type(AttributeProto::STRING);
315  } else if (arg.floats_size()) {
316  attr->mutable_floats()->CopyFrom(arg.floats());
317  attr->set_type(AttributeProto::STRINGS);
318  } else if (arg.ints_size()) {
319  attr->mutable_ints()->CopyFrom(arg.ints());
320  attr->set_type(AttributeProto::INTS);
321  } else if (arg.strings_size()) {
322  attr->mutable_strings()->CopyFrom(arg.strings());
323  attr->set_type(AttributeProto::STRINGS);
324  } else {
325  CAFFE_THROW(c10::str("Unsupported Caffe2 argument: ", arg.name()));
326  }
327 }
328 
329 bool OnnxExporter::IsBlackListed(const caffe2::Argument& arg) {
330  const static std::unordered_map<std::string, std::unordered_set<std::string>>
331  kBlackListString = {{"order", {"NCHW"}}};
332  const static std::unordered_map<std::string, std::unordered_set<int64_t>>
333  kBlackListInt = {{"cudnn_exhaustive_search", {0, 1}},
334  {"use_cudnn", {0, 1}},
335  {"exhaustive_search", {0, 1}},
336  {"is_test", {0, 1}},
337  {"broadcast", {0, 1}}};
338 
339  if (arg.has_i()) {
340  const auto it = kBlackListInt.find(arg.name());
341  if (it != kBlackListInt.end()) {
342  return it->second.count(arg.i());
343  }
344  } else if (arg.has_s()) {
345  const auto it = kBlackListString.find(arg.name());
346  if (it != kBlackListString.end()) {
347  return it->second.count(arg.s());
348  }
349  }
350 
351  return false;
352 }
353 
354 ConvertedResult OnnxExporter::Caffe2OpToOnnxNodes(
355  const caffe2::OperatorDef& def,
356  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
357  std::string type = def.type();
358  const auto& renamed_op_lut = get_renamed_operators();
359  const auto it = renamed_op_lut.find(type);
360  if (it != renamed_op_lut.end()) {
361  type = it->second;
362  }
363  const auto& special_op_lut = get_special_operators();
364  const auto it_op = get_special_operators().find(type);
365  if (it_op != special_op_lut.end()) {
366  return (this->*(it_op->second))(def, shapes);
367  } else {
368  return CommonCaffe2OpToOnnxNodes(def);
369  }
370 }
371 
372 ConvertedResult OnnxExporter::CommonCaffe2OpToOnnxNodes(
373  const caffe2::OperatorDef& def) {
374  ConvertedResult result;
375  auto& nodes = result.first;
376  nodes.emplace_back();
377  NodeProto& node = nodes.back();
378  node.set_name(def.name());
379  node.set_op_type(
380  caffe2::get_default(get_renamed_operators(), def.type(), def.type()));
381  for (const auto& i : def.input()) {
382  node.add_input(i);
383  }
384  for (const auto& o : def.output()) {
385  node.add_output(o);
386  }
387  for (const auto& a : def.arg()) {
388  if (!IsBlackListed(a)) {
389  auto* attr = node.add_attribute();
390  CopyCaffe2ArgToOnnxAttr(attr, def.type(), a);
391  }
392  }
393  return result;
394 }
395 
396 ConvertedResult OnnxExporter::CreateArgMaxMinOpNodes(
397  const caffe2::OperatorDef& def,
398  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
399  auto result = CommonCaffe2OpToOnnxNodes(def);
400  auto& nodes = result.first;
401 
402  CAFFE_ENFORCE_EQ(nodes.size(), 1);
403  auto& node = nodes.back();
404 
405  if (!ArgumentHelper::HasArgument(def, "axis")) {
406  const auto& x = def.input(0);
407  const auto& x_shape = shapes.at(x);
408  node.add_attribute()->CopyFrom(
409  MakeAttribute("axis", x_shape.dims().size() - 1));
410  }
411 
412  return result;
413 }
414 
415 ConvertedResult OnnxExporter::CreateBinaryElementwiseOpNodes(
416  const caffe2::OperatorDef& def,
417  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
418  caffe2::OperatorDef mdef(def); // The modified def without broadcast and axis
419  const auto& x = mdef.input(0);
420  const auto& y = def.input(1); // Refer to the old def, later won't change it.
421  const auto& x_shape = shapes.at(x);
422  const auto& y_shape = shapes.at(y);
423  for (int i = 0; i < mdef.arg_size(); ++i) {
424  const auto& arg = mdef.arg(i);
425  if (arg.name() == "broadcast") {
426  ArgumentHelper::RemoveArgument(mdef, i);
427  break;
428  }
429  }
430  std::vector<int64_t> axes;
431  for (int i = 0; i < mdef.arg_size(); ++i) {
432  const auto& arg = mdef.arg(i);
433  if (arg.name() == "axis") {
434  int64_t axis = arg.i();
435  if (x_shape.dims().size() - axis != y_shape.dims().size()) {
436  // The upper bound (excluded) of expanded y.
437  int64_t end_dim =
438  y_shape.dims().size() - 1 - axis + x_shape.dims().size();
439  axes.resize(end_dim - y_shape.dims().size());
440  std::iota(axes.begin(), axes.end(), y_shape.dims().size());
441  mdef.set_input(1, dummy_->NewDummyName());
442  }
443  ArgumentHelper::RemoveArgument(mdef, i);
444  break;
445  }
446  }
447 
448  auto result = CommonCaffe2OpToOnnxNodes(mdef);
449  if (axes.size() != 0) {
450  result.first.insert(
451  result.first.begin(),
452  MakeNode(
453  "Unsqueeze", {y}, {mdef.input(1)}, {MakeAttribute("axes", axes)}));
454  }
455  return result;
456 }
457 
458 ConvertedResult OnnxExporter::CreateCastNodes(
459  const caffe2::OperatorDef& def,
460  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
461  auto result = CommonCaffe2OpToOnnxNodes(def);
462  auto* attr = result.first[0].mutable_attribute(0);
463  auto onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UNDEFINED;
464  const auto& arg = def.arg(0);
465  if (arg.has_s()) {
466  auto c2_dtype = arg.s();
467  std::transform(
468  c2_dtype.begin(), c2_dtype.end(), c2_dtype.begin(), ::toupper);
469  if (c2_dtype == "FLOAT") {
470  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT;
471  } else if (c2_dtype == "INT32") {
472  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT32;
473  } else if (c2_dtype == "BOOL") {
474  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::BOOL;
475  } else if (c2_dtype == "UINT8") {
476  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UINT8;
477  } else if (c2_dtype == "INT8") {
478  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT8;
479  } else if (c2_dtype == "UINT16") {
480  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UINT16;
481  } else if (c2_dtype == "INT16") {
482  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT16;
483  } else if (c2_dtype == "INT64") {
484  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT64;
485  } else if (c2_dtype == "FLOAT16") {
486  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT16;
487  } else if (c2_dtype == "DOUBLE") {
488  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::DOUBLE;
489  } else {
490  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UNDEFINED;
491  }
492  CAFFE_ENFORCE_NE(
493  onnx_dtype,
494  ::ONNX_NAMESPACE::TensorProto::UNDEFINED,
495  "Casting to '",
496  c2_dtype,
497  "' dtype is not supported");
498  attr->clear_s();
499  attr->set_type(AttributeProto::INT);
500  } else if (arg.has_i()) {
501  const auto& c2_dtype = arg.i();
502  switch (c2_dtype) {
503  case caffe2::TensorProto::FLOAT:
504  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT;
505  break;
506  case caffe2::TensorProto::INT32:
507  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT32;
508  break;
509  case caffe2::TensorProto::BOOL:
510  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::BOOL;
511  break;
512  case caffe2::TensorProto::UINT8:
513  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UINT8;
514  break;
515  case caffe2::TensorProto::INT8:
516  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT8;
517  break;
518  case caffe2::TensorProto::UINT16:
519  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UINT16;
520  break;
521  case caffe2::TensorProto::INT16:
522  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT16;
523  break;
524  case caffe2::TensorProto::INT64:
525  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::INT64;
526  break;
527  case caffe2::TensorProto::FLOAT16:
528  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT16;
529  break;
530  case caffe2::TensorProto::DOUBLE:
531  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::DOUBLE;
532  break;
533 
534  case caffe2::TensorProto::STRING:
535  case caffe2::TensorProto::BYTE:
536  case caffe2::TensorProto::UNDEFINED:
537  onnx_dtype = ::ONNX_NAMESPACE::TensorProto::UNDEFINED;
538  break;
539  }
540  CAFFE_ENFORCE_NE(
541  onnx_dtype,
542  ::ONNX_NAMESPACE::TensorProto::UNDEFINED,
543  "Casting to '",
544  c2_dtype,
545  "' dtype is not supported");
546  }
547  attr->set_i(onnx_dtype);
548  return result;
549 }
550 
551 ConvertedResult OnnxExporter::CreateElementwiseLinearNodes(
552  const caffe2::OperatorDef& def,
553  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
554  CAFFE_ENFORCE_EQ(def.input_size(), 3);
555  CAFFE_ENFORCE_GE(def.output_size(), 1);
556  const auto& x = def.input(0);
557  const auto& w = def.input(1);
558  const auto& b = def.input(2);
559  const auto& y = def.output(0);
560  CAFFE_ENFORCE_EQ(shapes.at(w).dims().size(), 1);
561  CAFFE_ENFORCE_EQ(shapes.at(b).dims().size(), 1);
562 
563  ConvertedResult result;
564  auto& nodes = result.first;
565  auto& const_tensors = result.second;
566  std::unordered_map<std::string, const caffe2::Argument*> args;
567  for (const auto& a : def.arg()) {
568  args.emplace(a.name(), &a);
569  }
570 
571  const auto& x_shape = shapes.at(x);
572  const auto it = args.find("axis");
573  const int64_t axis = it == args.end() ? 1 : it->second->i();
574  const bool need_reshape = axis + 1 != x_shape.dims().size();
575 
576  auto fma_x_input = x;
577  if (need_reshape) {
578  const auto inner = DimProd(x_shape, axis, x_shape.dims().size());
579  CAFFE_ENFORCE_EQ(shapes.at(w).dims(0), inner);
580  CAFFE_ENFORCE_EQ(shapes.at(b).dims(0), inner);
581 
582  fma_x_input = dummy_->NewDummyName();
583  const_tensors.emplace_back(CreateOnnxShapeTensor(
584  dummy_, std::vector<int64_t>{-1, shapes.at(w).dims(0)}));
585  nodes.emplace_back(
586  MakeNode("Reshape", {x, const_tensors.back().name()}, {fma_x_input}));
587  }
588 
589  const auto& mul_output = dummy_->NewDummyName();
590  nodes.emplace_back(
591  MakeNode("Mul", {fma_x_input, w}, {mul_output}, def.name()));
592 
593  const auto& fma_y_output = need_reshape ? dummy_->NewDummyName() : y;
594  nodes.emplace_back(
595  MakeNode("Add", {mul_output, b}, {fma_y_output}, def.name()));
596 
597  if (need_reshape) {
598  const auto shape = dummy_->NewDummyName();
599  nodes.emplace_back(MakeNode("Shape", {x}, {shape}));
600  nodes.emplace_back(MakeNode("Reshape", {fma_y_output, shape}, {y}));
601  }
602 
603  return result;
604 }
605 
606 ConvertedResult OnnxExporter::CreateConvPoolNodes(
607  const caffe2::OperatorDef& def,
608  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
609  auto result = CommonCaffe2OpToOnnxNodes(def);
610  auto& nodes = result.first;
611  auto& node = nodes.back();
612 
613  std::unordered_map<std::string, AttributeProto> attrs;
614  for (const auto& attr : node.attribute()) {
615  attrs.emplace(attr.name(), attr);
616  }
617 
618  // Handle global pooling
619  bool global = false;
620  if (node.op_type() == "MaxPool" || node.op_type() == "AveragePool") {
621  auto it = attrs.find("global_pooling");
622  if (it != attrs.end() && it->second.has_i() && it->second.i()) {
623  node.set_op_type("Global" + node.op_type());
624  global = true;
625  attrs.erase(it);
626  }
627  }
628 
629  ApplyTrans(&attrs, global, "kernel", 2, "kernel_shape");
630  ApplyTrans(&attrs, global, "stride");
631  ApplyTrans(&attrs, global, "dilation");
632  ApplyTrans(&attrs, global, "adj");
633  ApplyTrans(&attrs, global, "pad", 4);
634 
635  // Fix legacy pad attr
636  auto it = attrs.find("legacy_pad");
637  if (it != attrs.end()) {
638  auto legacy_pad_attr = it->second;
639  attrs.erase(it);
640  CAFFE_ENFORCE(
641  node.op_type().size() >= 4 &&
642  (node.op_type().rfind("Pool") == node.op_type().size() - 4));
643  const auto& input_size = shapes.at(node.input(0));
644  const auto& output_size = shapes.at(node.output(0));
645  CAFFE_ENFORCE_EQ(output_size.dims().size(), 4);
646  if (!global && // global pool does not care about legacy pad
647  legacy_pad_attr.i() != static_cast<int64_t>(caffe2::LegacyPadding::NOTSET)) {
648  if (legacy_pad_attr.i() ==
649  static_cast<int64_t>(caffe2::LegacyPadding::VALID)) {
650  CAFFE_ENFORCE(!attrs.count("pads"));
651  attrs.emplace("auto_pad", MakeAttribute("auto_pad", "VALID"));
652  } else if (legacy_pad_attr.i() ==
653  static_cast<int64_t>(caffe2::LegacyPadding::SAME)) {
654  CAFFE_ENFORCE(!attrs.count("pads"));
655  // default behavior in Caffe2 is SAME_UPPER
656  // https://github.com/caffe2/caffe2/blob/master/caffe2/operators/conv_pool_op_base.h#L39
657  attrs.emplace("auto_pad", MakeAttribute("auto_pad", "SAME_UPPER"));
658  } else if (legacy_pad_attr.i() ==
659  static_cast<int64_t>(caffe2::LegacyPadding::CAFFE_LEGACY_POOLING)) {
660  // The problem here is that, Pool op in Caffe may add an additional pixel,
661  // if the last part is smaller than stride. So we use the explicit padding
662  // to replace legacy_pad. pad[end] = output_size[start + 2] *
663  // stride[start] - pad[start] - 1 + kernel[start] - input[start + 2] end =
664  // start + len(pad) / 2
665  LOG(WARNING) << "Converting legacy padding to explicit padding.";
666  auto* pads_attr = attrs.at("pads").mutable_ints();
667  auto& strides_attr = attrs.at("strides").ints();
668  auto& kernel_shape_attr = attrs.at("kernel_shape").ints();
669  for (int i = 0; i < 2; ++i) {
670  int64_t tmp_pad = output_size.dims(i + 2) * strides_attr.Get(i) -
671  pads_attr->Get(i) - 1 + kernel_shape_attr.Get(i) -
672  input_size.dims(i + 2);
673  pads_attr->Set(i + 2, tmp_pad);
674  }
675  } else {
676  LOG(ERROR) << "Don't know how to handle the legacy_pad:" << legacy_pad_attr.i();
677  CAFFE_THROW("Failed to handle legacy padding in pool operator!");
678  }
679  }
680  }
681 
682  node.clear_attribute();
683  for (const auto& kv : attrs) {
684  auto* attr = node.add_attribute();
685  attr->CopyFrom(kv.second);
686  }
687 
688  return result;
689 }
690 
691 ConvertedResult OnnxExporter::CreateLrnNodes(
692  const caffe2::OperatorDef& def,
693  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
694  auto result = CommonCaffe2OpToOnnxNodes(def);
695  auto& nodes = result.first;
696 
697  CAFFE_ENFORCE_EQ(nodes.size(), 1);
698  auto& node = nodes.back();
699  if (node.output_size() == 2) {
700  node.mutable_output()->RemoveLast();
701  }
702 
703  return result;
704 }
705 
706 ConvertedResult OnnxExporter::CreateConcatNodes(
707  const caffe2::OperatorDef& def,
708  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
709  caffe2::OperatorDef mdef(def); // The modified def without add_axis
710  // In caffe2, we can optionally add an axis specified by `add_axis`
711  int add_axis = 0;
712  for (int i = 0; i < mdef.arg_size(); ++i) {
713  const auto& arg = mdef.arg(i);
714  if (arg.name() == "add_axis") {
715  add_axis = arg.i();
716  ArgumentHelper::RemoveArgument(mdef, i);
717  break;
718  }
719  }
720 
721  auto result = CommonCaffe2OpToOnnxNodes(mdef);
722  auto& nodes = result.first;
723  nodes.reserve(nodes.size() + 3);
724  auto& const_tensors = result.second;
725 
726  CAFFE_ENFORCE_EQ(nodes.size(), 1);
727  auto& node = nodes.back();
728  bool explicit_axis = false;
729  int axis = -1;
730  if (ArgumentHelper::HasArgument(mdef, "axis")) {
731  axis = ArgumentHelper::GetSingleArgument(mdef, "axis", -1);
732  explicit_axis = true;
733  }
734  if (!explicit_axis) {
735  node.add_attribute()->CopyFrom(MakeAttribute("axis", 1));
736  }
737 
738  // If we have add_axis, we need to add a reshape node
739  auto final_output = node.output(0);
740  if (add_axis > 0) {
741  CAFFE_ENFORCE_GE(axis, 0);
742  std::vector<int64_t> dims;
743  const auto& shape0 = shapes.at(mdef.input(0));
744  for (int i = 1; i < mdef.input_size(); ++i) {
745  const auto& shape = shapes.at(mdef.input(i));
746  CAFFE_ENFORCE_EQ(shape.dims(axis), shape0.dims(axis));
747  }
748  for (const auto d : shape0.dims()) {
749  dims.push_back(d);
750  }
751  dims.insert(dims.begin() + axis, mdef.input_size());
752 
753  auto concat_output = dummy_->NewDummyName();
754  *node.mutable_output(0) = concat_output;
755  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
756  nodes.emplace_back(MakeNode(
757  "Reshape",
758  {concat_output, const_tensors.back().name()},
759  {final_output}));
760  }
761 
762  // If we have two output, we need to output the split_info, which can be
763  // statically inferred from the input shapes
764  if (node.output_size() == 2) {
765  std::string second_output = node.output(1);
766  node.mutable_output()->RemoveLast();
767  std::vector<int32_t> split_info;
768  int adj_size = shapes.at(mdef.input(0)).dims_size() + (add_axis ? 1 : 0);
769  int canonical_axis = canonical_axis_index_(axis, adj_size);
770  CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
771  for (int i = 0; i < mdef.input_size(); ++i) {
772  split_info.push_back(
773  add_axis ? 1 : shapes.at(mdef.input(i)).dims(canonical_axis));
774  }
775  auto split_info_tensor =
776  MakeTensor("split_info", split_info, TensorProto::INT32);
777  auto cnode = MakeNode("Constant", {}, {second_output});
778  cnode.add_attribute()->CopyFrom(MakeAttribute("value", split_info_tensor));
779  nodes.emplace_back(std::move(cnode));
780  }
781  return result;
782 }
783 
784 ConvertedResult OnnxExporter::CreateMergeDimNodes(
785  const caffe2::OperatorDef& def,
786  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
787  const auto& x = def.input(0);
788  const auto& y = def.output(0);
789 
790  ConvertedResult result;
791  auto& nodes = result.first;
792  auto& const_tensors = result.second;
793 
794  {
795  const auto ndim = shapes.at(x).dims().size();
796  CAFFE_ENFORCE_GE(ndim, 2, "No enough dims to merge.");
797  std::vector<int64_t> dims(ndim);
798  dims[0] = 1;
799  dims[1] = -1;
800  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
801  }
802 
803  const auto reshaped = dummy_->NewDummyName();
804  nodes.emplace_back(MakeNode("Reshape",
805  { x, const_tensors.back().name() },
806  { reshaped }));
807 
808  nodes.emplace_back(MakeNode("Squeeze",
809  { reshaped },
810  { y },
811  std::vector<AttributeProto>{
812  MakeAttribute("axes", std::vector<int64_t>{ 0 }),
813  }));
814 
815  return result;
816 }
817 
818 ConvertedResult OnnxExporter::CreateChannelShuffleNodes(
819  const caffe2::OperatorDef& def,
820  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
821  const auto& x = def.input(0);
822  const auto& y = def.output(0);
823  const auto& x_shape = shapes.at(x);
824  CAFFE_ENFORCE_EQ(
825  x_shape.dims().size(),
826  4,
827  "Input shape of ChannelShuffle needs to be in NCHW format");
828  auto n = x_shape.dims(0);
829  auto c = x_shape.dims(1);
830  auto h = x_shape.dims(2);
831  auto w = x_shape.dims(3);
832  int64_t g = 0;
833  for (const auto& arg : def.arg()) {
834  if (arg.name() == "group") {
835  g = arg.i();
836  break;
837  }
838  }
839  CAFFE_ENFORCE(g && c % g == 0);
840  ConvertedResult result;
841  auto& nodes = result.first;
842  auto& const_tensors = result.second;
843 
844  const auto reshape_output = dummy_->NewDummyName();
845  std::vector<int64_t> dims = {n, g, c / g, h, w};
846  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
847  nodes.emplace_back(
848  MakeNode("Reshape", {x, const_tensors.back().name()}, {reshape_output}));
849 
850  const auto transpose_output = dummy_->NewDummyName();
851  dims = {0, 2, 1, 3, 4};
852  nodes.emplace_back(MakeNode(
853  "Transpose",
854  {reshape_output},
855  {transpose_output},
856  {MakeAttribute("perm", dims)}));
857 
858  dims = {n, c, h, w};
859  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, dims));
860  nodes.emplace_back(MakeNode(
861  "Reshape", {transpose_output, const_tensors.back().name()}, {y}));
862 
863  return result;
864 }
865 
866 ConvertedResult OnnxExporter::CreateReduceMeanNodes(
867  const caffe2::OperatorDef& def,
868  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
869  CAFFE_ENFORCE_GE(def.input_size(), 1);
870  CAFFE_ENFORCE_LE(def.input_size(), 2);
871  CAFFE_ENFORCE_EQ(def.input_size(), 1, "Input \"lengths\" is not supported.");
872  CAFFE_ENFORCE_GE(def.output_size(), 1);
873  const auto& x = def.input(0);
874  const auto& y = def.output(0);
875  const auto& dims = shapes.at(x).dims();
876 
877  ConvertedResult result;
878  auto& nodes = result.first;
879  auto& const_tensors = result.second;
880  std::unordered_map<std::string, const caffe2::Argument*> args;
881  for (const auto& a : def.arg()) {
882  args.emplace(a.name(), &a);
883  }
884 
885  std::vector<int64_t> axes;
886  int64_t keepdims = 1;
887 
888  if (def.type() == "ReduceMean") {
889  // axes
890  auto it = args.find("axes");
891  if (it == args.end()) {
892  axes.resize(dims.size());
893  std::iota(axes.begin(), axes.end(), 0);
894  } else {
895  axes.assign(it->second->ints().begin(), it->second->ints().end());
896  }
897 
898  // keepdims
899  it = args.find("keepdims");
900  if (it != args.end()) {
901  keepdims = it->second->i();
902  }
903  } else {
904  // num_reduce_dim
905  auto it = args.find("num_reduce_dim");
906  const int64_t num_reduce_dim = it == args.end() ? 1 : it->second->i();
907  CAFFE_ENFORCE_LE(num_reduce_dim, dims.size());
908  axes.resize(num_reduce_dim);
909 
910  int64_t start_dim = 0;
911  if (def.type() == "ReduceFrontMean") {
912  start_dim = 0;
913  } else if (def.type() == "ReduceBackMean") {
914  start_dim = dims.size() - axes.size();
915  }
916  std::iota(axes.begin(), axes.end(), start_dim);
917 
918  keepdims = 0;
919  }
920 
921  nodes.emplace_back(MakeNode("ReduceMean",
922  { x },
923  { y },
924  {
925  MakeAttribute("axes", axes),
926  MakeAttribute("keepdims", keepdims),
927  },
928  def.name()));
929 
930  return result;
931 }
932 
933 ConvertedResult OnnxExporter::CreateUpsampleNodes(
934  const caffe2::OperatorDef& def,
935  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
936  ConvertedResult result;
937  //{H, W} => {1, 1, H, W}
938  auto& nodes = result.first;
939  auto resolved_scale = dummy_->NewDummyName();
940  if (def.input_size() == 1) {
941  float width_scale = 1.0;
942  float height_scale = 1.0;
943  for (const auto& a : def.arg()) {
944  if (a.name() == "width_scale") {
945  width_scale = a.f();
946  } else if (a.name() == "height_scale") {
947  height_scale = a.f();
948  }
949  }
950  CAFFE_ENFORCE_GT(width_scale, 0);
951  CAFFE_ENFORCE_GT(height_scale, 0);
952  std::vector<float> tmp_vector = {1, 1, height_scale, width_scale};
953  auto resolved_scale_tensor =
954  MakeTensor("resolved scale tensor", tmp_vector, TensorProto::FLOAT);
955 
956  auto node = MakeNode("Constant", {}, {resolved_scale});
957  node.add_attribute()->CopyFrom(
958  MakeAttribute("value", resolved_scale_tensor));
959  nodes.emplace_back(node);
960 
961  } else {
962  CAFFE_ENFORCE_EQ(def.input_size(), 2);
963  std::vector<float> tmp_vector = {1, 1};
964  auto scale_pads_tensor =
965  MakeTensor("scale pads", tmp_vector, TensorProto::FLOAT);
966  auto unresolved_scale_pads = dummy_->NewDummyName();
967 
968  auto node = MakeNode("Constant", {}, {unresolved_scale_pads});
969  node.add_attribute()->CopyFrom(MakeAttribute("value", scale_pads_tensor));
970  nodes.emplace_back(node);
971 
972  node = MakeNode(
973  "Concat", {unresolved_scale_pads, def.input(1)}, {resolved_scale});
974  node.add_attribute()->CopyFrom(MakeAttribute("axis", 0));
975  nodes.emplace_back(node);
976  }
977  std::vector<std::string> inputs = {def.input(0), resolved_scale};
978  std::vector<std::string> outputs(def.output().begin(), def.output().end());
979  auto node = MakeNode("Upsample", inputs, outputs, def.name());
980  node.add_attribute()->CopyFrom(MakeAttribute("mode", "nearest"));
981  nodes.emplace_back(node);
982  return result;
983 }
984 
985 ConvertedResult OnnxExporter::CreateSliceNodes(
986  const caffe2::OperatorDef& def,
987  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
988  CAFFE_ENFORCE_EQ(
989  def.input_size(),
990  1,
991  "ONNX Slice operator does not support dynamic slice.");
992  auto result = CommonCaffe2OpToOnnxNodes(def);
993  auto& nodes = result.first;
994  CAFFE_ENFORCE_EQ(nodes.size(), 1);
995  auto& node = nodes.back();
996  const auto& shape = shapes.at(node.input(0));
997 
998  std::vector<int64_t> dims;
999  for (auto& attr : *node.mutable_attribute()) {
1000  if (attr.name() == "starts") {
1001  auto len = attr.ints_size();
1002  if (len) {
1003  dims.resize(len);
1004  std::iota(dims.begin(), dims.end(), 0);
1005  }
1006  } else if (attr.name() == "ends") {
1007  for (int i = 0; i < attr.ints_size(); ++i) {
1008  auto end = attr.ints(i);
1009  if (end >= 0) {
1010  continue;
1011  }
1012  if (end == -1) {
1013  end = shape.dims(i);
1014  } else {
1015  ++end;
1016  }
1017  attr.set_ints(i, end);
1018  }
1019  }
1020  }
1021  if (!dims.empty()) {
1022  node.add_attribute()->CopyFrom(MakeAttribute("axes", dims));
1023  }
1024 
1025  return result;
1026 }
1027 
1028 ConvertedResult OnnxExporter::CreateReshapeNodes(
1029  const caffe2::OperatorDef& def,
1030  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1031  auto result = CommonCaffe2OpToOnnxNodes(def);
1032  auto& nodes = result.first;
1033  auto& const_tensors = result.second;
1034  CAFFE_ENFORCE_EQ(nodes.size(), 1);
1035  auto& node = nodes.back();
1036 
1037  int i = 0;
1038  int attr_size = node.attribute_size();
1039  for (; i < attr_size; ++i) {
1040  const auto& attr = node.attribute(i);
1041  if (attr.name() == "shape") {
1042  std::vector<int64_t> shape;
1043  for (const auto k : attr.ints()) {
1044  shape.push_back(k);
1045  }
1046  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, shape));
1047  node.add_input(const_tensors.back().name());
1048  break;
1049  }
1050  }
1051  if (i != attr_size) {
1052  if (i != attr_size - 1) {
1053  node.mutable_attribute()->SwapElements(i, attr_size - 1);
1054  }
1055  node.mutable_attribute()->RemoveLast();
1056  }
1057 
1058  if (node.output_size() == 2) {
1059  std::string shape_input = node.output(0);
1060  std::string shape_output = node.output(1);
1061  node.mutable_output()->RemoveLast();
1062  nodes.emplace_back(AddShapeNode(shape_input, shape_output));
1063  }
1064 
1065  return result;
1066 }
1067 
1068 ConvertedResult OnnxExporter::CreateGemmNodes(
1069  const caffe2::OperatorDef& def,
1070  const std::unordered_map<std::string, caffe2::TensorShape>& shapes) {
1071  CAFFE_ENFORCE_EQ(def.input_size(), 3);
1072  CAFFE_ENFORCE_GE(def.output_size(), 1);
1073  auto x = def.input(0);
1074  auto w = def.input(1);
1075  const auto& b = def.input(2);
1076  const auto& y = def.output(0);
1077  const auto& x_shape = shapes.at(x);
1078  const auto& w_shape = shapes.at(w);
1079  CAFFE_ENFORCE_GE(x_shape.dims().size(), 2);
1080  CAFFE_ENFORCE_GE(w_shape.dims().size(), 2);
1081 
1082  ConvertedResult result;
1083  auto& nodes = result.first;
1084  auto& const_tensors = result.second;
1085  std::unordered_map<std::string, const caffe2::Argument*> args;
1086  for (const auto& a : def.arg()) {
1087  args.emplace(a.name(), &a);
1088  }
1089 
1090  auto it = args.find("axis");
1091  int64_t axis = 1;
1092  bool has_axis = (it != args.end());
1093  if (has_axis) {
1094  axis = it->second->i();
1095  }
1096 
1097  auto gemm_x_input = x;
1098  if (x_shape.dims().size() > 2) {
1099  // we need to reshape only when dimension is higher than 2
1100  const auto inner = DimProd(x_shape, axis, x_shape.dims().size());
1101 
1102  gemm_x_input = dummy_->NewDummyName();
1103  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_,
1104  std::vector<int64_t>{ -1, inner }));
1105  nodes.emplace_back(MakeNode("Reshape",
1106  { x, const_tensors.back().name() },
1107  { gemm_x_input }));
1108  }
1109 
1110  it = args.find("axis_w");
1111  int64_t axis_w = 1;
1112  if (it != args.end()) {
1113  axis_w = it->second->i();
1114  }
1115  if (w_shape.dims().size() > 2) {
1116  // we need to reshape only when dimension is higher than 2
1117  auto outer = DimProd(w_shape, 0, axis_w);
1118  auto inner = DimProd(w_shape, axis_w, w_shape.dims().size());
1119  auto reshaped_w = dummy_->NewDummyName();
1120  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_,
1121  std::vector<int64_t>{ outer, inner }));
1122  nodes.emplace_back(MakeNode("Reshape",
1123  { w, const_tensors.back().name() },
1124  { reshaped_w }));
1125  w = reshaped_w;
1126  }
1127 
1128  auto gemm_y_output = axis > 1 ? dummy_->NewDummyName() : y;
1129  nodes.emplace_back(MakeNode("Gemm",
1130  { gemm_x_input, w, b },
1131  { gemm_y_output },
1132  { MakeAttribute("transB", 1L) },
1133  def.name()));
1134 
1135  // capture the outer shape if needed.
1136  if (axis > 1) {
1137  const auto x_shape = dummy_->NewDummyName();
1138  nodes.emplace_back(MakeNode("Shape", {x}, {x_shape}));
1139 
1140  const auto x_shape_outer = dummy_->NewDummyName();
1141  nodes.emplace_back(MakeNode("Slice",
1142  { x_shape },
1143  { x_shape_outer },
1144  std::vector<AttributeProto>{
1145  MakeAttribute("starts", std::vector<int64_t>{ 0 }),
1146  MakeAttribute("ends", std::vector<int64_t>{ axis }),
1147  }));
1148 
1149  const auto y_shape = dummy_->NewDummyName();
1150  const_tensors.emplace_back(CreateOnnxShapeTensor(dummy_, { -1 }));
1151  nodes.emplace_back(MakeNode("Concat",
1152  { x_shape_outer, const_tensors.back().name() },
1153  { y_shape },
1154  std::vector<AttributeProto>{
1155  MakeAttribute("axis", static_cast<int64_t>(0)),
1156  }));
1157 
1158  nodes.emplace_back(MakeNode("Reshape",
1159  { gemm_y_output, y_shape },
1160  { y }));
1161  }
1162 
1163  return result;
1164 }
1165 
1166 void OnnxExporter::InitOpToTensorProto(
1167  const caffe2::OperatorDef& op,
1168  TensorProto* tensor) {
1169  CAFFE_ENFORCE_EQ(op.input_size(), 0);
1170  CAFFE_ENFORCE_EQ(op.output_size(), 1);
1171 
1172  // Set name
1173  tensor->set_name(op.output(0));
1174 
1175  const Argument* values = nullptr;
1176  const Argument* shape = nullptr;
1177  for (const auto& arg: op.arg()) {
1178  if (arg.name() == "values") {
1179  values = &arg;
1180  } else if (arg.name() == "shape") {
1181  shape = &arg;
1182  }
1183  }
1184 
1185  CAFFE_ENFORCE(values);
1186  CAFFE_ENFORCE(shape);
1187 
1188  // Set dims
1189  for (const auto i: shape->ints()) {
1190  tensor->add_dims(i);
1191  }
1192 
1193  // Set value
1194  if (op.type() == "GivenTensorFill") {
1195  tensor->set_data_type(TensorProto::FLOAT);
1196  for (const auto i : values->floats()) {
1197  tensor->add_float_data(i);
1198  }
1199  } else if (op.type() == "GivenTensorInt64Fill") {
1200  tensor->set_data_type(TensorProto::INT64);
1201  for (const auto i : values->ints()) {
1202  tensor->add_int64_data(i);
1203  }
1204  } else if (op.type() == "GivenTensorIntFill") {
1205  tensor->set_data_type(TensorProto::INT32);
1206  for (const auto i : values->ints()) {
1207  tensor->add_int32_data(i);
1208  }
1209  } else if (op.type() == "GivenTensorBoolFill") {
1210  tensor->set_data_type(TensorProto::INT32);
1211  for (const auto i : values->ints()) {
1212  tensor->add_int32_data(i);
1213  }
1214  } else if (op.type() == "GivenTensorStringFill") {
1215  tensor->set_data_type(TensorProto::STRING);
1216  // TODO: we might need to do two pass to avoid adverse memory allocations
1217  for (const auto& s : values->strings()) {
1218  tensor->mutable_raw_data()->append(s);
1219  }
1220  } else {
1221  CAFFE_THROW(
1222  c10::str("Cannot convert C2 op ", op.type(), "to ONNX TensorProto"));
1223  }
1224 }
1225 
1226 } // namespace onnx
1227 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13