1 #include "caffe2/predictor/predictor.h" 2 #include <unordered_set> 3 #include "caffe2/core/init.h" 9 void enforceIsTensor(Workspace* ws,
const std::string& name) {
10 auto blob = ws->GetBlob(name);
11 CAFFE_ENFORCE(blob,
"Blob does not exist: ", name);
13 BlobIsTensorType(*blob, CPU),
"Blob is not a CPU Tensor: ", name);
16 Blob* getBlob(Workspace* ws,
const std::string& name) {
17 enforceIsTensor(ws, name);
18 auto* blob = ws->GetBlob(name);
19 CAFFE_ENFORCE(blob,
"Blob: ", name,
" does not exist");
23 const Tensor& getTensor(Workspace* ws,
const std::string& name) {
24 return *BlobGetMutableTensor(getBlob(ws, name), CPU);
30 const NetDef& init_net,
31 const NetDef& run_net,
35 : Predictor(makePredictorConfig(
42 Predictor::Predictor(PredictorConfig config) : config_(
std::move(config)) {
43 const auto& initialized_vec = config_.ws->Blobs();
44 const std::unordered_set<std::string> initialized{initialized_vec.begin(),
45 initialized_vec.end()};
46 for (
const auto& name : config_.predict_net->external_input()) {
47 if (!initialized.count(name)) {
48 auto* blob = config_.ws->CreateBlob(name);
49 BlobGetMutableTensor(blob, CPU);
52 CAFFE_ENFORCE(config_.ws->CreateNet(config_.predict_net));
55 bool Predictor::operator()(
const TensorList& inputs, TensorList* outputs) {
58 static_cast<unsigned>(config_.predict_net->external_input_size()));
59 for (
size_t i = 0; i < inputs.size(); ++i) {
62 getBlob(config_.ws.get(), config_.predict_net->external_input(i)),
63 inputs[i].UnsafeSharedInstance());
66 if (!config_.ws->RunNet(config_.predict_net->name())) {
70 for (
size_t i = 0; i < config_.predict_net->external_output_size(); ++i) {
71 outputs->emplace_back(
72 getTensor(config_.ws.get(), config_.predict_net->external_output(i))
73 .UnsafeSharedInstance());
78 bool Predictor::run_map_workspace(
const TensorMap& inputs) {
79 if (!config_.input_names.empty()) {
80 CAFFE_ENFORCE_EQ(inputs.size(), input_names().size());
82 for (
auto& input : inputs) {
83 if (!input_names().empty()) {
85 std::find(input_names().begin(), input_names().end(), input.first) !=
87 "Input can't be found: ",
92 getBlob(config_.ws.get(), input.first),
93 input.second.UnsafeSharedInstance());
96 return config_.ws->RunNet(config_.predict_net->name());
99 bool Predictor::operator()(
const TensorMap& inputs, TensorList* outputs) {
100 if (!run_map_workspace(inputs)) {
104 for (
size_t i = 0; i < config_.predict_net->external_output_size(); ++i) {
106 getTensor(config_.ws.get(), config_.predict_net->external_output(i))
107 .UnsafeSharedInstance());
112 bool Predictor::operator()(
const TensorMap& inputs, TensorMap* outputs) {
113 if (!run_map_workspace(inputs)) {
117 for (
const std::string& outputName : output_names()) {
120 getTensor(config_.ws.get(), outputName).UnsafeSharedInstance());
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...