Caffe2 - C++ API
A deep learning, cross platform ML framework
predictor.cc
1 
17 #include "caffe2/core/predictor.h"
18 
19 #include <unordered_set>
20 
21 namespace caffe2 {
22 
23 namespace {
24 
25 void enforceIsTensor(Workspace* ws, const std::string& name) {
26  auto blob = ws->GetBlob(name);
27  CAFFE_ENFORCE(blob, "Blob does not exist: ", name);
28  CAFFE_ENFORCE(
29  blob->template IsType<TensorCPU>(), "Blob is not a CPU Tensor: ", name);
30 }
31 
32 void shareInputTensor(
33  Workspace* ws,
34  const std::string& name,
35  TensorCPU* input) {
36  enforceIsTensor(ws, name);
37  auto* blob = ws->GetBlob(name);
38  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
39  auto* tensor = blob->template GetMutable<TensorCPU>();
40  tensor->ResizeLike(*input);
41  tensor->ShareData(*input);
42 }
43 
44 TensorCPU* extractOutputTensor(Workspace* ws, const std::string& name) {
45  enforceIsTensor(ws, name);
46  auto* blob = ws->GetBlob(name);
47  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
48  return blob->template GetMutable<TensorCPU>();
49 }
50 
51 const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
52  for (const auto& n : def.nets()) {
53  if (n.key() == name) {
54  return n.value();
55  }
56  }
57  CAFFE_THROW("Net not found: ", name);
58 }
59 
60 const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs(
61  const MetaNetDef& def,
62  const std::string& name) {
63  for (const auto& b : def.blobs()) {
64  if (b.key() == name) {
65  return b.value();
66  }
67  }
68  CAFFE_THROW("Blob not found: ", name);
69 }
70 } // namespace
71 
72 Predictor::Predictor(const MetaNetDef& def, Workspace* parent)
73  : Predictor(
74  getNet(
75  def,
76  PredictorConsts::default_instance().global_init_net_type()),
77  getNet(def, PredictorConsts::default_instance().predict_net_type()),
78  parent) {
79  const auto& inputs =
80  getBlobs(def, PredictorConsts::default_instance().inputs_blob_type());
81  for (const auto& input : inputs) {
82  inputNames_.insert(input);
83  }
84 }
85 
86 Predictor::Predictor(
87  const NetDef& init_net,
88  const NetDef& run_net,
89  Workspace* parent)
90  : run_net_(run_net), ws_(parent) {
91  CAFFE_ENFORCE(ws_.RunNetOnce(init_net));
92 
93  // real model inputs can be fed later in run* functions
94  const auto& initialized_vec = ws_.Blobs();
95  const std::unordered_set<std::string> initialized{initialized_vec.begin(),
96  initialized_vec.end()};
97  for (const auto& name : run_net.external_input()) {
98  if (!initialized.count(name)) {
99  auto* blob = ws_.CreateBlob(name);
100  blob->template GetMutable<TensorCPU>();
101  }
102  }
103  CAFFE_ENFORCE(ws_.CreateNet(run_net));
104 }
105 
106 Predictor::~Predictor() {}
107 
108 bool Predictor::run(const TensorVector& inputs, TensorVector* outputs) {
109  CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
110  for (auto i = 0; i < inputs.size(); ++i) {
111  shareInputTensor(&ws_, run_net_.external_input(i), inputs[i]);
112  }
113 
114  if (!ws_.RunNet(run_net_.name())) {
115  return false;
116  }
117 
118  outputs->resize(run_net_.external_output_size());
119  for (auto i = 0; i < outputs->size(); ++i) {
120  (*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
121  }
122  return true;
123 }
124 
125 bool Predictor::run_map(const TensorMap& inputs, TensorVector* outputs) {
126  if (!inputNames_.empty()) {
127  CAFFE_ENFORCE_EQ(inputs.size(), inputNames_.size());
128  }
129  for (auto input : inputs) {
130  if (!inputNames_.empty()) {
131  CAFFE_ENFORCE_GT(inputNames_.count(input.first), 0);
132  }
133  shareInputTensor(&ws_, input.first, input.second);
134  }
135 
136  if (!ws_.RunNet(run_net_.name())) {
137  return false;
138  }
139 
140  outputs->resize(run_net_.external_output_size());
141  for (auto i = 0; i < outputs->size(); ++i) {
142  (*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
143  }
144  return true;
145 }
146 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.