3 #include "caffe2/core/tensor.h" 4 #include "caffe2/mobile/contrib/ios/ios_caffe_predictor.h" 5 #include "caffe2/predictor/predictor.h" 8 const std::string& predict_net_str,
9 bool disableMultithreadProcessing,
10 bool allowMetalOperators,
11 std::string& errorMessage) {
12 caffe2::NetDef init_net, predict_net;
13 init_net.ParseFromString(init_net_str);
14 predict_net.ParseFromString(predict_net_str);
19 init_net, predict_net, disableMultithreadProcessing, allowMetalOperators);
21 std::string error = e.msg();
22 errorMessage.swap(error);
24 }
catch (
const std::exception& e) {
25 std::string error = e.what();
26 errorMessage.swap(error);
32 void GenerateStylizedImage(std::vector<float>& originalImage,
33 const std::string& init_net_str,
34 const std::string& predict_net_str,
37 std::vector<float>& dataOut) {
38 caffe2::NetDef init_net, predict_net;
39 init_net.ParseFromString(init_net_str);
40 predict_net.ParseFromString(predict_net_str);
43 std::vector<int> dims({1, 3, height, width});
46 input.ShareExternalPointer(originalImage.data());
47 caffe2::Predictor::TensorList input_vec;
48 input_vec.emplace_back(std::move(input));
49 caffe2::Predictor::TensorList output_vec;
50 p(input_vec, &output_vec);
51 assert(output_vec.size() == 1);
54 float* outputArray = output->mutable_data<
float>();
55 dataOut.assign(outputArray, outputArray + output->size());
The primary ATen error class.
static Caffe2IOSPredictor * NewCaffe2IOSPredictor(const caffe2::NetDef &init_net, const caffe2::NetDef &predict_net, bool disableMultithreadProcessing, bool allowMetalOperators)
Allow converting eligible operators to Metal GPU framework accelerated operators. ...