1 #include "caffe2/predictor/predictor.h"     2 #include <unordered_set>     3 #include "caffe2/core/init.h"     9 void enforceIsTensor(Workspace* ws, 
const std::string& name) {
    10   auto blob = ws->GetBlob(name);
    11   CAFFE_ENFORCE(blob, 
"Blob does not exist: ", name);
    13       BlobIsTensorType(*blob, CPU), 
"Blob is not a CPU Tensor: ", name);
    16 Blob* getBlob(Workspace* ws, 
const std::string& name) {
    17   enforceIsTensor(ws, name);
    18   auto* blob = ws->GetBlob(name);
    19   CAFFE_ENFORCE(blob, 
"Blob: ", name, 
" does not exist");
    23 const Tensor& getTensor(Workspace* ws, 
const std::string& name) {
    24   return *BlobGetMutableTensor(getBlob(ws, name), CPU);
    30     const NetDef& init_net,
    31     const NetDef& run_net,
    35     : Predictor(makePredictorConfig(
    42 Predictor::Predictor(PredictorConfig config) : config_(
std::move(config)) {
    43   const auto& initialized_vec = config_.ws->Blobs();
    44   const std::unordered_set<std::string> initialized{initialized_vec.begin(),
    45                                                     initialized_vec.end()};
    46   for (
const auto& name : config_.predict_net->external_input()) {
    47     if (!initialized.count(name)) {
    48       auto* blob = config_.ws->CreateBlob(name);
    49       BlobGetMutableTensor(blob, CPU);
    52   CAFFE_ENFORCE(config_.ws->CreateNet(config_.predict_net));
    55 bool Predictor::operator()(
const TensorList& inputs, TensorList* outputs) {
    58       static_cast<unsigned>(config_.predict_net->external_input_size()));
    59   for (
size_t i = 0; i < inputs.size(); ++i) {
    62         getBlob(config_.ws.get(), config_.predict_net->external_input(i)),
    63         inputs[i].UnsafeSharedInstance());
    66   if (!config_.ws->RunNet(config_.predict_net->name())) {
    70   for (
size_t i = 0; i < config_.predict_net->external_output_size(); ++i) {
    71     outputs->emplace_back(
    72         getTensor(config_.ws.get(), config_.predict_net->external_output(i))
    73             .UnsafeSharedInstance());
    78 bool Predictor::run_map_workspace(
const TensorMap& inputs) {
    79   if (!config_.input_names.empty()) {
    80     CAFFE_ENFORCE_EQ(inputs.size(), input_names().size());
    82   for (
auto& input : inputs) {
    83     if (!input_names().empty()) {
    85           std::find(input_names().begin(), input_names().end(), input.first) !=
    87           "Input can't be found: ",
    92         getBlob(config_.ws.get(), input.first),
    93         input.second.UnsafeSharedInstance());
    96   return config_.ws->RunNet(config_.predict_net->name());
    99 bool Predictor::operator()(
const TensorMap& inputs, TensorList* outputs) {
   100   if (!run_map_workspace(inputs)) {
   104   for (
size_t i = 0; i < config_.predict_net->external_output_size(); ++i) {
   106         getTensor(config_.ws.get(), config_.predict_net->external_output(i))
   107             .UnsafeSharedInstance());
   112 bool Predictor::operator()(
const TensorMap& inputs, TensorMap* outputs) {
   113   if (!run_map_workspace(inputs)) {
   117   for (
const std::string& outputName : output_names()) {
   120         getTensor(config_.ws.get(), outputName).UnsafeSharedInstance());
 
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...