3 #include "caffe2/onnx/backend_rep.h" 4 #include "caffe2/onnx/device.h" 5 #include "caffe2/onnx/helper.h" 6 #include "caffe2/proto/caffe2_pb.h" 7 #include "onnx/onnx_pb.h" 11 #include <unordered_map> 12 #include <unordered_set> 14 constexpr
int kKnownOpsetVersion = 9;
19 using ::ONNX_NAMESPACE::AttributeProto;
20 using ::ONNX_NAMESPACE::GraphProto;
21 using ::ONNX_NAMESPACE::ModelProto;
22 using ::ONNX_NAMESPACE::NodeProto;
23 using ::ONNX_NAMESPACE::TensorProto;
24 using ::ONNX_NAMESPACE::ValueInfoProto;
26 using ValueInfoMap = std::unordered_map<std::string, ValueInfoProto>;
31 : value_infos_(value_infos), opset_version_(opset_version) {}
32 const ValueInfoMap& value_infos()
const {
35 int opset_version()
const {
36 return opset_version_;
40 const ValueInfoMap& value_infos_;
41 const int opset_version_;
48 ::google::protobuf::RepeatedPtrField<caffe2::OperatorDef> init_ops;
49 ::google::protobuf::RepeatedPtrField<caffe2::OperatorDef> ops;
50 ::google::protobuf::RepeatedPtrField<std::string> interface_blobs;
59 bool HasAttribute(
const std::string& key)
const {
60 return onnx_attrs_.count(key);
63 AttributeProto* AddRewrittenAttribute(
const std::string& key) {
64 auto tmp = rewritten_onnx_attrs_.emplace(key, AttributeProto());
65 auto& attr = tmp.first->second;
70 ::google::protobuf::RepeatedPtrField<caffe2::Argument> OnnxAttrToCaffe2Arg(
71 std::function<std::string(
const std::string&)> mapper)
const;
76 T get(
const std::string& key)
const;
79 T get(
const std::string& key,
const T& default_value)
const {
80 if (onnx_attrs_.count(key)) {
87 const AttributeProto*
remove(
const std::string& key) {
88 const AttributeProto* result =
nullptr;
89 auto iter = onnx_attrs_.find(key);
90 if (iter != onnx_attrs_.end()) {
91 result = iter->second;
92 onnx_attrs_.erase(iter);
98 std::unordered_map<std::string, const AttributeProto*> onnx_attrs_;
99 std::unordered_map<std::string, AttributeProto> rewritten_onnx_attrs_;
103 int64_t OnnxAttributes::get(
const std::string& key)
const;
105 float OnnxAttributes::get(
const std::string& key)
const;
108 ::google::protobuf::RepeatedPtrField<std::string> OnnxAttributes::get(
109 const std::string& key)
const;
112 ::google::protobuf::RepeatedField<::google::protobuf::int64>
113 OnnxAttributes::get(
const std::string& key)
const;
116 ::google::protobuf::RepeatedField<float>
117 OnnxAttributes::get(
const std::string& key)
const;
120 const TensorProto* OnnxAttributes::get(
const std::string& key)
const;
124 OnnxNode(
const NodeProto& node_in) : node(node_in), attributes(node_in) {}
126 const NodeProto& node;
140 dummy_ = std::shared_ptr<DummyName>(dummy, [](
DummyName *){});
142 dummy_ = std::make_shared<DummyName>();
147 const std::string& onnx_model_str,
148 const std::string& device,
149 const std::vector<Caffe2Ops>& extras);
151 bool SupportOp(
const std::string tyep)
const;
154 const std::string& node_str,
157 void BuildTensorFillingOp(
158 caffe2::OperatorDef* c2_op,
159 const TensorProto& onnx_tensor,
160 const std::string& output_name =
"",
161 const std::string& shape_name =
"");
164 using SpecialOpConverter =
168 caffe2::NetDef* init_net,
169 caffe2::NetDef* pred_net,
170 const ModelProto& onnx_model,
171 const std::string& device,
173 bool include_initializers,
174 const std::vector<Caffe2Ops>& extras);
176 void CheckOpSchemaArguments(
const caffe2::OpSchema& schema,
const caffe2::OperatorDef& op);
179 const ModelProto& init_model,
180 const ModelProto& pred_model,
184 std::unordered_set<std::string> AllNamesInGraph(
const GraphProto& graph);
220 std::string PreprocessSliceIndexTensor(
OnnxNode* onnx_node,
222 std::string indices_tensor,
223 std::string axes_tensor,
224 std::string rank_tensor,
225 std::string zero_tensor,
226 std::string one_tensor,
252 const std::unordered_map<std::string, std::string>& get_renamed_operators()
254 const std::unordered_set<std::string>& get_rnn_operators()
const;
255 const std::unordered_map<std::string, int>& get_broken_operators()
const;
256 const std::unordered_map<std::string, std::string>& get_renamed_attrs()
const;
258 unordered_map<std::string, std::unordered_map<std::string, std::string>>&
259 get_per_op_renamed_attrs()
const;
260 const std::unordered_map<std::string, Caffe2Backend::SpecialOpConverter>&
261 get_special_operators()
const;
264 std::shared_ptr<DummyName> dummy_;
A class to record the schema of an op.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...