Caffe2 - C++ API
A deep learning, cross platform ML framework
predictor.cc
1 #include "caffe2/predictor/predictor.h"
2 #include <unordered_set>
3 #include "caffe2/core/init.h"
4 
5 namespace caffe2 {
6 
7 namespace {
8 
9 void enforceIsTensor(Workspace* ws, const std::string& name) {
10  auto blob = ws->GetBlob(name);
11  CAFFE_ENFORCE(blob, "Blob does not exist: ", name);
12  CAFFE_ENFORCE(
13  BlobIsTensorType(*blob, CPU), "Blob is not a CPU Tensor: ", name);
14 }
15 
16 Blob* getBlob(Workspace* ws, const std::string& name) {
17  enforceIsTensor(ws, name);
18  auto* blob = ws->GetBlob(name);
19  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
20  return blob;
21 }
22 
23 const Tensor& getTensor(Workspace* ws, const std::string& name) {
24  return *BlobGetMutableTensor(getBlob(ws, name), CPU);
25 }
26 
27 } // namespace
28 
29 Predictor::Predictor(
30  const NetDef& init_net,
31  const NetDef& run_net,
32  Workspace* parent,
33  bool run_init,
34  int optimization)
35  : Predictor(makePredictorConfig(
36  init_net,
37  run_net,
38  parent,
39  run_init,
40  optimization)) {}
41 
42 Predictor::Predictor(PredictorConfig config) : config_(std::move(config)) {
43  const auto& initialized_vec = config_.ws->Blobs();
44  const std::unordered_set<std::string> initialized{initialized_vec.begin(),
45  initialized_vec.end()};
46  for (const auto& name : config_.predict_net->external_input()) {
47  if (!initialized.count(name)) {
48  auto* blob = config_.ws->CreateBlob(name);
49  BlobGetMutableTensor(blob, CPU);
50  }
51  }
52  CAFFE_ENFORCE(config_.ws->CreateNet(config_.predict_net));
53 }
54 
55 bool Predictor::operator()(const TensorList& inputs, TensorList* outputs) {
56  CAFFE_ENFORCE(
57  inputs.size() <=
58  static_cast<unsigned>(config_.predict_net->external_input_size()));
59  for (size_t i = 0; i < inputs.size(); ++i) {
60  // This is evil and shares the same underlying tensor
61  BlobSetTensor(
62  getBlob(config_.ws.get(), config_.predict_net->external_input(i)),
63  inputs[i].UnsafeSharedInstance());
64  }
65 
66  if (!config_.ws->RunNet(config_.predict_net->name())) {
67  return false;
68  }
69  outputs->clear();
70  for (size_t i = 0; i < config_.predict_net->external_output_size(); ++i) {
71  outputs->emplace_back(
72  getTensor(config_.ws.get(), config_.predict_net->external_output(i))
73  .UnsafeSharedInstance());
74  }
75  return true;
76 }
77 
78 bool Predictor::run_map_workspace(const TensorMap& inputs) {
79  if (!config_.input_names.empty()) {
80  CAFFE_ENFORCE_EQ(inputs.size(), input_names().size());
81  }
82  for (auto& input : inputs) {
83  if (!input_names().empty()) {
84  CAFFE_ENFORCE(
85  std::find(input_names().begin(), input_names().end(), input.first) !=
86  input_names().end(),
87  "Input can't be found: ",
88  input.first);
89  }
90  // This is evil and shares the same underlying tensor
91  BlobSetTensor(
92  getBlob(config_.ws.get(), input.first),
93  input.second.UnsafeSharedInstance());
94  }
95 
96  return config_.ws->RunNet(config_.predict_net->name());
97 }
98 
99 bool Predictor::operator()(const TensorMap& inputs, TensorList* outputs) {
100  if (!run_map_workspace(inputs)) {
101  return false;
102  }
103  outputs->clear();
104  for (size_t i = 0; i < config_.predict_net->external_output_size(); ++i) {
105  outputs->push_back(
106  getTensor(config_.ws.get(), config_.predict_net->external_output(i))
107  .UnsafeSharedInstance());
108  }
109  return true;
110 }
111 
112 bool Predictor::operator()(const TensorMap& inputs, TensorMap* outputs) {
113  if (!run_map_workspace(inputs)) {
114  return false;
115  }
116 
117  for (const std::string& outputName : output_names()) {
118  outputs->emplace(
119  outputName,
120  getTensor(config_.ws.get(), outputName).UnsafeSharedInstance());
121  }
122  return true;
123 }
124 
125 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13