5 #include <unordered_map> 8 #include "onnx/onnx_pb.h" 10 #include "caffe2/core/operator.h" 11 #include "caffe2/onnx/onnxifi_init.h" 12 #include "caffe2/opt/backend_transformer_base.h" 32 bool add_adjust_batch_ops{
true};
50 const std::vector<std::string>& weight_names,
51 const std::unordered_map<std::string, TensorShape>& shape_hints,
52 const std::unordered_set<int>& blacklisted_ops)
override;
60 caffe2::NetDef SubnetToOnnxifiOpViaOnnx(
61 const caffe2::NetDef& net,
62 const std::unordered_set<std::string>& weights_in_ws,
64 onnx::OnnxExporter* exporter,
65 ShapeInfoMap* shape_hints);
68 caffe2::NetDef SubnetToOnnxifiOpViaC2(
69 const caffe2::NetDef& net,
70 const std::unordered_set<std::string>& weights_in_ws,
71 const ShapeInfoMap& shape_hints);
74 OperatorDef BuildOnnxifiOp(
75 const std::string& onnx_model_str,
76 const std::unordered_map<std::string, TensorShape>& output_size_hints,
77 const std::unordered_set<std::string>& initialization_list,
78 const std::vector<std::string>& external_inputs,
79 const std::vector<std::string>& external_outputs);
82 NetDef TransformViaC2(
84 const std::unordered_set<std::string>& weights,
85 const std::unordered_set<int>& blacklisted_ops,
86 const ShapeInfoMap& shape_hints);
89 NetDef TransformViaOnnx(
92 const std::unordered_set<std::string>& weights,
93 const std::unordered_set<int>& blacklisted_ops,
94 ShapeInfoMap* shape_hints);
98 const caffe2::OperatorDef& op,
99 const ShapeInfoMap& shape_hints,
100 const std::unordered_set<int>& blacklisted_ops,
101 onnxBackendID backend_id)
const;
105 const caffe2::OperatorDef& op,
106 onnx::OnnxExporter* exporter,
107 const std::unordered_set<int>& blacklisted_ops,
108 onnxBackendID backend_id)
const;
113 void tieGatherAndSparseLengthsWeightedSumOps(
115 const ShapeInfoMap& shape_hints,
116 std::unordered_set<int>* blacklisted_ops)
const;
119 void applyFilteringRules(
121 const ShapeInfoMap& shape_hints,
122 std::unordered_set<int>* blacklisted_ops)
const;
131 onnxifi_library* lib_{
nullptr};
134 size_t num_backends_{0};
140 int onnxifi_op_id_{0};
143 std::string model_id_;
146 std::vector<onnxBackendID> backend_ids_;
149 std::unordered_map<std::string, TensorShape> shape_hints_onnx_;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...