3 #include <unordered_map> 5 #include "onnx/onnx_pb.h" 7 #include "caffe2/core/context.h" 8 #include "caffe2/core/logging.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/onnx/onnxifi_graph_info.h" 11 #include "caffe2/onnx/onnxifi_init.h" 12 #include "caffe2/utils/string_utils.h" 16 template <
typename T,
typename Context>
20 TensorInfo(TensorInfo&&) =
default;
21 TensorInfo& operator=(TensorInfo&&) =
default;
22 std::vector<uint64_t> dims;
23 uint64_t onnxifi_type;
27 USE_OPERATOR_CONTEXT_FUNCTIONS;
30 lib_ = onnx::initOnnxifiLibrary();
31 backend_graph_map_ptr_ = onnx::getOnnxBackendGraphMap();
32 CAFFE_ENFORCE(lib_,
"Cannot initialize ONNXIFI library");
34 this->
template GetSingleArgument<std::string>(
"onnx_model",
"");
35 CAFFE_ENFORCE(!onnx_model_str.empty(),
"onnx_model cannot be empty");
39 this->
template GetRepeatedArgument<std::string>(
"input_names");
41 this->
template GetRepeatedArgument<std::string>(
"output_names");
42 CAFFE_ENFORCE_EQ(input_names_.size(), operator_def.input_size());
43 CAFFE_ENFORCE_EQ(output_names_.size(), operator_def.output_size());
44 for (
const auto& input : input_names_) {
45 input_desc_.push_back(onnxTensorDescriptorV1());
46 input_desc_.back().name = input.c_str();
49 for (
const auto& output : output_names_) {
50 output_desc_.push_back(onnxTensorDescriptorV1());
51 output_desc_.back().name = output.c_str();
54 const std::string key = c10::str(
"output_shape_hint_", output_idx);
55 auto output_shape_hint = this->
template GetRepeatedArgument<int>(key);
56 if (!output_shape_hint.empty()) {
58 info.onnxifi_type = output_shape_hint.front();
59 for (
size_t i = 1; i < output_shape_hint.size(); ++i) {
60 info.dims.push_back(output_shape_hint[i]);
62 output_shape_hints_.emplace(output_idx, std::move(info));
68 std::vector<uint64_t> property_pointers;
69 std::vector<int64_t> int_args;
70 std::vector<float> float_args;
71 BuildPropertyList(operator_def, &property_pointers, &int_args, &float_args);
77 this->
template GetRepeatedArgument<std::string>(
"initializers");
78 std::unordered_set<std::string> initializer_set;
79 for (
auto it = initializers.begin(); it != initializers.end(); ++it) {
81 initializer_set.emplace(key);
83 std::vector<std::string> weight_names;
84 std::vector<std::vector<uint64_t>> weight_shapes;
85 auto weight_descs = buildInitializationList(
86 ws, &initializer_set, &weight_names, &weight_shapes);
88 BuildBackendAndGraph(property_pointers, onnx_model_str, weight_descs);
92 backend_graph_shared_ptr_.reset();
93 backend_graph_map_ptr_->remove(op_id_string_);
96 bool RunOnDevice()
override;
99 uint64_t SetOutputShapeAndType(
int output_idx, std::vector<size_t>* dims) {
100 uint64_t type = ONNXIFI_DATATYPE_FLOAT32;
101 const auto it = output_shape_hints_.find(output_idx);
102 if (it != output_shape_hints_.end()) {
104 it->second.dims.begin(),
105 it->second.dims.end(),
106 std::back_inserter(*dims));
107 type = it->second.onnxifi_type;
112 void BuildPropertyList(
114 std::vector<uint64_t>* property_list,
115 std::vector<int64_t>* ,
116 std::vector<float>* ) {
117 property_list->push_back(ONNXIFI_BACKEND_PROPERTY_NONE);
120 void BuildBackendAndGraph(
121 const std::vector<uint64_t>& property_pointers,
122 const std::string& onnx_model_str,
123 const std::vector<onnxTensorDescriptorV1>& weight_descs) {
125 this->
template GetSingleArgument<std::string>(
"model_id",
"") +
":" +
126 this->
template GetSingleArgument<std::string>(
"net_pos",
"");
129 auto backend_index = this->
template GetSingleArgument<int>(
"backend_id", 0);
130 onnxifi_library* lib = lib_;
136 std::vector<onnxBackendID> backend_ids;
137 size_t num_backends{0};
139 lib->onnxGetBackendIDs(
nullptr, &num_backends),
140 ONNXIFI_STATUS_FALLBACK);
142 num_backends, 0,
"At least 1 onnxifi backend should be available");
146 "Backend idx out of bound: ",
150 backend_ids.resize(num_backends);
152 lib->onnxGetBackendIDs(backend_ids.data(), &num_backends),
153 ONNXIFI_STATUS_SUCCESS);
155 onnxBackendID backend_id = backend_ids[backend_index];
156 onnxBackend backend{
nullptr};
159 lib->onnxInitBackend(backend_id, property_pointers.data(), &backend),
160 ONNXIFI_STATUS_SUCCESS);
163 for (
auto i = 0; i < num_backends; ++i) {
164 if (i == backend_index) {
167 lib->onnxReleaseBackendID(backend_ids[i]);
169 onnxGraph graph{
nullptr};
174 onnx_model_str.size(),
175 (
const void*)(onnx_model_str.c_str()),
179 ONNXIFI_STATUS_SUCCESS);
181 return std::make_shared<onnx::BackendGraphInfo>(
182 backend_id, backend, graph, lib);
184 backend_graph_shared_ptr_ =
185 backend_graph_map_ptr_->insert(op_id_string_, creator);
187 backend_id_ = backend_graph_shared_ptr_->backend_id;
188 backend_ = backend_graph_shared_ptr_->backend;
189 graph_ = backend_graph_shared_ptr_->graph;
192 #ifdef ONNXIFI_ENABLE_EXT 193 onnxExtensionFunctionPointer p;
194 if (lib_->onnxGetExtensionFunctionAddress(
195 backend_id_,
"onnxSetIOAndRunGraphFunction", &p) !=
196 ONNXIFI_STATUS_SUCCESS) {
197 onnxSetIOAndRunGraphPointer_ =
nullptr;
200 onnxSetIOAndRunGraphPointer_ =
201 reinterpret_cast<decltype(onnxSetIOAndRunGraphPointer_)
>(p);
205 std::vector<onnxTensorDescriptorV1> buildInitializationList(
207 std::unordered_set<std::string>* initialization_list,
208 std::vector<std::string>* weight_names,
209 std::vector<std::vector<uint64_t>>* weight_shapes);
212 onnxifi_library* lib_{
nullptr};
213 onnx::OnnxBackendGraphMap* backend_graph_map_ptr_;
214 std::string op_id_string_;
216 onnxBackendID backend_id_{
nullptr};
217 onnxBackend backend_{
nullptr};
218 onnxGraph graph_{
nullptr};
219 onnx::SharedPtrBackendGraphInfo backend_graph_shared_ptr_;
222 std::vector<onnxTensorDescriptorV1> input_desc_;
223 std::vector<onnxTensorDescriptorV1> output_desc_;
225 #ifdef ONNXIFI_ENABLE_EXT 227 onnxStatus (*onnxSetIOAndRunGraphPointer_)(
230 const onnxTensorDescriptorV1*,
232 const onnxTensorDescriptorV1*,
234 onnxTraceEventList*);
241 std::vector<std::string> input_names_;
242 std::vector<std::string> output_names_;
244 std::vector<std::vector<uint64_t>> input_shapes_;
245 std::vector<std::vector<uint64_t>> output_shapes_;
248 std::unordered_map<int, TensorInfo> output_shape_hints_;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...