Caffe2 - C++ API
A deep learning, cross platform ML framework
onnx_exporter.h
1 #pragma once
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/tensor.h"
5 #include "caffe2/onnx/helper.h"
6 #include "caffe2/proto/caffe2_pb.h"
7 #include "onnx/onnx_pb.h"
8 
9 #include <string>
10 #include <unordered_map>
11 #include <vector>
12 
13 namespace caffe2 {
14 namespace onnx {
15 
16 namespace {
17 using ::ONNX_NAMESPACE::AttributeProto;
18 using ::ONNX_NAMESPACE::GraphProto;
19 using ::ONNX_NAMESPACE::ModelProto;
20 using ::ONNX_NAMESPACE::NodeProto;
21 using ::ONNX_NAMESPACE::TensorProto;
22 } // namespace
23 
24 using ConvertedResult =
25  std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
26 
27 // Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
28 // output names for predict net.
29 CAFFE2_API std::unordered_map<std::string, std::string> SsaRewrite(
30  caffe2::NetDef* init_net,
31  caffe2::NetDef* pred_net);
32 
33 ::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
34  caffe2::TensorProto::DataType t);
35 
36 class CAFFE2_API OnnxExporter {
37  using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
38  const caffe2::OperatorDef&,
39  const std::unordered_map<std::string, caffe2::TensorShape>&);
40 
41  public:
42  OnnxExporter(DummyName* dummy = nullptr) {
43  if (dummy) {
44  dummy_ = std::shared_ptr<DummyName>(dummy, [](DummyName*) {});
45  } else {
46  dummy_ = std::make_shared<DummyName>();
47  }
48  }
49 
50  ConvertedResult Caffe2OpToOnnxNodes(
51  const caffe2::OperatorDef& def,
52  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
53 
54  void InitOpToTensorProto(const caffe2::OperatorDef& def, TensorProto* tensor);
55  private:
56  ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
57 
58  ConvertedResult CreateArgMaxMinOpNodes(
59  const caffe2::OperatorDef& def,
60  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
61 
62  ConvertedResult CreateBinaryElementwiseOpNodes(
63  const caffe2::OperatorDef& def,
64  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
65 
66  ConvertedResult CreateCastNodes(
67  const caffe2::OperatorDef& def,
68  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
69 
70  ConvertedResult CreateElementwiseLinearNodes(
71  const caffe2::OperatorDef& def,
72  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
73 
74  ConvertedResult CreateConvPoolNodes(
75  const caffe2::OperatorDef& def,
76  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
77 
78  ConvertedResult CreateGemmNodes(
79  const caffe2::OperatorDef& def,
80  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
81 
82  ConvertedResult CreateReshapeNodes(
83  const caffe2::OperatorDef& def,
84  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
85 
86  ConvertedResult CreateSliceNodes(
87  const caffe2::OperatorDef& def,
88  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
89 
90  ConvertedResult CreateChannelShuffleNodes(
91  const caffe2::OperatorDef& def,
92  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
93 
94  ConvertedResult CreateReduceMeanNodes(
95  const caffe2::OperatorDef& def,
96  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
97 
98  ConvertedResult CreateConcatNodes(
99  const caffe2::OperatorDef& def,
100  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
101 
102  ConvertedResult CreateMergeDimNodes(
103  const caffe2::OperatorDef& def,
104  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
105 
106  ConvertedResult CreateLrnNodes(
107  const caffe2::OperatorDef& def,
108  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
109 
110  ConvertedResult CreateUpsampleNodes(
111  const caffe2::OperatorDef& def,
112  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
113 
114  // \brief Check black listed arguemnts where we won't pass down when
115  // converting to ONNX node
116  bool IsBlackListed(const caffe2::Argument& arg);
117 
118  // \brief Convert Caffe2 argument to Onnx attribute
119  void CopyCaffe2ArgToOnnxAttr(
120  AttributeProto* attr,
121  const std::string& op_type,
122  const caffe2::Argument& arg);
123 
124  // LUT getters
125  const std::unordered_map<std::string, std::string>& get_renamed_operators()
126  const;
127  const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
128  const std::
129  unordered_map<std::string, std::unordered_map<std::string, std::string>>&
130  get_per_op_renamed_attrs() const;
131  const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
132  get_special_operators() const;
133 
134  // Dummy name generator
135  std::shared_ptr<DummyName> dummy_;
136 };
137 } // namespace onnx
138 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13