Caffe2 - C++ API
A deep learning, cross platform ML framework
ios_caffe.cc
1 
18 #include "ios_caffe.h"
19 #include "caffe2/core/predictor.h"
20 #include "caffe2/core/tensor.h"
21 #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h"
22 
23 Caffe2IOSPredictor* MakeCaffe2Predictor(const std::string& init_net_str,
24  const std::string& predict_net_str,
25  bool disableMultithreadProcessing,
26  bool allowMetalOperators,
27  std::string& errorMessage) {
28  caffe2::NetDef init_net, predict_net;
29  init_net.ParseFromString(init_net_str);
30  predict_net.ParseFromString(predict_net_str);
31 
32  Caffe2IOSPredictor* predictor = NULL;
33  try {
35  init_net, predict_net, disableMultithreadProcessing, allowMetalOperators);
36  } catch (const caffe2::EnforceNotMet& e) {
37  std::string error = e.msg();
38  errorMessage.swap(error);
39  return NULL;
40  } catch (const std::exception& e) {
41  std::string error = e.what();
42  errorMessage.swap(error);
43  return NULL;
44  }
45  return predictor;
46 }
47 
48 void GenerateStylizedImage(std::vector<float>& originalImage,
49  const std::string& init_net_str,
50  const std::string& predict_net_str,
51  int height,
52  int width,
53  std::vector<float>& dataOut) {
54  caffe2::NetDef init_net, predict_net;
55  init_net.ParseFromString(init_net_str);
56  predict_net.ParseFromString(predict_net_str);
57  caffe2::Predictor p(init_net, predict_net);
58 
59  std::vector<int> dims({1, 3, height, width});
60  caffe2::TensorCPU input;
61  input.Resize(dims);
62  input.ShareExternalPointer(originalImage.data());
63  caffe2::Predictor::TensorVector input_vec{&input};
64  caffe2::Predictor::TensorVector output_vec;
65  p.run(input_vec, &output_vec);
66  assert(output_vec.size() == 1);
67  caffe2::TensorCPU* output = output_vec.front();
68  // output is our styled image
69  float* outputArray = output->mutable_data<float>();
70  dataOut.assign(outputArray, outputArray + output->size());
71 }
static Caffe2IOSPredictor * NewCaffe2IOSPredictor(const caffe2::NetDef &init_net, const caffe2::NetDef &predict_net, bool disableMultithreadProcessing, bool allowMetalOperators)
Allow converting eligible operators to Metal GPU framework accelerated operators. ...