Caffe2 - C++ API
A deep learning, cross platform ML framework
ios_caffe_predictor.cc
1 #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h"
2 #include "caffe2/core/flags.h"
3 #include "caffe2/core/tensor.h"
4 
5 #if defined(CAFFE2_USE_MPSCNN) && defined(C10_MOBILE)
6 #include "caffe2/mobile/contrib/ios/mpscnn/mpscnn.h"
7 #endif
8 
9 C10_DECLARE_bool(caffe2_force_shared_col_buffer);
10 
12  const caffe2::NetDef& predict_net,
13  bool disableMultithreadProcessing,
14  bool allowMetalOperators) {
15  caffe2::NetDef metal_predict_net;
16  bool usingMetalOperators = false;
17 #if defined(CAFFE2_USE_MPSCNN) && defined(C10_MOBILE)
18  if (allowMetalOperators) {
19  caffe2::dumpDef(predict_net);
20  if (caffe2::tryConvertToMPSCNN(init_net, predict_net, &metal_predict_net)) {
21  LOG(INFO) << "Successfully converted to MPSCNN";
22  caffe2::dumpDef(metal_predict_net);
23  usingMetalOperators = true;
24  } else {
25  LOG(ERROR) << "Failed converting model to MPSCNN";
26  }
27  }
28 #endif
29 
30  return new Caffe2IOSPredictor(init_net,
31  usingMetalOperators ? metal_predict_net : predict_net,
32  disableMultithreadProcessing,
33  usingMetalOperators);
34 }
35 
36 Caffe2IOSPredictor::Caffe2IOSPredictor(const caffe2::NetDef& init_net,
37  const caffe2::NetDef& predict_net,
38  bool disableMultithreadProcessing,
39  bool usingMetalOperators)
40  : usingMetalOperators(usingMetalOperators), predictor_(init_net, predict_net) {
41 #ifdef C10_MOBILE
42  if (disableMultithreadProcessing) {
43  caffe2::ThreadPool* threadpool = predictor_.ws()->GetThreadPool();
44  if (threadpool != nullptr) {
45  threadpool->setMinWorkSize(std::numeric_limits<size_t>::max());
46  }
47  }
48 #endif
49 }
50 
51 void Caffe2IOSPredictor::run(const Tensor& inData, Tensor& outData, std::string& errorMessage) {
52  FLAGS_caffe2_force_shared_col_buffer = true;
53  caffe2::Tensor input = caffe2::empty(inData.dims, at::dtype<uint8_t>().device(caffe2::CPU));
54  input.ShareExternalPointer(inData.data);
55  caffe2::Predictor::TensorList input_vec;
56  input_vec.emplace_back(std::move(input));
57  caffe2::Predictor::TensorList output_vec;
58  try {
59  predictor_(input_vec, &output_vec);
60  } catch (const caffe2::EnforceNotMet& e) {
61  std::string error = e.msg();
62  errorMessage.swap(error);
63  return;
64  } catch (const std::exception& e) {
65  std::string error = e.what();
66  errorMessage.swap(error);
67  return;
68  }
69  caffe2::Tensor* output = &output_vec.front();
70  outData.data = output->mutable_data<uint8_t>();
71  outData.dims = output->sizes().vec();
72 }
The primary ATen error class.
Definition: Exception.h:27
static Caffe2IOSPredictor * NewCaffe2IOSPredictor(const caffe2::NetDef &init_net, const caffe2::NetDef &predict_net, bool disableMultithreadProcessing, bool allowMetalOperators)
Allow converting eligible operators to Metal GPU framework accelerated operators. ...