1 #include "caffe2/core/logging.h" 2 #include "caffe2/core/operator.h" 3 #include "caffe2/onnx/backend.h" 4 #include "caffe2/onnx/device.h" 5 #include "caffe2/onnx/helper.h" 6 #include "caffe2/utils/map_utils.h" 7 #include "caffe2/utils/proto_utils.h" 10 #include "onnx/checker.h" 11 #include "onnx/optimizer/optimize.h" 14 #include "google/protobuf/io/coded_stream.h" 15 #include "google/protobuf/io/zero_copy_stream_impl_lite.h" 21 #include <unordered_map> 22 #include <unordered_set> 29 bool AlmostEqual(
double a,
double b) {
30 constexpr
static double kEps = 1e-15;
31 return (fabs(a - b) < kEps);
35 bool TryConvertingTensorRawValues(
36 const TensorProto& onnx_tensor,
37 ::google::protobuf::RepeatedField<T>* field) {
38 if (!onnx_tensor.has_raw_data()) {
42 size_t raw_size = onnx_tensor.raw_data().size();
43 CAFFE_ENFORCE_EQ(raw_size %
sizeof(
T), 0);
45 size_t num_elements = raw_size /
sizeof(
T);
46 const void* src_ptr =
static_cast<const void*
>(onnx_tensor.raw_data().data());
47 field->Resize(num_elements, 0);
48 void* target_ptr =
static_cast<void*
>(field->mutable_data());
49 memcpy(target_ptr, src_ptr, raw_size);
54 bool IsOperator(
const std::string& op_type) {
57 static std::set<std::string>* ops_ =
58 new std::set<std::string>(caffe2::GetRegisteredOperators());
59 return ops_->count(caffe2::OpRegistryKey(op_type,
"DEFAULT"));
62 caffe2::DeviceOption GetDeviceOption(
const Device& onnx_device) {
63 static const std::unordered_map<DeviceType, caffe2::DeviceType> m = {
64 {DeviceType::CPU, caffe2::DeviceType::CPU},
65 {DeviceType::CUDA, caffe2::DeviceType::CUDA}};
66 caffe2::DeviceOption d;
67 d.set_device_type(static_cast<int32_t>(m.at(onnx_device.type)));
68 d.set_device_id(onnx_device.device_id);
73 ModelProto OptimizeOnnx(
const ModelProto& input,
bool init) {
74 std::vector<std::string> passes{
"fuse_consecutive_transposes",
75 "eliminate_nop_transpose",
76 "fuse_transpose_into_gemm"};
79 passes.emplace_back(
"split_init");
81 passes.emplace_back(
"split_predict");
83 return ::ONNX_NAMESPACE::optimization::Optimize(input, passes);
87 template <
class T,
class U>
89 const std::unordered_map<T, U>& map,
91 const U& default_value) {
92 const auto it = map.find(key);
93 if (it == map.end()) {
100 void UpdateNames(std::shared_ptr<DummyName> dummy,
const caffe2::OperatorDef& op) {
101 for (
const auto& n : op.input()) {
104 for (
const auto& n : op.output()) {
110 caffe2::OperatorDef* c2_op,
111 const std::string& op_type,
112 const std::vector<std::string>& inputs,
113 const std::vector<std::string>& outputs,
114 const std::vector<caffe2::Argument>& args) {
116 c2_op->set_type(op_type);
117 for (
const auto& input : inputs) {
118 c2_op->add_input(input);
120 for (
const auto& output : outputs) {
121 c2_op->add_output(output);
123 for (
const auto& arg : args) {
124 auto* tmp = c2_op->add_arg();
130 caffe2::OperatorDef* c2_op,
131 const std::string& op_type,
132 const std::vector<std::string>& inputs,
133 const std::vector<std::string>& outputs) {
134 std::vector<caffe2::Argument> empty;
135 BuildOperator(c2_op, op_type, inputs, outputs, empty);
138 void CopyOnnxAttrValueToCaffe2Arg(
140 const AttributeProto& attr) {
142 arg->set_f(attr.f());
143 }
else if (attr.has_i()) {
144 arg->set_i(attr.i());
145 }
else if (attr.has_s()) {
146 arg->set_s(attr.s());
147 }
else if (attr.has_t()) {
150 attr.t().SerializeToString(&buffer);
152 }
else if (attr.floats_size()) {
153 arg->mutable_floats()->CopyFrom(attr.floats());
154 }
else if (attr.ints_size()) {
155 arg->mutable_ints()->CopyFrom(attr.ints());
156 }
else if (attr.strings_size()) {
157 arg->mutable_strings()->CopyFrom(attr.strings());
159 CAFFE_THROW(
"Unsupported ONNX attribute: ", attr.name());
164 OnnxAttributes::OnnxAttributes(
const NodeProto& node) {
165 for (
const auto& attr : node.attribute()) {
166 onnx_attrs_.emplace(attr.name(), &attr);
171 int64_t OnnxAttributes::get(
const std::string& key)
const {
173 const auto it = onnx_attrs_.find(key);
174 if (it != onnx_attrs_.end()) {
175 const AttributeProto& attr = *it->second;
182 float OnnxAttributes::get(
const std::string& key)
const {
184 const auto it = onnx_attrs_.find(key);
185 if (it != onnx_attrs_.end()) {
186 const AttributeProto& attr = *it->second;
193 ::google::protobuf::RepeatedPtrField<std::string> OnnxAttributes::get(
194 const std::string& key)
const {
195 ::google::protobuf::RepeatedPtrField<std::string> value;
196 const auto it = onnx_attrs_.find(key);
197 if (it != onnx_attrs_.end()) {
198 const AttributeProto& attr = *it->second;
199 value.CopyFrom(attr.strings());
205 ::google::protobuf::RepeatedField<::google::protobuf::int64>
206 OnnxAttributes::get(
const std::string& key)
const {
207 ::google::protobuf::RepeatedField<::google::protobuf::int64> value;
208 const auto it = onnx_attrs_.find(key);
209 if (it != onnx_attrs_.end()) {
210 const AttributeProto& attr = *it->second;
211 value.CopyFrom(attr.ints());
217 ::google::protobuf::RepeatedField<float>
218 OnnxAttributes::get(
const std::string& key)
const {
219 ::google::protobuf::RepeatedField<float> value;
220 const auto it = onnx_attrs_.find(key);
221 if (it != onnx_attrs_.end()) {
222 const AttributeProto& attr = *it->second;
223 value.CopyFrom(attr.floats());
229 const TensorProto* OnnxAttributes::get(
const std::string& key)
const {
230 const TensorProto* value =
nullptr;
231 const auto it = onnx_attrs_.find(key);
232 if (it != onnx_attrs_.end()) {
233 const AttributeProto& attr = *it->second;
239 ::google::protobuf::RepeatedPtrField<caffe2::Argument>
240 OnnxAttributes::OnnxAttrToCaffe2Arg(
241 std::function<std::string(
const std::string&)> mapper)
const {
242 ::google::protobuf::RepeatedPtrField<caffe2::Argument> args;
243 for (
const auto& kv : onnx_attrs_) {
246 const auto& attr = rewritten_onnx_attrs_.count(kv.first)
247 ? rewritten_onnx_attrs_.at(kv.first)
249 auto* arg = args.Add();
250 arg->set_name(mapper(attr.name()));
251 CopyOnnxAttrValueToCaffe2Arg(arg, attr);
253 for (
const auto& kv : rewritten_onnx_attrs_) {
256 if (!onnx_attrs_.count(kv.first)) {
257 const auto& attr = kv.second;
258 auto* arg = args.Add();
259 arg->set_name(mapper(attr.name()));
260 CopyOnnxAttrValueToCaffe2Arg(arg, attr);
267 const std::unordered_map<std::string, int>&
268 Caffe2Backend::get_broken_operators()
const {
269 const static std::unordered_map<std::string, int> kBrokenOperators{};
270 return kBrokenOperators;
275 const std::unordered_set<std::string>& Caffe2Backend::get_rnn_operators()
277 const static std::unordered_set<std::string> kRNNOperators{
278 "LSTM",
"GRU",
"RNN"};
279 return kRNNOperators;
286 const std::unordered_map<std::string, std::string>&
287 Caffe2Backend::get_renamed_operators()
const {
288 const static std::unordered_map<std::string, std::string> kRenamedOperators{
289 {
"Caffe2ConvTranspose",
"ConvTranspose"},
290 {
"GlobalMaxPool",
"MaxPool"},
291 {
"GlobalAveragePool",
"AveragePool"},
294 {
"BatchNormalization",
"SpatialBN"},
295 {
"InstanceNormalization",
"InstanceNorm"},
296 {
"MatMul",
"BatchMatMul"},
297 {
"Upsample",
"ResizeNearest"},
298 {
"Identity",
"Copy"},
299 {
"InstanceNormalization",
"InstanceNorm"},
303 {
"Unsqueeze",
"ExpandDims"},
304 {
"Tile",
"NumpyTile"},
305 {
"DynamicSlice",
"Slice"},
306 {
"ConstantOfShape",
"ConstantFill"},
307 {
"RandomNormal",
"GaussianFill"}};
308 return kRenamedOperators;
311 const std::unordered_map<std::string, std::string>&
312 Caffe2Backend::get_renamed_attrs()
const {
313 const static std::unordered_map<std::string, std::string> kRenamedAttrs{
314 {
"kernel_shape",
"kernels"}};
315 return kRenamedAttrs;
319 unordered_map<std::string, std::unordered_map<std::string, std::string>>&
320 Caffe2Backend::get_per_op_renamed_attrs()
const {
322 unordered_map<std::string, std::unordered_map<std::string, std::string>>
323 kPerOpRenamedAttrs = {{
"Squeeze", {{
"axes",
"dims"}}},
324 {
"Unsqueeze", {{
"axes",
"dims"}}},
325 {
"Transpose", {{
"perm",
"axes"}}},
326 {
"ConvTranspose", {{
"output_padding",
"adjs"}}},
327 {
"Selu", {{
"gamma",
"scale"}}}};
329 return kPerOpRenamedAttrs;
335 const std::unordered_map<std::string, Caffe2Backend::SpecialOpConverter>&
336 Caffe2Backend::get_special_operators()
const {
338 unordered_map<std::string, Caffe2Backend::SpecialOpConverter>
339 kSpecialOperators = {
340 {
"ArgMax", &Caffe2Backend::CreateArgMaxMin},
341 {
"ArgMin", &Caffe2Backend::CreateArgMaxMin},
342 {
"Cast", &Caffe2Backend::CreateCast},
343 {
"Constant", &Caffe2Backend::CreateConstant},
344 {
"ConstantOfShape", &Caffe2Backend::CreateConstantOfShape},
345 {
"Conv", &Caffe2Backend::CreateConvPoolOpBase},
346 {
"AveragePool", &Caffe2Backend::CreateConvPoolOpBase},
347 {
"GlobalAveragePool", &Caffe2Backend::CreateConvPoolOpBase},
348 {
"GlobalMaxPool", &Caffe2Backend::CreateConvPoolOpBase},
349 {
"MaxPool", &Caffe2Backend::CreateConvPoolOpBase},
350 {
"Reshape", &Caffe2Backend::CreateReshape},
351 {
"Gather", &Caffe2Backend::CreateGather},
352 {
"Gemm", &Caffe2Backend::CreateGemm},
353 {
"Pad", &Caffe2Backend::CreatePad},
354 {
"Concat", &Caffe2Backend::CreateConcat},
355 {
"LogSoftmax", &Caffe2Backend::CreateLogSoftmax},
356 {
"Slice", &Caffe2Backend::CreateSlice},
357 {
"Split", &Caffe2Backend::CreateSplit},
358 {
"Reciprocal", &Caffe2Backend::CreateReciprocal},
359 {
"BatchNormalization", &Caffe2Backend::CreateBatchNormalization},
360 {
"MatMul", &Caffe2Backend::CreateMatMul},
361 {
"Upsample", &Caffe2Backend::CreateUpsample},
362 {
"Dropout", &Caffe2Backend::CreateDropout},
363 {
"LRN", &Caffe2Backend::CreateLRN},
364 {
"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
365 {
"RandomNormal", &Caffe2Backend::CreateRandomNormal}};
366 return kSpecialOperators;
373 Caffe2Ops Caffe2Backend::CreateArgMaxMin(
375 const ConversionContext& ctx) {
376 auto& attributes = onnx_node->attributes;
377 if (!attributes.HasAttribute(
"axis")) {
378 auto* attr = attributes.AddRewrittenAttribute(
"axis");
381 return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
384 Caffe2Ops Caffe2Backend::CreateCast(
386 const ConversionContext& ctx) {
387 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
390 onnx_node->attributes.get<int64_t>(
"to", TensorProto::UNDEFINED);
391 auto c2_dtype = caffe2::TensorProto::UNDEFINED;
392 switch (onnx_dtype) {
393 case ::ONNX_NAMESPACE::TensorProto::FLOAT:
394 c2_dtype = caffe2::TensorProto::FLOAT;
396 case ::ONNX_NAMESPACE::TensorProto::UINT8:
397 c2_dtype = caffe2::TensorProto::UINT8;
399 case ::ONNX_NAMESPACE::TensorProto::INT8:
400 c2_dtype = caffe2::TensorProto::INT8;
402 case ::ONNX_NAMESPACE::TensorProto::UINT16:
403 c2_dtype = caffe2::TensorProto::UINT16;
405 case ::ONNX_NAMESPACE::TensorProto::INT16:
406 c2_dtype = caffe2::TensorProto::INT16;
408 case ::ONNX_NAMESPACE::TensorProto::INT32:
409 c2_dtype = caffe2::TensorProto::INT32;
411 case ::ONNX_NAMESPACE::TensorProto::INT64:
412 c2_dtype = caffe2::TensorProto::INT64;
414 case ::ONNX_NAMESPACE::TensorProto::STRING:
415 c2_dtype = caffe2::TensorProto::STRING;
417 case ::ONNX_NAMESPACE::TensorProto::BOOL:
418 c2_dtype = caffe2::TensorProto::BOOL;
420 case ::ONNX_NAMESPACE::TensorProto::FLOAT16:
421 c2_dtype = caffe2::TensorProto::FLOAT16;
423 case ::ONNX_NAMESPACE::TensorProto::DOUBLE:
424 c2_dtype = caffe2::TensorProto::DOUBLE;
426 case ::ONNX_NAMESPACE::TensorProto::UINT32:
427 case ::ONNX_NAMESPACE::TensorProto::UINT64:
428 case ::ONNX_NAMESPACE::TensorProto::COMPLEX64:
429 case ::ONNX_NAMESPACE::TensorProto::COMPLEX128:
430 case ::ONNX_NAMESPACE::TensorProto::UNDEFINED:
431 c2_dtype = caffe2::TensorProto::UNDEFINED;
437 caffe2::TensorProto::UNDEFINED,
440 "' dtype is not supported");
443 c2_op.ops.Get(0).arg().size(),
445 "Unexpected number of attributes in 'Cast'");
446 c2_op.ops.Mutable(0)->mutable_arg(0)->set_i(c2_dtype);
451 Caffe2Ops Caffe2Backend::CreateConstant(
453 const ConversionContext& ctx) {
454 CAFFE_ENFORCE_EQ(onnx_node->node.output_size(), 1);
457 auto* c2_op = ret.ops.Add();
458 const auto* value = onnx_node->attributes.get<
const TensorProto*>(
"value");
459 BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0));
464 Caffe2Ops Caffe2Backend::CreateConstantOfShape(
466 const ConversionContext& ctx) {
467 CAFFE_ENFORCE_EQ(onnx_node->node.input_size(), 1);
468 CAFFE_ENFORCE_EQ(onnx_node->node.output_size(), 1);
471 auto* c2_op = ret.ops.Add();
472 const auto* value = onnx_node->attributes.get<
const TensorProto*>(
"value");
474 BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0), onnx_node->node.input(0));
476 c2_op->set_type(
"ConstantFill");
477 c2_op->add_input(onnx_node->node.input(0));
478 c2_op->add_output(onnx_node->node.output(0));
479 auto c2_input_as_shape = c2_op->add_arg();
480 c2_input_as_shape->set_name(
"input_as_shape");
481 c2_input_as_shape->set_i(1);
517 Caffe2Ops Caffe2Backend::CreateConvPoolOpBase(
519 const ConversionContext& ctx) {
520 const auto& node = onnx_node->node;
521 auto& attributes = onnx_node->attributes;
522 if (node.op_type().find(
"Global") == 0) {
523 auto* attr = attributes.AddRewrittenAttribute(
"global_pooling");
527 if (attributes.HasAttribute(
"kernel_shape") &&
528 attributes.HasAttribute(
"pads")) {
531 .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>(
535 .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>(
537 if (kernel_shape.size() == pads.size()) {
539 auto* attr = attributes.AddRewrittenAttribute(
"pads");
540 attr->mutable_ints()->CopyFrom(pads);
541 attr->mutable_ints()->MergeFrom(pads);
545 return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
548 Caffe2Ops Caffe2Backend::CreateReshape(
550 const ConversionContext& ctx) {
551 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
552 CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1);
553 auto* op = c2_op.ops.Mutable(0);
554 op->add_output(dummy_->NewDummyName());
559 Caffe2Ops Caffe2Backend::CreateRandomNormal(
561 const ConversionContext& ctx) {
562 auto& attributes = onnx_node->attributes;
564 if (attributes.HasAttribute(
"seed")) {
565 CAFFE_THROW(
"Caffe2 GaussianFill does not support random seed");
568 if (attributes.HasAttribute(
"dtype")) {
569 if (attributes.get<int64_t>(
"dtype") != TensorProto::FLOAT) {
570 CAFFE_THROW(
"Caffe2 GaussianFill only support FLOAT dtype");
572 attributes.remove(
"dtype");
574 if (attributes.HasAttribute(
"scale")) {
575 auto scale = attributes.get<
float>(
"scale");
576 auto* attr = attributes.AddRewrittenAttribute(
"std");
578 attributes.remove(
"scale");
580 return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
583 Caffe2Ops Caffe2Backend::CreateReciprocal(
585 const ConversionContext& ctx) {
586 const auto& node = onnx_node->node;
587 if (node.input_size() != 1 || node.output_size() != 1) {
588 CAFFE_THROW(
"Caffe2 Reciprocal should have 1 input and 1 output");
592 auto* c2_op = ret.ops.Add();
595 exponent.set_name(
"exponent");
596 exponent.set_f(-1.0);
597 BuildOperator(c2_op,
"Pow", {node.input(0)}, {node.output(0)}, {exponent});
601 Caffe2Ops Caffe2Backend::CreateGather(
603 const ConversionContext& ctx) {
604 const auto& node = onnx_node->node;
605 if (node.input_size() < 2 || node.output_size() < 1) {
606 CAFFE_THROW(
"Caffe2 Gather should have 2 inputs and 1 output");
610 auto* c2_op = ret.ops.Add();
612 std::vector<std::string> inputs;
613 inputs.emplace_back(node.input(0));
614 inputs.emplace_back(node.input(1));
615 std::vector<std::string> outputs;
616 outputs.emplace_back(node.output(0));
618 auto axis = onnx_node->attributes.get<int64_t>(
"axis", 0L);
620 BuildOperator(c2_op,
"Gather", inputs, outputs);
621 }
else if (axis == 1) {
622 BuildOperator(c2_op,
"BatchGather", inputs, outputs);
625 "Caffe2 only supports Gather with axis being 0 or 1, ",
633 Caffe2Ops Caffe2Backend::CreateGemm(
635 const ConversionContext& ctx) {
636 const auto& node = onnx_node->node;
637 if (node.input_size() < 3 || node.output_size() < 1) {
638 CAFFE_THROW(
"Caffe2 Gemm should have 3 inputs and 1 output");
642 auto input_a = node.input(0);
643 auto input_b = node.input(1);
644 auto input_c = node.input(2);
645 auto output = node.output(0);
647 auto alpha = onnx_node->attributes.get<
float>(
"alpha", 1.0);
648 auto beta = onnx_node->attributes.get<
float>(
"beta", 1.0);
649 if (!AlmostEqual(alpha, 1)) {
650 auto scaled_a = dummy_->NewDummyName();
652 scale.set_name(
"scale");
655 auto* c2_op = ret.ops.Add();
656 BuildOperator(c2_op,
"Scale", {input_a}, {scaled_a}, {scale});
659 if (!AlmostEqual(beta, 1)) {
660 auto scaled_c = dummy_->NewDummyName();
662 scale.set_name(
"scale");
665 auto* c2_op = ret.ops.Add();
666 BuildOperator(c2_op,
"Scale", {input_c}, {scaled_c}, {scale});
670 auto trans_a = onnx_node->attributes.get<int64_t>(
"transA", 0L);
671 auto trans_b = onnx_node->attributes.get<int64_t>(
"transB", 0L);
674 onnx_node->attributes.get<int64_t>(
"broadcast",
675 (ctx.opset_version() > 6) ? 1L : 0L);
679 auto check_fc = [&]() ->
bool {
680 const auto input_c_vi_iter = ctx.value_infos().find(node.input(2));
682 if (input_c_vi_iter == ctx.value_infos().end()) {
686 const auto input_c_shape =
687 input_c_vi_iter->second.type().tensor_type().shape();
689 if (input_c_shape.dim_size() != 1) {
694 if (input_c_shape.dim(0).dim_value() == 1) {
695 const auto input_b_vi_iter = ctx.value_infos().find(node.input(1));
698 if (input_b_vi_iter == ctx.value_infos().end()) {
701 const auto input_b_shape =
702 input_b_vi_iter->second.type().tensor_type().shape();
703 int input_b_last_dim_index = (trans_b) ? 0 : 1;
705 if (input_b_shape.dim_size() <= input_b_last_dim_index ||
706 input_b_shape.dim(input_b_last_dim_index).dim_value() != 1) {
714 if (!trans_a && broadcast && check_fc()) {
715 auto* c2_op = ret.ops.Add();
717 BuildOperator(c2_op,
"FC", {input_a, input_b, input_c}, {output});
719 BuildOperator(c2_op,
"FCTransposed", {input_a, input_b, input_c}, {output});
722 auto ab = dummy_->NewDummyName();
724 arg_trans_a.set_name(
"trans_a");
725 arg_trans_a.set_i(trans_a);
727 arg_trans_b.set_name(
"trans_b");
728 arg_trans_b.set_i(trans_b);
730 auto* c2_op = ret.ops.Add();
732 c2_op,
"MatMul", {input_a, input_b}, {ab}, {arg_trans_a, arg_trans_b});
733 c2_op = ret.ops.Add();
734 if (ctx.opset_version() >= 7) {
735 BuildOperator(c2_op,
"Add", {ab, input_c}, {output});
738 arg_broadcast.set_name(
"broadcast");
739 arg_broadcast.set_i(broadcast);
740 BuildOperator(c2_op,
"Add", {ab, input_c}, {output}, {arg_broadcast});
747 Caffe2Ops Caffe2Backend::CreatePad(
749 const ConversionContext& ctx) {
750 auto& attributes = onnx_node->attributes;
751 ::google::protobuf::RepeatedField<::google::protobuf::int64> pads;
752 std::string pad_name = ctx.opset_version() < 2 ?
"paddings" :
"pads";
754 .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>(
757 std::stringstream ss;
759 for (
const auto& i : pads) {
766 for (
const auto i : pads) {
768 CAFFE_THROW(
"ONNX does not support negative pads in Pad, but get ", str);
774 if (!(pads.size() == 8 &&
775 (pads.Get(0) + pads.Get(1) + pads.Get(4) + pads.Get(5) == 0))) {
777 "Caffe2 only supports padding 2D Tensor, whereas padding is ", str);
781 auto* attr = attributes.AddRewrittenAttribute(pad_name);
782 attr->add_ints(pads.Get(2));
783 attr->add_ints(pads.Get(3));
784 attr->add_ints(pads.Get(6));
785 attr->add_ints(pads.Get(7));
787 return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
793 Caffe2Ops Caffe2Backend::CreateConcat(
795 const ConversionContext& ctx) {
796 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
797 CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1);
798 auto* op = c2_op.ops.Mutable(0);
799 op->add_output(dummy_->NewDummyName());
804 Caffe2Ops Caffe2Backend::CreateLogSoftmax(
806 const ConversionContext& ctx) {
807 const auto& node = onnx_node->node;
808 if (node.input_size() < 1 || node.output_size() < 1) {
809 CAFFE_THROW(
"LogSoftmax should have 1 input and 1 output");
811 auto axis = onnx_node->attributes.get<int64_t>(
"axis", 1L);
813 arg_axis.set_name(
"axis");
814 arg_axis.set_i(axis);
815 auto softmax_a = dummy_->NewDummyName();
818 auto* c2_op = ret.ops.Add();
819 BuildOperator(c2_op,
"Softmax", {node.input(0)}, {softmax_a}, {arg_axis});
820 c2_op = ret.ops.Add();
821 BuildOperator(c2_op,
"Log", {softmax_a}, {node.output(0)});
826 Caffe2Ops Caffe2Backend::CreateSlice(
828 const ConversionContext& ctx) {
829 auto op_tmp = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
830 CAFFE_ENFORCE_EQ(op_tmp.ops.size(), 1);
831 auto* op = op_tmp.ops.Mutable(0);
832 std::unordered_map<std::string, caffe2::Argument*> args;
833 for (
auto& arg : *op->mutable_arg()) {
834 args.emplace(arg.name(), &arg);
838 starts_vals.set_name(
"values");
839 auto pos = args.find(
"starts");
840 if (pos != args.end()) {
841 for (
auto i : pos->second->ints()) {
842 starts_vals.add_ints(i < 0 ? i - 1 : i);
848 ends_vals.set_name(
"values");
849 pos = args.find(
"ends");
850 if (pos != args.end()) {
851 for (
auto i : pos->second->ints()) {
852 if (i == std::numeric_limits<int64_t>::max()) {
853 ends_vals.add_ints(-1);
855 ends_vals.add_ints(i < 0 ? i - 1 : i);
862 axes_vals.set_name(
"values");
863 pos = args.find(
"axes");
864 if (pos != args.end()) {
865 for (
auto i : pos->second->ints()) {
866 axes_vals.add_ints(i);
870 auto ndim = starts_vals.ints_size();
871 for (int64_t i = 0; i < ndim; ++i) {
872 axes_vals.add_ints(i);
876 CAFFE_ENFORCE_GE(op->input_size(), 1);
877 auto data = op->input(0);
878 auto shape_tensor = dummy_->NewDummyName();
881 auto* c2_op = ret.ops.Add();
882 BuildOperator(c2_op,
"Shape", {data}, {shape_tensor});
884 auto axes_tensor = dummy_->NewDummyName();
885 c2_op = ret.ops.Add();
888 shape.set_name(
"shape");
889 shape.add_ints(axes_vals.ints_size());
891 c2_op,
"GivenTensorIntFill", {}, {axes_tensor}, {shape, axes_vals});
894 auto starts_vals_tensor = dummy_->NewDummyName();
895 auto starts_tensor = dummy_->NewDummyName();
896 c2_op = ret.ops.Add();
899 shape_starts.set_name(
"shape");
900 shape_starts.add_ints(starts_vals.ints_size());
903 "GivenTensorInt64Fill",
905 {starts_vals_tensor},
906 {shape_starts, starts_vals});
910 dtype.set_name(
"dtype");
911 dtype.set_i(static_cast<int64_t>(caffe2::TensorProto::INT64));
913 constant.set_name(
"value");
915 c2_op = ret.ops.Add();
922 c2_op = ret.ops.Add();
926 {starts_tensor, axes_tensor, starts_vals_tensor},
931 to.set_i(static_cast<int64_t>(caffe2::TensorProto::INT32));
933 auto ends_vals_tensor = dummy_->NewDummyName();
934 auto ends_tensor = dummy_->NewDummyName();
935 c2_op = ret.ops.Add();
938 shape_ends.set_name(
"shape");
939 shape_ends.add_ints(ends_vals.ints_size());
942 "GivenTensorInt64Fill",
945 {shape_ends, ends_vals});
949 c2_op = ret.ops.Add();
951 c2_op,
"ConstantFill", {shape_tensor}, {ends_tensor}, {dtype, constant});
952 c2_op = ret.ops.Add();
956 {ends_tensor, axes_tensor, ends_vals_tensor},
960 c2_op = ret.ops.Add();
961 c2_op->CopyFrom(*op);
962 c2_op->mutable_input()->Clear();
963 c2_op->add_input(data);
964 c2_op->add_input(starts_tensor);
965 c2_op->add_input(ends_tensor);
966 c2_op->mutable_arg()->Clear();
967 for (
const auto& kv : args) {
968 c2_op->add_arg()->CopyFrom(*kv.second);
980 std::string Caffe2Backend::PreprocessSliceIndexTensor(OnnxNode* onnx_node,
982 std::string indices_tensor,
983 std::string axes_tensor,
984 std::string rank_tensor,
985 std::string zero_tensor,
986 std::string one_tensor,
988 auto indices_tensor_full = dummy_->NewDummyName();
992 value.set_name(
"value");
993 value.set_i(default_value);
995 dtype.set_name(
"dtype");
996 dtype.set_i(static_cast<int64_t>(caffe2::TensorProto::INT64));
998 input_as_shape.set_name(
"input_as_shape");
999 input_as_shape.set_i(1);
1000 auto c2_op = ret.ops.Add();
1001 BuildOperator(c2_op,
"ConstantFill", {rank_tensor}, {indices_tensor_full},
1002 {value, dtype, input_as_shape});
1006 auto lt_tensor = dummy_->NewDummyName();
1009 broadcast.set_name(
"broadcast");
1011 auto c2_op = ret.ops.Add();
1012 BuildOperator(c2_op,
"LT", {indices_tensor, zero_tensor}, {lt_tensor}, {broadcast});
1015 auto sub_one_tensor = dummy_->NewDummyName();
1018 broadcast.set_name(
"broadcast");
1020 auto c2_op = ret.ops.Add();
1021 BuildOperator(c2_op,
"Sub", {indices_tensor, one_tensor}, {sub_one_tensor}, {broadcast});
1024 auto indices_tensor_adjusted = dummy_->NewDummyName();
1025 auto c2_op = ret.ops.Add();
1026 BuildOperator(c2_op,
"Conditional", {lt_tensor, sub_one_tensor, indices_tensor}, {indices_tensor_adjusted}, {});
1029 c2_op = ret.ops.Add();
1030 BuildOperator(c2_op,
"ScatterAssign",
1031 {indices_tensor_full, axes_tensor, indices_tensor_adjusted},
1032 {indices_tensor_full});
1034 return indices_tensor_full;
1037 Caffe2Ops Caffe2Backend::CreateDynamicSlice(
1038 OnnxNode* onnx_node,
1039 const ConversionContext& ctx) {
1040 auto op_tmp = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1041 CAFFE_ENFORCE_EQ(op_tmp.ops.size(), 1);
1042 auto* op = op_tmp.ops.Mutable(0);
1043 std::unordered_map<std::string, caffe2::Argument*> args;
1044 for (
auto& arg : *op->mutable_arg()) {
1045 args.emplace(arg.name(), &arg);
1048 CAFFE_ENFORCE_GE(op->input_size(), 1);
1049 auto data = op->input(0);
1053 auto* c2_op = ret.ops.Add();
1054 auto size_tensor = dummy_->NewDummyName();
1055 BuildOperator(c2_op,
"Shape", {data}, {size_tensor});
1059 c2_op = ret.ops.Add();
1060 auto rank_tensor = dummy_->NewDummyName();
1061 BuildOperator(c2_op,
"Shape", {size_tensor}, {rank_tensor});
1065 std::string axes_tensor;
1066 if (onnx_node->node.input_size() > 2) {
1067 axes_tensor = onnx_node->node.input(3);
1069 axes_tensor = dummy_->NewDummyName();
1070 auto* c2_op = ret.ops.Add();
1071 BuildOperator(c2_op,
"Range", {rank_tensor}, {axes_tensor}, {});
1075 auto define_integer_constant = [
this, &ret](
int val) {
1077 value.set_name(
"value");
1080 dtype.set_name(
"dtype");
1081 dtype.set_i(static_cast<int64_t>(caffe2::TensorProto::INT64));
1083 shape.set_name(
"shape");
1085 auto c2_op = ret.ops.Add();
1086 auto name = dummy_->NewDummyName();
1087 BuildOperator(c2_op,
"ConstantFill", {}, {name},
1088 {value, dtype, shape});
1092 auto zero_tensor = define_integer_constant(0);
1093 auto one_tensor = define_integer_constant(1);
1095 auto starts_tensor_full = PreprocessSliceIndexTensor(onnx_node,
1097 onnx_node->node.input(1),
1104 auto ends_tensor_full = PreprocessSliceIndexTensor(onnx_node,
1106 onnx_node->node.input(2),
1114 c2_op = ret.ops.Add();
1115 c2_op->CopyFrom(*op);
1116 c2_op->mutable_input()->Clear();
1117 c2_op->add_input(data);
1118 c2_op->add_input(starts_tensor_full);
1119 c2_op->add_input(ends_tensor_full);
1120 c2_op->mutable_arg()->Clear();
1121 for (
const auto& kv : args) {
1122 c2_op->add_arg()->CopyFrom(*kv.second);
1128 Caffe2Ops Caffe2Backend::CreateBatchNormalization(
1129 OnnxNode* onnx_node,
1130 const ConversionContext& ctx) {
1131 auto& attributes = onnx_node->attributes;
1133 if (ctx.opset_version() < 6) {
1134 attributes.remove(
"consumed_inputs");
1137 if (ctx.opset_version() >= 7) {
1138 auto* attr = attributes.AddRewrittenAttribute(
"is_test");
1142 if (attributes.HasAttribute(
"spatial") && attributes.get<int64_t>(
"spatial") == 1) {
1143 attributes.remove(
"spatial");
1146 return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1149 Caffe2Ops Caffe2Backend::CreateSplit(
1150 OnnxNode* onnx_node,
1151 const ConversionContext& ctx) {
1152 auto& attributes = onnx_node->attributes;
1153 if (!attributes.HasAttribute(
"axis")) {
1154 auto* attr = attributes.AddRewrittenAttribute(
"axis");
1158 return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1161 Caffe2Ops Caffe2Backend::CreateMatMul(
1162 OnnxNode* onnx_node,
1163 const ConversionContext& ctx) {
1164 const auto& node = onnx_node->node;
1165 if (node.input_size() != 2) {
1166 CAFFE_THROW(
"MatMul should have 2 inputs");
1169 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1170 CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1);
1171 auto* op = c2_op.ops.Mutable(0);
1172 auto* broadcast_arg = op->add_arg();
1173 broadcast_arg->set_name(
"broadcast");
1174 broadcast_arg->set_i(1);
1179 Caffe2Ops Caffe2Backend::CreateUpsample(
1180 OnnxNode* onnx_node,
1181 const ConversionContext& ctx) {
1182 auto& attributes = onnx_node->attributes;
1183 attributes.remove(
"mode");
1185 if (ctx.opset_version() >= 7 && ctx.opset_version() < 9) {
1186 const auto& scales = attributes.get<::google::protobuf::RepeatedField<float>>(
"scales");
1187 if (scales.size() != 4) {
1188 CAFFE_THROW(
"The scales argument should have size 4");
1189 }
else if (!AlmostEqual(scales.Get(0), 1) || !AlmostEqual(scales.Get(1), 1)) {
1190 CAFFE_THROW(
"The first two elements in the scales argument must be 1");
1192 attributes.remove(
"scales");
1193 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1194 auto* op = c2_op.ops.Mutable(0);
1195 auto* c2_height = op->add_arg();
1196 c2_height->set_name(
"height_scale");
1197 c2_height->set_f(scales.Get(2));
1198 auto* c2_width = op->add_arg();
1199 c2_width->set_name(
"width_scale");
1200 c2_width->set_f(scales.Get(3));
1202 }
else if (ctx.opset_version() >= 9) {
1203 const auto& node = onnx_node->node;
1204 if (node.input_size() != 2) {
1205 CAFFE_THROW(
"Expects 2 input in upsample after onnx version 9");
1210 auto* c2_op = ret.ops.Add();
1211 auto sliced_input = dummy_->NewDummyName();
1213 arg_starts.set_name(
"starts");
1214 arg_starts.add_ints(2);
1215 arg_ends.set_name(
"ends");
1216 arg_ends.add_ints(-1);
1222 {arg_starts, arg_ends});
1225 c2_op = ret.ops.Add();
1229 {node.input(0), sliced_input},
1234 return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1237 Caffe2Ops Caffe2Backend::CreateDropout(
1238 OnnxNode* onnx_node,
1239 const ConversionContext& ctx) {
1240 if (ctx.opset_version() >= 7) {
1241 auto& attributes = onnx_node->attributes;
1242 auto* attr = attributes.AddRewrittenAttribute(
"is_test");
1246 return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1249 Caffe2Ops Caffe2Backend::CreateLRN(
1250 OnnxNode* onnx_node,
1251 const ConversionContext& ctx) {
1252 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1253 const auto& attributes = onnx_node->attributes;
1254 if (!attributes.HasAttribute(
"alpha")) {
1255 auto* arg = c2_op.ops.Mutable(0)->add_arg();
1256 arg->set_name(
"alpha");
1259 if (!attributes.HasAttribute(
"beta")) {
1260 auto* arg = c2_op.ops.Mutable(0)->add_arg();
1261 arg->set_name(
"beta");
1270 std::unordered_set<std::string>
1271 Caffe2Backend::AllNamesInGraph(
const GraphProto &graph) {
1272 std::unordered_set<std::string> names;
1274 for (
const auto& input : graph.input()) {
1275 names.emplace(input.name());
1277 for (
const auto& output : graph.output()) {
1278 names.emplace(output.name());
1280 for (
const auto& node : graph.node()) {
1281 for (
const auto& n : node.input()) {
1284 for (
const auto& n : node.output()) {
1302 Caffe2Ops Caffe2Backend::CommonOnnxNodeToCaffe2Ops(
1303 OnnxNode* onnx_node,
1304 const ConversionContext& ctx) {
1306 auto* c2_op = ret.ops.Add();
1308 const auto& node = onnx_node->node;
1309 c2_op->mutable_input()->MergeFrom(node.input());
1310 c2_op->mutable_output()->MergeFrom(node.output());
1311 c2_op->set_name(node.name());
1313 const auto onnx_op_type = node.op_type();
1314 auto broken_version = caffe2::get_default(
1315 get_broken_operators(), onnx_op_type, std::numeric_limits<int>::max());
1316 if (broken_version <= ctx.opset_version()) {
1318 "Don't know how to translate op ",
1320 " in ONNX operator set v",
1321 ctx.opset_version(),
1322 " (I only support prior to v",
1326 caffe2::get_default(get_renamed_operators(), onnx_op_type, onnx_op_type));
1327 if (!IsOperator(c2_op->type())) {
1329 "Don't know how to translate op ", onnx_op_type);
1332 auto mapper = [&,
this](
const std::string& k) {
1333 const auto it = get_per_op_renamed_attrs().find(onnx_op_type);
1334 if (it != get_per_op_renamed_attrs().end()) {
1335 const auto it_op = it->second.find(k);
1336 if (it_op != it->second.end()) {
1337 return it_op->second;
1340 const auto it_global = get_renamed_attrs().find(k);
1341 if (it_global != get_renamed_attrs().end()) {
1342 return it_global->second;
1346 c2_op->mutable_arg()->MergeFrom(
1347 onnx_node->attributes.OnnxAttrToCaffe2Arg(mapper));
1352 Caffe2Ops Caffe2Backend::ConvertNode(
1353 const std::string& node_str,
1354 const ConversionContext& ctx) {
1355 ::google::protobuf::RepeatedPtrField<NodeProto> nodes;
1356 auto* n = nodes.Add();
1357 ParseProtoFromLargeString(node_str, n);
1358 ModelProto init_model;
1359 ModelProto pred_model;
1360 OnnxNode onnx_node = OnnxNode(nodes.Get(0));
1361 return OnnxNodeToCaffe2Ops(init_model, pred_model, ctx, &onnx_node);
1364 void Caffe2Backend::CheckOpSchemaArguments(
1366 const caffe2::OperatorDef& op) {
1367 const auto& schema_args = schema.args();
1368 if (schema_args.size() > 0){
1369 std::vector<std::string> argnames;
1371 schema_args.begin(),
1373 std::back_inserter(argnames),
1376 for (
const auto& arg : op.arg()) {
1377 if (std::count(argnames.begin(), argnames.end(), arg.name()) == 0) {
1379 "Don't know how to map unexpected argument ",
1387 VLOG(2) <<
"Operator " << op.type() <<
" does not declare arguments in its schema. Please file a Caffe2 issue.";
1391 Caffe2Ops Caffe2Backend::OnnxNodeToCaffe2Ops(
1392 const ModelProto& init_model,
1393 const ModelProto& pred_model,
1394 const ConversionContext& ctx,
1395 OnnxNode* onnx_node) {
1397 if (get_special_operators().count(onnx_node->node.op_type())) {
1398 res = (this->*get_special_operators().at(onnx_node->node.op_type()))(
1401 res = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
1404 for (
const auto& result_op: res.ops){
1405 const auto* schema = OpSchemaRegistry::Schema(result_op.type());
1407 CheckOpSchemaArguments(*schema, result_op);
1409 CAFFE_THROW(
"Caffe2 has no such operator, could not find schema for ", result_op.type());
1415 void Caffe2Backend::OnnxToCaffe2(
1416 caffe2::NetDef* init_net,
1417 caffe2::NetDef* pred_net,
1418 const ModelProto& onnx_model,
1419 const std::string& device,
1421 bool include_initializers,
1422 const std::vector<Caffe2Ops>& extras) {
1423 auto device_option = GetDeviceOption(Device(device));
1426 ModelProto init_model = OptimizeOnnx(onnx_model,
true);
1427 ModelProto pred_model = OptimizeOnnx(onnx_model,
false);
1429 ModelProto init_model = ModelProto();
1430 ModelProto pred_model = onnx_model;
1431 pred_model.mutable_graph()->mutable_initializer()->Clear();
1434 init_net->set_name(onnx_model.graph().name() +
"_init");
1435 pred_net->set_name(onnx_model.graph().name() +
"_predict");
1438 if (include_initializers) {
1439 for (
const auto& tp : onnx_model.graph().initializer()) {
1440 auto* c2_op = init_net->add_op();
1441 BuildTensorFillingOp(c2_op, tp);
1445 auto name_set = AllNamesInGraph(init_model.graph());
1446 auto name_set_pred = AllNamesInGraph(pred_model.graph());
1447 name_set.insert(name_set_pred.begin(), name_set_pred.end());
1448 dummy_->Reset(name_set);
1450 ValueInfoMap graph_value_infos{};
1451 for (
const auto& vi : pred_model.graph().input()) {
1452 graph_value_infos[vi.name()].CopyFrom(vi);
1454 for (
const auto& vi : pred_model.graph().output()) {
1455 graph_value_infos[vi.name()].CopyFrom(vi);
1457 for (
const auto& vi : pred_model.graph().value_info()) {
1458 graph_value_infos[vi.name()].CopyFrom(vi);
1461 size_t idx_extra = 0;
1462 auto converter = [&](
const ModelProto& model, caffe2::NetDef* net)
mutable {
1463 net->mutable_device_option()->CopyFrom(device_option);
1464 for (
const auto& node : model.graph().node()) {
1465 auto* init_net_tmp = include_initializers ? init_net : net;
1471 if (get_rnn_operators().count(node.op_type())) {
1472 if (idx_extra < extras.size()) {
1473 const auto& c2ops = extras[idx_extra++];
1474 for (
const auto& op : c2ops.init_ops) {
1475 UpdateNames(dummy_, op);
1477 init_net_tmp->mutable_op()->MergeFrom(c2ops.init_ops);
1478 for (
const auto& op : c2ops.ops) {
1479 UpdateNames(dummy_, op);
1481 net->mutable_op()->MergeFrom(c2ops.ops);
1482 for (
const auto& input : c2ops.interface_blobs) {
1483 dummy_->AddName(input);
1485 net->mutable_external_input()->MergeFrom(c2ops.interface_blobs);
1488 "Don't know how to convert ",
1490 " without enough extra preconverted string");
1493 ValueInfoMap value_infos{};
1494 for (
const auto& name : node.input()) {
1495 auto iter = graph_value_infos.find(name);
1496 if (iter != graph_value_infos.end()) {
1497 value_infos[name].CopyFrom(iter->second);
1500 auto onnx_node = OnnxNode(node);
1501 auto c2ops = OnnxNodeToCaffe2Ops(
1502 init_model, pred_model, {value_infos, opset_version}, &onnx_node);
1503 init_net_tmp->mutable_op()->MergeFrom(c2ops.init_ops);
1504 net->mutable_op()->MergeFrom(c2ops.ops);
1505 net->mutable_external_input()->MergeFrom(c2ops.interface_blobs);
1509 for (
const auto& value : model.graph().output()) {
1510 net->add_external_output(value.name());
1512 for (
const auto& value : model.graph().input()) {
1513 net->add_external_input(value.name());
1517 converter(init_model, init_net);
1518 converter(pred_model, pred_net);
1521 Caffe2BackendRep* Caffe2Backend::Prepare(
1522 const std::string& onnx_model_str,
1523 const std::string& device,
1524 const std::vector<Caffe2Ops>& extras) {
1525 Caffe2BackendRep* rep =
new Caffe2BackendRep();
1526 ModelProto onnx_model;
1527 ParseProtoFromLargeString(onnx_model_str, &onnx_model);
1530 ::ONNX_NAMESPACE::checker::check_model(onnx_model);
1533 int opset_version = -1;
1534 for (
const auto& imp : onnx_model.opset_import()) {
1535 if ((!imp.has_domain()) || imp.domain().empty()) {
1536 opset_version = imp.version();
1537 if (opset_version > kKnownOpsetVersion) {
1539 <<
"This version of onnx-caffe2 targets ONNX operator set version " 1540 << kKnownOpsetVersion
1541 <<
", but the model we are trying to import uses version " 1542 << opset_version <<
". We will try to import it anyway, " 1543 <<
"but if the model uses operators which had BC-breaking changes " 1544 "in the intervening versions, import will fail." 1548 std::cout <<
"Unrecognized operator set " << opset_version << std::endl;
1551 if (opset_version < 0) {
1552 if (onnx_model.ir_version() >= 0x00000003) {
1554 "Model with IR version >= 3 did not specify ONNX operator set " 1555 "version (onnx-caffe2 requires it)");
1572 auto& uninitialized_inputs = rep->uninitialized_inputs();
1573 std::unordered_set<std::string> initialized_inputs;
1574 for (
const auto& tp : onnx_model.graph().initializer()) {
1575 initialized_inputs.emplace(tp.name());
1577 for (
const auto& input : onnx_model.graph().input()) {
1578 if (!initialized_inputs.count(input.name())) {
1579 uninitialized_inputs.emplace_back(input.name());
1586 template <
typename T>
1587 void ConvertIntegralValueToCaffe2(caffe2::OperatorDef* c2_op,
1589 const TensorProto& onnx_tensor) {
1591 onnx_tensor.data_type() == TensorProto::BOOL ?
"GivenTensorBoolFill" 1592 :
"GivenTensorIntFill");
1593 ::google::protobuf::RepeatedField<T> tmp;
1594 const ::google::protobuf::RepeatedField<T>* src =
1596 bool converted = TryConvertingTensorRawValues<T>(onnx_tensor, &tmp);
1598 for (
const auto i : *src) {
1599 c2_values->add_ints(i);
1602 const ::google::protobuf::RepeatedField<::google::protobuf::int32> *int32_src = \
1603 &onnx_tensor.int32_data();
1604 for (
const auto i : *int32_src) {
1605 c2_values->add_ints(i);
1611 void ConvertIntegralValueToCaffe2<::google::protobuf::int64>(caffe2::OperatorDef* c2_op,
1613 const TensorProto& onnx_tensor) {
1614 c2_op->set_type(
"GivenTensorInt64Fill");
1615 auto* ints = c2_values->mutable_ints();
1616 if (!TryConvertingTensorRawValues<::google::protobuf::int64>(
1617 onnx_tensor, ints)) {
1618 ints->CopyFrom(onnx_tensor.int64_data());
1623 void ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(caffe2::OperatorDef* c2_op,
1625 const TensorProto& onnx_tensor) {
1626 c2_op->set_type(
"GivenTensorInt64Fill");
1627 ::google::protobuf::RepeatedField<::google::protobuf::uint64> tmp;
1628 const ::google::protobuf::RepeatedField<::google::protobuf::uint64>* src =
1630 if (!TryConvertingTensorRawValues<::google::protobuf::uint64>(
1631 onnx_tensor, &tmp)) {
1632 src = &onnx_tensor.uint64_data();
1634 for (
const auto i : *src) {
1635 c2_values->add_ints(i);
1639 void Caffe2Backend::BuildTensorFillingOp(
1640 caffe2::OperatorDef* c2_op,
1641 const TensorProto& onnx_tensor,
1642 const std::string& output_name,
1643 const std::string& shape_name) {
1644 auto fill_name = output_name.empty() ? onnx_tensor.name() : output_name;
1645 CAFFE_ENFORCE(!fill_name.empty());
1647 if (onnx_tensor.has_segment()) {
1648 CAFFE_THROW(
"Currently not supporting loading segments.");
1651 auto* c2_values = c2_op->add_arg();
1654 if (shape_name.empty()) {
1656 c2_values->set_name(
"values");
1657 if (onnx_tensor.data_type() == TensorProto::FLOAT) {
1658 c2_op->set_type(
"GivenTensorFill");
1659 auto* floats = c2_values->mutable_floats();
1660 if (!TryConvertingTensorRawValues<float>(onnx_tensor, floats)) {
1661 floats->CopyFrom(onnx_tensor.float_data());
1663 }
else if (onnx_tensor.data_type() == TensorProto::DOUBLE) {
1664 c2_op->set_type(
"GivenTensorDoubleFill");
1665 ::google::protobuf::RepeatedField<double> tmp;
1666 const ::google::protobuf::RepeatedField<double>* src = &tmp;
1667 if (!TryConvertingTensorRawValues<double>(onnx_tensor, &tmp)) {
1668 src = &onnx_tensor.double_data();
1670 for (
const auto i : *src) {
1671 c2_values->add_floats(i);
1673 }
else if (onnx_tensor.data_type() == TensorProto::INT64) {
1674 ConvertIntegralValueToCaffe2<::google::protobuf::int64>(c2_op, c2_values, onnx_tensor);
1675 }
else if (onnx_tensor.data_type() == TensorProto::UINT32) {
1676 ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(c2_op, c2_values, onnx_tensor);
1677 }
else if (onnx_tensor.data_type() == TensorProto::BOOL) {
1678 ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor);
1679 }
else if (onnx_tensor.data_type() == TensorProto::UINT8) {
1680 ConvertIntegralValueToCaffe2<::google::protobuf::uint8>(c2_op, c2_values, onnx_tensor);
1681 }
else if (onnx_tensor.data_type() == TensorProto::INT8) {
1682 ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor);
1683 }
else if (onnx_tensor.data_type() == TensorProto::UINT16) {
1684 ConvertIntegralValueToCaffe2<::google::protobuf::uint16>(c2_op, c2_values, onnx_tensor);
1685 }
else if (onnx_tensor.data_type() == TensorProto::INT16) {
1686 ConvertIntegralValueToCaffe2<::google::protobuf::int16>(c2_op, c2_values, onnx_tensor);
1687 }
else if (onnx_tensor.data_type() == TensorProto::INT32) {
1688 ConvertIntegralValueToCaffe2<::google::protobuf::int32>(c2_op, c2_values, onnx_tensor);
1689 }
else if (onnx_tensor.data_type() == TensorProto::STRING) {
1690 c2_op->set_type(
"GivenTensorStringFill");
1691 auto* strings = c2_values->mutable_strings();
1692 strings->CopyFrom(onnx_tensor.string_data());
1694 CAFFE_THROW(
"unrecognized tensor type: ", onnx_tensor.data_type());
1696 auto* c2_shape = c2_op->add_arg();
1697 c2_shape->set_name(
"shape");
1698 for (
const auto d : onnx_tensor.dims()) {
1699 c2_shape->add_ints(d);
1703 for (
const auto d : onnx_tensor.dims()) {
1706 CAFFE_ENFORCE(value_size == 1);
1707 auto c2_input_as_shape = c2_op->add_arg();
1708 c2_input_as_shape->set_name(
"input_as_shape");
1709 c2_input_as_shape->set_i(1);
1710 c2_values->set_name(
"value");
1711 auto* c2_dtype = c2_op->add_arg();
1712 c2_dtype->set_name(
"dtype");
1713 if (onnx_tensor.data_type() == TensorProto::FLOAT) {
1714 c2_dtype->set_i(caffe2::TensorProto::FLOAT);
1715 if (onnx_tensor.float_data_size() > 0) {
1716 c2_values->set_f(onnx_tensor.float_data(0));
1718 CAFFE_ENFORCE(onnx_tensor.raw_data().size() ==
sizeof(float));
1720 memcpy(&f, onnx_tensor.raw_data().c_str(),
sizeof(float));
1721 c2_values->set_f(f);
1723 }
else if (onnx_tensor.data_type() == TensorProto::DOUBLE) {
1724 c2_dtype->set_i(caffe2::TensorProto::DOUBLE);
1725 if (onnx_tensor.double_data_size() > 0) {
1726 c2_values->set_f(static_cast<float>(onnx_tensor.double_data(0)));
1728 CAFFE_ENFORCE(onnx_tensor.raw_data().size() ==
sizeof(double));
1730 memcpy(&d, onnx_tensor.raw_data().c_str(),
sizeof(double));
1731 c2_values->set_f(static_cast<float>(d));
1733 }
else if (onnx_tensor.data_type() == TensorProto::INT64) {
1734 c2_dtype->set_i(caffe2::TensorProto::INT64);
1735 if (onnx_tensor.int64_data_size() > 0) {
1736 c2_values->set_i(onnx_tensor.int64_data(0));
1738 CAFFE_ENFORCE(onnx_tensor.raw_data().size() ==
sizeof(int64_t));
1740 memcpy(&i, onnx_tensor.raw_data().c_str(),
sizeof(int64_t));
1741 c2_values->set_i(i);
1743 }
else if (onnx_tensor.data_type() == TensorProto::INT32) {
1744 c2_dtype->set_i(caffe2::TensorProto::INT32);
1745 if (onnx_tensor.int32_data_size() > 0) {
1746 c2_values->set_i(onnx_tensor.int32_data(0));
1748 CAFFE_ENFORCE(onnx_tensor.raw_data().size() ==
sizeof(int32_t));
1750 memcpy(&i, onnx_tensor.raw_data().c_str(),
sizeof(int32_t));
1751 c2_values->set_i(i);
1755 std::stringstream oss;
1756 oss <<
"Unsupported dtype: " << onnx_tensor.data_type();
1757 CAFFE_THROW(oss.str());
1760 c2_op->set_type(
"ConstantFill");
1761 c2_op->add_input(shape_name);
1764 c2_op->add_output(fill_name);
1767 bool Caffe2Backend::SupportOp(
const std::string type)
const {
1768 return get_special_operators().count(type);
A class to record the schema of an op.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...