Caffe2 - C++ API
A deep learning, cross platform ML framework
ios_caffe_predictor.h
1 
2 #pragma once
3 
4 #include <string>
5 #include "caffe2/core/net.h"
6 #include "caffe2/mobile/contrib/ios/ios_caffe_defines.h"
7 #include "caffe2/predictor/predictor.h"
8 
9 struct Tensor {
10  std::vector<int64_t> dims;
11  uint8_t* data;
12 };
13 
14 class IOS_CAFFE_EXPORT Caffe2IOSPredictor final {
15  public:
21  static Caffe2IOSPredictor* NewCaffe2IOSPredictor(const caffe2::NetDef& init_net,
22  const caffe2::NetDef& predict_net,
23  bool disableMultithreadProcessing,
24  bool allowMetalOperators);
25  void run(const Tensor& inData, Tensor& outData, std::string& errorMessage);
26  ~Caffe2IOSPredictor(){};
27 
28  const bool usingMetalOperators;
29 
30  private:
31  Caffe2IOSPredictor(const caffe2::NetDef& init_net,
32  const caffe2::NetDef& predict_net,
33  bool disableMultithreadProcessing,
34  bool usingMetalOperators);
35  caffe2::Predictor predictor_;
36 };