Caffe2 - C++ API
A deep learning, cross platform ML framework
load_save_op.h
1 
17 #ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
18 #define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
19 
20 #include <cstdio>
21 #include <map>
22 #include <unordered_set>
23 
24 #include "caffe2/core/blob_serialization.h"
25 #include "caffe2/core/context.h"
26 #include "caffe2/core/db.h"
27 #include "caffe2/core/logging.h"
28 #include "caffe2/core/operator.h"
29 #include "caffe2/utils/math.h"
30 #include "caffe2/utils/proto_utils.h"
31 
32 namespace caffe2 {
33 
34 namespace {
35 struct BlobState {
36  int64_t total_size;
37  int64_t current_size;
38  bool is_tensor;
39  std::set<int32_t> seen_chunks_ids;
40 
41  explicit BlobState(
42  int64_t total_size = 0,
43  int64_t current_size = 0,
44  bool is_tensor = false)
45  : total_size(total_size),
46  current_size(current_size),
47  is_tensor(is_tensor) {}
48 };
49 } // namespace
50 
51 using db::Cursor;
52 using db::DB;
53 using db::Transaction;
54 
55 template <class Context>
56 class DBExistsOp final : public Operator<Context> {
57  public:
58  USE_OPERATOR_CONTEXT_FUNCTIONS;
59  DBExistsOp(const OperatorDef& operator_def, Workspace* ws)
60  : Operator<Context>(operator_def, ws),
61  ws_(ws),
62  absolute_path_(
63  OperatorBase::GetSingleArgument<int>("absolute_path", false)),
64  db_name_(OperatorBase::GetSingleArgument<string>("db_name", "")),
65  db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")) {}
66 
67  bool RunOnDevice() override {
68  string full_db_name =
69  absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_);
70  auto* output = Output(0);
71  output->Resize();
72  bool* exists = output->template mutable_data<bool>();
73 
74  *exists = caffe2::db::DBExists(db_type_, full_db_name);
75  return true;
76  }
77 
78  private:
79  Workspace* ws_;
80  bool absolute_path_;
81  std::string db_name_;
82  std::string db_type_;
83 };
84 
85 template <class Context>
86 class LoadOp final : public Operator<Context> {
87  public:
88  USE_OPERATOR_CONTEXT_FUNCTIONS;
89  LoadOp(const OperatorDef& operator_def, Workspace* ws)
90  : Operator<Context>(operator_def, ws),
91  ws_(ws),
92  absolute_path_(
93  OperatorBase::GetSingleArgument<int>("absolute_path", false)),
94  add_prefix_(OperatorBase::GetSingleArgument<string>("add_prefix", "")),
95  strip_prefix_(
96  OperatorBase::GetSingleArgument<string>("strip_prefix", "")),
97  db_name_(OperatorBase::GetSingleArgument<string>("db", "")),
98  db_names_(OperatorBase::GetRepeatedArgument<string>("dbs")),
99  db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")),
100  keep_device_(OperatorBase::GetSingleArgument<int>("keep_device", 0)),
101  load_all_(OperatorBase::GetSingleArgument<int>("load_all", 0)),
102  allow_incomplete_(
103  OperatorBase::GetSingleArgument<bool>("allow_incomplete", false)),
104  blob_names_(
105  OperatorBase::GetRepeatedArgument<string>("source_blob_names")) {
106  if (InputSize() == 0) {
107  CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
108  if (db_names_.empty()) {
109  CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name.");
110  db_names_.push_back(db_name_);
111  db_name_ = "";
112  } else {
113  std::set<std::string> db_name_set;
114  for (const string& db_name : db_names_) {
115  CAFFE_ENFORCE_GT(db_name.size(), 0, "Db name should not be empty.");
116  CAFFE_ENFORCE(
117  db_name_set.insert(db_name).second,
118  "Duplicated db name: ",
119  db_name);
120  }
121  db_name_ = "";
122  }
123  }
124  CAFFE_ENFORCE(blob_names_.empty() || blob_names_.size() == OutputSize(),
125  "Number of output blobs and source_blob_names mismatch.");
126  CAFFE_ENFORCE(blob_names_.empty() || strip_prefix_.empty(),
127  "strip_prefix and source_blob_names are mutually exclusive.");
128  CAFFE_ENFORCE(blob_names_.empty() || !load_all_,
129  "cannot load_all_ while using source_blob_names.");
130  if (!load_all_) {
131  // blob_names_ will be filled with ''source blob names'' in file/db
132  // if argument source_blob_names is not given, then blob_names_ is
133  // inferred from operator output
134  if(blob_names_.empty()) {
135  for (const string& name : operator_def.output()) {
136  blob_names_.push_back(name);
137  }
138  }
139  int idx = 0;
140  std::set<std::string> name_set;
141  for (const string& name : blob_names_) {
142  CAFFE_ENFORCE(name_set.insert(name).second,
143  "Duplicated source blob name: ", name);
144  output_indices_[name] = idx++;
145  }
146  }
147  }
148 
149  void SetCurrentDevice(BlobProto* proto);
150 
151  bool RunOnDevice() override {
152  int total_loaded_blobs = 0;
153  std::unordered_map<string, BlobState> blob_states;
154  if (InputSize() > 0) {
155  for (int i = 0; i < InputSize(); ++i) {
156  const db::DBReader& reader = OperatorBase::Input<db::DBReader>(i);
157  extract(i, reader.cursor(), &blob_states, &total_loaded_blobs);
158  }
159  } else {
160  for (int i = 0; i < db_names_.size(); ++i) {
161  string full_db_name = absolute_path_
162  ? db_names_[i]
163  : (ws_->RootFolder() + "/" + db_names_[i]);
164  std::unique_ptr<DB> in_db(
165  caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ));
166  CAFFE_ENFORCE(in_db.get(), "Cannot open db: ", full_db_name);
167  std::unique_ptr<Cursor> cursor(in_db->NewCursor());
168  extract(i, cursor.get(), &blob_states, &total_loaded_blobs);
169  }
170  }
171 
172  validateBlobStates(blob_states);
173  // Loaded all the needed blobs.
174  if (load_all_ || total_loaded_blobs == OutputSize()) {
175  VLOG(1) << "Loaded " << total_loaded_blobs << " blobs fully from db(s)";
176  return true;
177  }
178 
179  // Only loaded a subset of the blobs.
180  if (allow_incomplete_) {
181  VLOG(1) << "Loaded " << total_loaded_blobs << " blobs out of "
182  << OutputSize() << " blobs from db(s).";
183  } else {
184  for (const string& output_name : this->debug_def().output()) {
185  if (blob_states.count(output_name) == 0) {
186  LOG(ERROR) << "Failed to load blob: " << output_name;
187  }
188  }
189  CAFFE_THROW(
190  "Expected to load ",
191  OutputSize(),
192  " blobs, got ",
193  total_loaded_blobs,
194  " only.\n");
195  }
196 
197  return true;
198  }
199 
200  private:
201  void extract(
202  int db_id,
203  Cursor* cursor,
204  std::unordered_map<string, BlobState>* blob_states,
205  int* total_loaded_blobs) {
206  if (load_all_) {
207  extractAll(db_id, cursor, blob_states, total_loaded_blobs);
208  } else {
209  extractFrom(
210  db_id,
211  cursor,
212  OperatorBase::Outputs(),
213  blob_states,
214  total_loaded_blobs);
215  }
216  }
217 
218  void extractAll(
219  int db_id,
220  Cursor* cursor,
221  std::unordered_map<string, BlobState>* blob_states,
222  int* total_loaded_blobs) {
223  CAFFE_ENFORCE(cursor, "cursor is not valid");
224  int loaded_blobs = 0;
225  for (; cursor->Valid(); cursor->Next()) {
226  const auto key = buildBlobNameFromDbKey(cursor->key());
227  if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
228  CAFFE_THROW("Duplicate Key ", key, " is found!\n");
229  } else {
230  key_to_dbid_[key] = db_id;
231  }
232 
233  BlobProto proto;
234  CAFFE_ENFORCE(
235  proto.ParseFromString(cursor->value()), "Couldn't parse Proto");
236  if (!keep_device_) {
237  // If we are not keeping the device as the one specified in the
238  // proto, we will set the current device.
239  SetCurrentDevice(&proto);
240  }
241  Blob* blob = ws_->CreateBlob(key);
242  ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
243  }
244  *total_loaded_blobs += loaded_blobs;
245  }
246 
247  void extractFrom(
248  int db_id,
249  Cursor* cursor,
250  const vector<Blob*>& outputs,
251  std::unordered_map<string, BlobState>* blob_states,
252  int* total_loaded_blobs) {
253  CAFFE_ENFORCE(cursor);
254  int loaded_blobs = 0;
255  for (; cursor->Valid(); cursor->Next()) {
256  const auto key = buildBlobNameFromDbKey(cursor->key());
257  if (!output_indices_.count(key)) {
258  VLOG(1) << "Key " << key << " not used. Skipping.";
259  } else {
260  if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
261  CAFFE_THROW("Duplicate Key ", key, " is found!\n");
262  } else {
263  key_to_dbid_[key] = db_id;
264  }
265 
266  VLOG(2) << "Deserializing blob " << key;
267  BlobProto proto;
268  CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
269  if (!keep_device_) {
270  // If we are not keeping the device as the one specified in the
271  // proto, we will set the current device.
272  SetCurrentDevice(&proto);
273  }
274  auto blobIndex = output_indices_[key];
275  Blob* blob = outputs.at(blobIndex);
276  ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
277 
278  if (*total_loaded_blobs + loaded_blobs == OutputSize()) {
279  break;
280  }
281  }
282  }
283 
284  *total_loaded_blobs += loaded_blobs;
285  }
286 
287  string buildBlobNameFromDbKey(const string& dbKey) {
288  string key = dbKey.substr(0, dbKey.find(kChunkIdSeparator));
289  if (!strip_prefix_.empty()) {
290  auto match_pos = key.find(strip_prefix_);
291  if (match_pos != string::npos) {
292  key = key.substr(match_pos + strip_prefix_.size());
293  }
294  }
295  key = add_prefix_ + key;
296  return key;
297  }
298 
299  private:
300  // We are tracking sizes of already read tensor parts while reading data
301  // chunks. This way we can make sure that all chunks were loaded in the end.
302  void ProcessBlob(
303  Blob* blob,
304  const BlobProto& proto,
305  std::unordered_map<string, BlobState>* blob_states_ptr,
306  const string& key,
307  int* loaded_blobs) {
308  auto& blob_states = *blob_states_ptr;
309  if (blob_states.count(key) == 0) {
310  // We reset the blob so that any existing content is destroyed. This
311  // is to guaranee correct device placement: if we are deserializing
312  // into a TensorCUDA, without explicit Reset we might be loading data
313  // into an existing TensorCUDA that has pre-allocated memory on a
314  // different GPU.
315  blob->Reset();
316  }
317  blob->Deserialize(proto);
318  if (proto.has_content_num_chunks()) {
319  if (!blob_states.count(key)) {
320  blob_states[key] = BlobState(proto.content_num_chunks());
321  }
322  CAFFE_ENFORCE(
323  blob_states[key]
324  .seen_chunks_ids.insert(proto.content_chunk_id())
325  .second,
326  "Chunk with the same id has occured twice for: ",
327  key);
328  CAFFE_ENFORCE(
329  proto.content_chunk_id() >= 0 &&
330  proto.content_chunk_id() < blob_states[key].total_size,
331  "Chunk id has to be not less than 0 and "
332  "less than content_num_chunks for key: ",
333  key);
334  blob_states[key].current_size++;
335  CAFFE_ENFORCE(
336  !blob_states[key].is_tensor,
337  "Proto with content_chunks can not store tensor: ",
338  key);
339  CAFFE_ENFORCE(
340  blob_states[key].current_size <= blob_states[key].total_size,
341  "Found an extra part for an already filled blob: ",
342  key);
343  if (blob_states[key].current_size == blob_states[key].total_size) {
344  (*loaded_blobs)++;
345  }
346  return;
347  }
348  if (!proto.has_tensor()) {
349  // If blob is divided into chunks the field content_chunks has to be set,
350  // otherwise only tensors can be seen multiple times as chunks.
351  CAFFE_ENFORCE(blob_states.count(key) == 0, "Blob duplicated: ", key);
352  blob_states[key] = BlobState();
353  (*loaded_blobs)++;
354  return;
355  }
356  CAFFE_ENFORCE(proto.has_tensor());
357  if (blob_states.count(key)) {
358  CAFFE_ENFORCE(blob_states[key].is_tensor, "Must be tensor ", key);
359  CAFFE_ENFORCE(
360  blob_states[key].current_size < blob_states[key].total_size,
361  "Found an extra part for an already filled tensor: ",
362  key);
363  CAFFE_ENFORCE(
364  proto.tensor().has_segment(),
365  "Partial tensor must have a segment: ",
366  key);
367  blob_states[key].current_size +=
368  proto.tensor().segment().end() - proto.tensor().segment().begin();
369  CAFFE_ENFORCE(
370  blob_states[key].current_size <= blob_states[key].total_size,
371  "Tensor parts are bigger than target size for tensor: ",
372  key);
373  } else {
374  const auto& dims = proto.tensor().dims();
375  int64_t total_size = 1;
376  for (const auto& dim : dims) {
377  total_size *= dim;
378  }
379  auto current_size = total_size;
380  if (proto.tensor().has_segment()) {
381  current_size =
382  proto.tensor().segment().end() - proto.tensor().segment().begin();
383  }
384  blob_states[key] =
385  BlobState(total_size, current_size, true /* is_tensor */);
386  }
387 
388  if (blob_states[key].current_size == blob_states[key].total_size) {
389  (*loaded_blobs)++;
390  }
391  }
392 
393  void validateBlobStates(
394  const std::unordered_map<string, BlobState>& blob_states) {
395  for (const auto& iter : blob_states) {
396  const BlobState& blob_state = iter.second;
397  CAFFE_ENFORCE(
398  blob_state.current_size == blob_state.total_size,
399  "Data size mismatch for blob ",
400  iter.first,
401  ". Expected: ",
402  blob_state.total_size,
403  " Read: ",
404  blob_state.current_size);
405  }
406  }
407 
408  Workspace* ws_;
409  bool absolute_path_;
410  string add_prefix_;
411  string strip_prefix_;
412  string db_name_;
413  std::vector<std::string> db_names_;
414  string db_type_;
415  bool keep_device_;
416  bool load_all_;
417  bool allow_incomplete_;
418  std::map<string, int> output_indices_;
419  std::map<string, int> key_to_dbid_;
420  std::vector<std::string> blob_names_;
421 };
422 
423 template <class Context>
424 class SaveOp final : public Operator<Context> {
425  public:
426  USE_OPERATOR_CONTEXT_FUNCTIONS;
427  SaveOp(const OperatorDef& operator_def, Workspace* ws)
428  : Operator<Context>(operator_def, ws),
429  ws_(ws),
430  absolute_path_(
431  OperatorBase::GetSingleArgument<int>("absolute_path", false)),
432  strip_prefix_(
433  OperatorBase::GetSingleArgument<string>("strip_prefix", "")),
434  db_name_(OperatorBase::GetSingleArgument<string>("db", "")),
435  db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")),
436  blob_names_(
437  OperatorBase::GetRepeatedArgument<string>("blob_name_overrides")) {
438  CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name.");
439  CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
440  CAFFE_ENFORCE(
441  blob_names_.empty() ||
442  blob_names_.size() == OperatorBase::Inputs().size(),
443  "Number of blobs and blob_name_overrides mismatch.");
444  CAFFE_ENFORCE(
445  blob_names_.empty() || strip_prefix_.empty(),
446  "strip_prefix and blob_name_overrides are mutually exclusive.");
447 
448  if (blob_names_.empty()) {
449  std::set<std::string> input_names;
450  blob_names_.resize(OperatorBase::Inputs().size());
451  for (int i = 0; i < blob_names_.size(); ++i) {
452  std::string name;
453  if (strip_prefix_.empty()) {
454  name = operator_def.input(i);
455  } else {
456  auto match_pos = operator_def.input(i).find(strip_prefix_);
457  if (match_pos == string::npos) {
458  name = operator_def.input(i);
459  } else {
460  name = operator_def.input(i).substr(
461  match_pos + strip_prefix_.size(), string::npos);
462  }
463  }
464  CAFFE_ENFORCE(
465  input_names.insert(name).second, "Duplicated input: ", name);
466  blob_names_[i] = name;
467  }
468  }
469  }
470 
471  bool RunOnDevice() override {
472  string full_db_name =
473  absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_);
474  std::unique_ptr<DB> out_db(
475  caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::NEW));
476  CAFFE_ENFORCE(out_db.get(), "Cannot open db for writing: ", full_db_name);
477 
478  BlobSerializerBase::SerializationAcceptor acceptor = [&](
479  const std::string& blobName, const std::string& data) {
480  // transaction should take care of locking
481  VLOG(2) << "Sending " << blobName << " blob's data of size "
482  << data.size() << " to db";
483  auto transaction = out_db->NewTransaction();
484  transaction->Put(blobName, data);
485  transaction->Commit();
486  };
487 
488  const vector<const Blob*>& inputs = OperatorBase::Inputs();
489  for (int i = 0; i < inputs.size(); ++i) {
490  inputs[i]->Serialize(blob_names_[i], acceptor);
491  }
492  out_db->Close();
493  return true;
494  }
495 
496  private:
497  Workspace* ws_;
498  bool absolute_path_;
499  string strip_prefix_;
500  string db_name_;
501  string db_type_;
502  std::vector<std::string> blob_names_;
503 };
504 
505 template <typename... Ts>
506 string FormatString(const string& pattern, Ts... values) {
507  // Note(Yangqing): We believe that 1024 is enough, but who are we to assert
508  // that?
509  // As a result, if things go wrong, we'll just throw the towel and quit loud.
510  // Yeah, I know that there is snprintf, but it is not present in *some*
511  // platforms unfortunately.
512  char buffer[1024];
513  int written = sprintf(buffer, pattern.c_str(), values...);
514  if (written < 0 || written + 1 > 1024) {
515  LOG(FATAL) << "FormatString fails: total bytes written " << written;
516  }
517  return string(buffer);
518  /*
519  * The following is the snprintf version that is safe; enable it one day?
520  unsigned int required =
521  std::snprintf(nullptr, 0, pattern.c_str(), values...) + 1;
522  char bytes[required];
523  std::snprintf(bytes, required, pattern.c_str(), values...);
524  return string(bytes);
525  */
526 }
527 
528 // CheckpointOp is a wrapper over a SaveFloatTensorOp that basically allows
529 // flexible naming over iterations.
530 // The file pattern in db_name should be a format string that can be passed into
531 // sprintf with an int argument specifying the current iteration. An example:
532 // "/path/to/my/checkpoint/checkpoint_at_%d.pb"
533 template <class Context>
534 class CheckpointOp final : public Operator<Context> {
535  public:
536  CheckpointOp(const OperatorDef& operator_def, Workspace* ws)
537  : Operator<Context>(operator_def, ws),
538  db_pattern_(OperatorBase::GetSingleArgument<string>("db", "")),
539  every_(OperatorBase::GetSingleArgument<int>("every", 1)),
540  ws_(ws),
541  save_op_def_(operator_def) {
542  CAFFE_ENFORCE_GT(
543  db_pattern_.size(), 0, "Must specify a checkpoint file pattern.");
544  CAFFE_ENFORCE_GT(every_, 0, "Checkpoint interval should be positive.");
545  if (every_ == 1) {
546  // Just issue a warning, but it's totally legal so we don't do anything.
547  LOG(WARNING) << "It seems that we are checkpointting every iteration. "
548  << "Is that intended?";
549  }
550  save_op_def_.set_type("Save");
551  }
552 
553  bool RunOnDevice() override {
554  int64_t iter =
555  OperatorBase::Input<TensorCPU>(0).template data<int64_t>()[0];
556  if (iter % every_ == 0) {
557  GetMutableArgument("db", true, &save_op_def_)
558  ->set_s(FormatString(db_pattern_, iter));
559  SaveOp<Context> sub_op(save_op_def_, ws_);
560  return sub_op.Run();
561  } else {
562  return true;
563  }
564  }
565 
566  private:
567  string db_pattern_;
568  int every_;
569  Workspace* ws_;
570  OperatorDef save_op_def_;
571 };
572 
573 } // namespace caffe2
574 
575 #endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:160
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Definition: db.h:270
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
T * Reset(T *allocated)
Sets the underlying object to the allocated one.
Definition: blob.h:137
void Deserialize(const string &content)
Deserializes from a string containing either BlobProto or TensorProto.