Caffe2 - C++ API
A deep learning, cross platform ML framework
predictor_utils.cc
1 #include "caffe2/predictor/predictor_utils.h"
2 
3 #include "caffe2/core/blob.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/proto/caffe2_pb.h"
6 #include "caffe2/proto/predictor_consts.pb.h"
7 #include "caffe2/utils/proto_utils.h"
8 
9 namespace caffe2 {
10 namespace predictor_utils {
11 
12 CAFFE2_API const NetDef& getNet(
13  const MetaNetDef& def,
14  const std::string& name) {
15  for (const auto& n : def.nets()) {
16  if (n.key() == name) {
17  return n.value();
18  }
19  }
20  CAFFE_THROW("Net not found: ", name);
21 }
22 
23 std::unique_ptr<MetaNetDef> extractMetaNetDef(
24  db::Cursor* cursor,
25  const std::string& key) {
26  CAFFE_ENFORCE(cursor);
27  if (cursor->SupportsSeek()) {
28  cursor->Seek(key);
29  }
30  for (; cursor->Valid(); cursor->Next()) {
31  if (cursor->key() != key) {
32  continue;
33  }
34  // We've found a match. Parse it out.
35  BlobProto proto;
36  CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
37  Blob blob;
38  DeserializeBlob(proto, &blob);
39  CAFFE_ENFORCE(blob.template IsType<string>());
40  auto def = caffe2::make_unique<MetaNetDef>();
41  CAFFE_ENFORCE(def->ParseFromString(blob.template Get<string>()));
42  return def;
43  }
44  CAFFE_THROW("Failed to find in db the key: ", key);
45 }
46 
47 std::unique_ptr<MetaNetDef> runGlobalInitialization(
48  std::unique_ptr<db::DBReader> db,
49  Workspace* master) {
50  CAFFE_ENFORCE(db.get());
51  auto* cursor = db->cursor();
52 
53  auto metaNetDef = extractMetaNetDef(
54  cursor, PredictorConsts::default_instance().meta_net_def());
55  if (metaNetDef->has_modelinfo()) {
56  CAFFE_ENFORCE(
57  metaNetDef->modelinfo().predictortype() ==
58  PredictorConsts::default_instance().single_predictor(),
59  "Can only load single predictor");
60  }
61  VLOG(1) << "Extracted meta net def";
62 
63  const auto globalInitNet = getNet(
64  *metaNetDef, PredictorConsts::default_instance().global_init_net_type());
65  VLOG(1) << "Global init net: " << ProtoDebugString(globalInitNet);
66 
67  // Now, pass away ownership of the DB into the master workspace for
68  // use by the globalInitNet.
69  master->CreateBlob(PredictorConsts::default_instance().predictor_dbreader())
70  ->Reset(db.release());
71 
72  // Now, with the DBReader set, we can run globalInitNet.
73  CAFFE_ENFORCE(
74  master->RunNetOnce(globalInitNet),
75  "Failed running the globalInitNet: ",
76  ProtoDebugString(globalInitNet));
77 
78  return metaNetDef;
79 }
80 
81 } // namespace predictor_utils
82 } // namespace caffe2
void DeserializeBlob(const string &content, Blob *result)
Deserializes from a string containing either BlobProto or TensorProto.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13