1 #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h" 2 #include "caffe2/core/flags.h" 3 #include "caffe2/core/tensor.h" 5 #if defined(CAFFE2_USE_MPSCNN) && defined(C10_MOBILE) 6 #include "caffe2/mobile/contrib/ios/mpscnn/mpscnn.h" 9 C10_DECLARE_bool(caffe2_force_shared_col_buffer);
12 const caffe2::NetDef& predict_net,
13 bool disableMultithreadProcessing,
14 bool allowMetalOperators) {
15 caffe2::NetDef metal_predict_net;
16 bool usingMetalOperators =
false;
17 #if defined(CAFFE2_USE_MPSCNN) && defined(C10_MOBILE) 18 if (allowMetalOperators) {
19 caffe2::dumpDef(predict_net);
20 if (caffe2::tryConvertToMPSCNN(init_net, predict_net, &metal_predict_net)) {
21 LOG(INFO) <<
"Successfully converted to MPSCNN";
22 caffe2::dumpDef(metal_predict_net);
23 usingMetalOperators =
true;
25 LOG(ERROR) <<
"Failed converting model to MPSCNN";
31 usingMetalOperators ? metal_predict_net : predict_net,
32 disableMultithreadProcessing,
36 Caffe2IOSPredictor::Caffe2IOSPredictor(
const caffe2::NetDef& init_net,
37 const caffe2::NetDef& predict_net,
38 bool disableMultithreadProcessing,
39 bool usingMetalOperators)
40 : usingMetalOperators(usingMetalOperators), predictor_(init_net, predict_net) {
42 if (disableMultithreadProcessing) {
44 if (threadpool !=
nullptr) {
45 threadpool->setMinWorkSize(std::numeric_limits<size_t>::max());
51 void Caffe2IOSPredictor::run(
const Tensor& inData,
Tensor& outData, std::string& errorMessage) {
52 FLAGS_caffe2_force_shared_col_buffer =
true;
53 caffe2::Tensor input = caffe2::empty(inData.dims, at::dtype<uint8_t>().device(caffe2::CPU));
54 input.ShareExternalPointer(inData.data);
55 caffe2::Predictor::TensorList input_vec;
56 input_vec.emplace_back(std::move(input));
57 caffe2::Predictor::TensorList output_vec;
59 predictor_(input_vec, &output_vec);
61 std::string error = e.msg();
62 errorMessage.swap(error);
64 }
catch (
const std::exception& e) {
65 std::string error = e.what();
66 errorMessage.swap(error);
70 outData.data = output->mutable_data<uint8_t>();
71 outData.dims = output->sizes().vec();
The primary ATen error class.
static Caffe2IOSPredictor * NewCaffe2IOSPredictor(const caffe2::NetDef &init_net, const caffe2::NetDef &predict_net, bool disableMultithreadProcessing, bool allowMetalOperators)
Allow converting eligible operators to Metal GPU framework accelerated operators. ...