Caffe2 - C++ API
A deep learning, cross platform ML framework
predictor.h
1 #pragma once
2 
3 #include <unordered_set>
4 #include "caffe2/core/net.h"
5 #include "caffe2/core/tensor.h"
6 #include "caffe2/predictor/predictor_config.h"
7 
8 namespace caffe2 {
9 
10 class CAFFE2_API Predictor {
11  public:
12  using TensorList = std::vector<TensorCPU>;
13  using TensorMap = std::unordered_map<std::string, TensorCPU>;
14 
15  Predictor(
16  const NetDef& init_net,
17  const NetDef& run_net,
18  Workspace* parent = nullptr,
19  bool run_init = true,
20  int optimization = 1);
21 
22  Predictor(PredictorConfig config);
23 
24  virtual ~Predictor() {}
25 
26  // Executes `run_net` on the inputs.
27  // The first `inputs.size()` inputs from run_net::external_inputs
28  // are shared with the data in `inputs`.
29 
30  // Precondition:
31  // inputs.size() <= run_net_.external_inputs.size()
32 
33  // Postcondition:
34  // outputs->size() == run_net.external_inputs.size()
35 
36  // NOTE: output is a part of thread local workspace
37  // and is only valid until the next predictor execution.
38 
39  // Returns true on success
40  virtual bool operator()(const TensorList& inputs, TensorList* outputs);
41 
42  // Similar to run, but consumes a map of name to tensor as input
43  bool operator()(const TensorMap& inputs, TensorList* outputs);
44 
45  // Similar to the other run fns, except inputs and outputs are both maps of
46  // string name to tensor.
47  bool operator()(const TensorMap& inputs, TensorMap* outputs);
48 
49  const NetDef& def() const {
50  return *config_.predict_net;
51  };
52 
53  Workspace* ws() {
54  return config_.ws.get();
55  };
56 
57  const std::vector<std::string>& input_names() const {
58  return config_.input_names;
59  }
60 
61  const std::vector<std::string>& output_names() const {
62  return config_.output_names;
63  }
64 
65  private:
66  bool run_map_workspace(const TensorMap& inputs);
67 
68  protected:
69  PredictorConfig config_;
70 };
71 } // namespace caffe2
Stores parameters nessasary for creating a PredictorInterface object.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13