Caffe2 - C++ API
A deep learning, cross platform ML framework
ios_caffe.cc
1 
2 #include "ios_caffe.h"
3 #include "caffe2/core/tensor.h"
4 #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h"
5 #include "caffe2/predictor/predictor.h"
6 
7 Caffe2IOSPredictor* MakeCaffe2Predictor(const std::string& init_net_str,
8  const std::string& predict_net_str,
9  bool disableMultithreadProcessing,
10  bool allowMetalOperators,
11  std::string& errorMessage) {
12  caffe2::NetDef init_net, predict_net;
13  init_net.ParseFromString(init_net_str);
14  predict_net.ParseFromString(predict_net_str);
15 
16  Caffe2IOSPredictor* predictor = NULL;
17  try {
19  init_net, predict_net, disableMultithreadProcessing, allowMetalOperators);
20  } catch (const caffe2::EnforceNotMet& e) {
21  std::string error = e.msg();
22  errorMessage.swap(error);
23  return NULL;
24  } catch (const std::exception& e) {
25  std::string error = e.what();
26  errorMessage.swap(error);
27  return NULL;
28  }
29  return predictor;
30 }
31 
32 void GenerateStylizedImage(std::vector<float>& originalImage,
33  const std::string& init_net_str,
34  const std::string& predict_net_str,
35  int height,
36  int width,
37  std::vector<float>& dataOut) {
38  caffe2::NetDef init_net, predict_net;
39  init_net.ParseFromString(init_net_str);
40  predict_net.ParseFromString(predict_net_str);
41  caffe2::Predictor p(init_net, predict_net);
42 
43  std::vector<int> dims({1, 3, height, width});
44  caffe2::Tensor input(caffe2::CPU);
45  input.Resize(dims);
46  input.ShareExternalPointer(originalImage.data());
47  caffe2::Predictor::TensorList input_vec;
48  input_vec.emplace_back(std::move(input));
49  caffe2::Predictor::TensorList output_vec;
50  p(input_vec, &output_vec);
51  assert(output_vec.size() == 1);
52  caffe2::TensorCPU* output = &output_vec.front();
53  // output is our styled image
54  float* outputArray = output->mutable_data<float>();
55  dataOut.assign(outputArray, outputArray + output->size());
56 }
The primary ATen error class.
Definition: Exception.h:27
static Caffe2IOSPredictor * NewCaffe2IOSPredictor(const caffe2::NetDef &init_net, const caffe2::NetDef &predict_net, bool disableMultithreadProcessing, bool allowMetalOperators)
Allow converting eligible operators to Metal GPU framework accelerated operators. ...