Caffe2 - C++ API
A deep learning, cross platform ML framework
load_save_op.h
1 #ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
2 #define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
3 
4 #include <cstdio>
5 #include <map>
6 #include <unordered_set>
7 
8 #include "caffe2/core/blob_serialization.h"
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/db.h"
11 #include "caffe2/core/logging.h"
12 #include "caffe2/core/operator.h"
13 #include "caffe2/utils/math.h"
14 #include "caffe2/utils/proto_utils.h"
15 
16 namespace caffe2 {
17 
18 namespace {
19 struct BlobState {
20  int64_t total_size;
21  int64_t current_size;
22  bool is_tensor;
23  std::set<int32_t> seen_chunks_ids;
24 
25  explicit BlobState(
26  int64_t total_size = 0,
27  int64_t current_size = 0,
28  bool is_tensor = false)
29  : total_size(total_size),
30  current_size(current_size),
31  is_tensor(is_tensor) {}
32 };
33 } // namespace
34 
35 using db::Cursor;
36 using db::DB;
37 using db::Transaction;
38 
39 template <class Context>
40 class DBExistsOp final : public Operator<Context> {
41  public:
42  USE_OPERATOR_CONTEXT_FUNCTIONS;
43  explicit DBExistsOp(const OperatorDef& operator_def, Workspace* ws)
44  : Operator<Context>(operator_def, ws),
45  ws_(ws),
46  absolute_path_(
47  this->template GetSingleArgument<int>("absolute_path", false)),
48  db_name_(this->template GetSingleArgument<string>("db_name", "")),
49  db_type_(this->template GetSingleArgument<string>("db_type", "")) {}
50 
51  bool RunOnDevice() override {
52  string full_db_name =
53  absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_);
54  auto* output = Output(0);
55  output->Resize();
56  bool* exists = output->template mutable_data<bool>();
57 
58  *exists = caffe2::db::DBExists(db_type_, full_db_name);
59  return true;
60  }
61 
62  private:
63  Workspace* ws_;
64  bool absolute_path_;
65  std::string db_name_;
66  std::string db_type_;
67 };
68 
69 template <class Context>
70 class LoadOp final : public Operator<Context> {
71  public:
72  USE_OPERATOR_CONTEXT_FUNCTIONS;
73  explicit LoadOp(const OperatorDef& operator_def, Workspace* ws)
74  : Operator<Context>(operator_def, ws),
75  ws_(ws),
76  absolute_path_(
77  this->template GetSingleArgument<int>("absolute_path", false)),
78  add_prefix_(this->template GetSingleArgument<string>("add_prefix", "")),
79  strip_prefix_(
80  this->template GetSingleArgument<string>("strip_prefix", "")),
81  db_name_(this->template GetSingleArgument<string>("db", "")),
82  db_names_(this->template GetRepeatedArgument<string>("dbs")),
83  db_type_(this->template GetSingleArgument<string>("db_type", "")),
84  keep_device_(this->template GetSingleArgument<int>("keep_device", 0)),
85  load_all_(this->template GetSingleArgument<int>("load_all", 0)),
86  allow_incomplete_(
87  this->template GetSingleArgument<bool>("allow_incomplete", false)),
88  blob_names_(
89  this->template GetRepeatedArgument<string>("source_blob_names")) {
90  if (InputSize() == 0) {
91  CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
92  if (db_names_.empty()) {
93  CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name.");
94  db_names_.push_back(db_name_);
95  db_name_ = "";
96  } else {
97  std::set<std::string> db_name_set;
98  for (const string& db_name : db_names_) {
99  CAFFE_ENFORCE_GT(db_name.size(), 0, "Db name should not be empty.");
100  CAFFE_ENFORCE(
101  db_name_set.insert(db_name).second,
102  "Duplicated db name: ",
103  db_name);
104  }
105  db_name_ = "";
106  }
107  }
108  CAFFE_ENFORCE(blob_names_.empty() || blob_names_.size() == OutputSize(),
109  "Number of output blobs and source_blob_names mismatch.");
110  CAFFE_ENFORCE(blob_names_.empty() || strip_prefix_.empty(),
111  "strip_prefix and source_blob_names are mutually exclusive.");
112  CAFFE_ENFORCE(blob_names_.empty() || !load_all_,
113  "cannot load_all_ while using source_blob_names.");
114  if (!load_all_) {
115  // blob_names_ will be filled with ''source blob names'' in file/db
116  // if argument source_blob_names is not given, then blob_names_ is
117  // inferred from operator output
118  if(blob_names_.empty()) {
119  for (const string& name : operator_def.output()) {
120  blob_names_.push_back(name);
121  }
122  }
123  int idx = 0;
124  std::set<std::string> name_set;
125  for (const string& name : blob_names_) {
126  CAFFE_ENFORCE(name_set.insert(name).second,
127  "Duplicated source blob name: ", name);
128  output_indices_[name] = idx++;
129  }
130  }
131  }
132 
133  void SetCurrentDevice(BlobProto* proto);
134 
135  bool RunOnDevice() override {
136  int total_loaded_blobs = 0;
137  std::unordered_map<string, BlobState> blob_states;
138  if (InputSize() > 0) {
139  for (int i = 0; i < InputSize(); ++i) {
140  const db::DBReader& reader = this->template Input<db::DBReader>(i);
141  extract(i, reader.cursor(), &blob_states, &total_loaded_blobs);
142  }
143  } else {
144  for (int i = 0; i < db_names_.size(); ++i) {
145  string full_db_name = absolute_path_
146  ? db_names_[i]
147  : (ws_->RootFolder() + "/" + db_names_[i]);
148  std::unique_ptr<DB> in_db(
149  caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ));
150  CAFFE_ENFORCE(
151  in_db.get(),
152  "Cannot find db implementation of type ",
153  db_type_,
154  " (while trying to open ",
155  full_db_name,
156  ")");
157  std::unique_ptr<Cursor> cursor(in_db->NewCursor());
158  extract(i, cursor.get(), &blob_states, &total_loaded_blobs);
159  }
160  }
161 
162  validateBlobStates(blob_states);
163  // Loaded all the needed blobs.
164  if (load_all_ || total_loaded_blobs == OutputSize()) {
165  VLOG(1) << "Loaded " << total_loaded_blobs << " blobs fully from db(s)";
166  return true;
167  }
168 
169  // Only loaded a subset of the blobs.
170  if (allow_incomplete_) {
171  VLOG(1) << "Loaded " << total_loaded_blobs << " blobs out of "
172  << OutputSize() << " blobs from db(s).";
173  } else {
174  for (const string& output_name : this->debug_def().output()) {
175  if (blob_states.count(output_name) == 0) {
176  LOG(ERROR) << "Failed to load blob: " << output_name;
177  }
178  }
179  CAFFE_THROW(
180  "Expected to load ",
181  OutputSize(),
182  " blobs, got ",
183  total_loaded_blobs,
184  " only.\n");
185  }
186 
187  return true;
188  }
189 
190  private:
191  void extract(
192  int db_id,
193  Cursor* cursor,
194  std::unordered_map<string, BlobState>* blob_states,
195  int* total_loaded_blobs) {
196  if (load_all_) {
197  extractAll(db_id, cursor, blob_states, total_loaded_blobs);
198  } else {
199  extractFrom(
200  db_id,
201  cursor,
202  OperatorBase::Outputs(),
203  blob_states,
204  total_loaded_blobs);
205  }
206  }
207 
208  void extractAll(
209  int db_id,
210  Cursor* cursor,
211  std::unordered_map<string, BlobState>* blob_states,
212  int* total_loaded_blobs) {
213  CAFFE_ENFORCE(cursor, "cursor is not valid");
214  int loaded_blobs = 0;
215  for (; cursor->Valid(); cursor->Next()) {
216  const auto key = buildBlobNameFromDbKey(cursor->key());
217  if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
218  CAFFE_THROW("Duplicate Key ", key, " is found!\n");
219  } else {
220  key_to_dbid_[key] = db_id;
221  }
222 
223  BlobProto proto;
224  CAFFE_ENFORCE(
225  proto.ParseFromString(cursor->value()), "Couldn't parse Proto");
226  if (!keep_device_) {
227  // If we are not keeping the device as the one specified in the
228  // proto, we will set the current device.
229  SetCurrentDevice(&proto);
230  }
231  Blob* blob = ws_->CreateBlob(key);
232  ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
233  }
234  *total_loaded_blobs += loaded_blobs;
235  }
236 
237  void extractFrom(
238  int db_id,
239  Cursor* cursor,
240  const vector<Blob*>& outputs,
241  std::unordered_map<string, BlobState>* blob_states,
242  int* total_loaded_blobs) {
243  CAFFE_ENFORCE(cursor);
244  int loaded_blobs = 0;
245  for (; cursor->Valid(); cursor->Next()) {
246  const auto key = buildBlobNameFromDbKey(cursor->key());
247  if (!output_indices_.count(key)) {
248  VLOG(1) << "Key " << key << " not used. Skipping.";
249  } else {
250  if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
251  CAFFE_THROW("Duplicate Key ", key, " is found!\n");
252  } else {
253  key_to_dbid_[key] = db_id;
254  }
255 
256  VLOG(2) << "Deserializing blob " << key;
257  BlobProto proto;
258  CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
259  if (!keep_device_) {
260  // If we are not keeping the device as the one specified in the
261  // proto, we will set the current device.
262  SetCurrentDevice(&proto);
263  }
264  auto blobIndex = output_indices_[key];
265  Blob* blob = outputs.at(blobIndex);
266  ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
267 
268  if (*total_loaded_blobs + loaded_blobs == OutputSize()) {
269  break;
270  }
271  }
272  }
273 
274  *total_loaded_blobs += loaded_blobs;
275  }
276 
277  string buildBlobNameFromDbKey(const string& dbKey) {
278  string key = dbKey.substr(0, dbKey.find(kChunkIdSeparator));
279  if (!strip_prefix_.empty()) {
280  auto match_pos = key.find(strip_prefix_);
281  if (match_pos != string::npos) {
282  key = key.substr(match_pos + strip_prefix_.size());
283  }
284  }
285  key = add_prefix_ + key;
286  return key;
287  }
288 
289  private:
290  // We are tracking sizes of already read tensor parts while reading data
291  // chunks. This way we can make sure that all chunks were loaded in the end.
292  void ProcessBlob(
293  Blob* blob,
294  const BlobProto& proto,
295  std::unordered_map<string, BlobState>* blob_states_ptr,
296  const string& key,
297  int* loaded_blobs) {
298  auto& blob_states = *blob_states_ptr;
299  if (blob_states.count(key) == 0) {
300  // We reset the blob so that any existing content is destroyed. This
301  // is to guaranee correct device placement: if we are deserializing
302  // into a TensorCUDA, without explicit Reset we might be loading data
303  // into an existing TensorCUDA that has pre-allocated memory on a
304  // different GPU.
305  blob->Reset();
306  }
307  DeserializeBlob(proto, blob);
308  if (proto.has_content_num_chunks()) {
309  if (!blob_states.count(key)) {
310  blob_states[key] = BlobState(proto.content_num_chunks());
311  }
312  CAFFE_ENFORCE(
313  blob_states[key]
314  .seen_chunks_ids.insert(proto.content_chunk_id())
315  .second,
316  "Chunk with the same id has occured twice for: ",
317  key);
318  CAFFE_ENFORCE(
319  proto.content_chunk_id() >= 0 &&
320  proto.content_chunk_id() < blob_states[key].total_size,
321  "Chunk id has to be not less than 0 and "
322  "less than content_num_chunks for key: ",
323  key);
324  blob_states[key].current_size++;
325  CAFFE_ENFORCE(
326  !blob_states[key].is_tensor,
327  "Proto with content_chunks can not store tensor: ",
328  key);
329  CAFFE_ENFORCE(
330  blob_states[key].current_size <= blob_states[key].total_size,
331  "Found an extra part for an already filled blob: ",
332  key);
333  if (blob_states[key].current_size == blob_states[key].total_size) {
334  (*loaded_blobs)++;
335  }
336  return;
337  }
338  if (!proto.has_tensor()) {
339  // If blob is divided into chunks the field content_chunks has to be set,
340  // otherwise only tensors can be seen multiple times as chunks.
341  CAFFE_ENFORCE(blob_states.count(key) == 0, "Blob duplicated: ", key);
342  blob_states[key] = BlobState();
343  (*loaded_blobs)++;
344  return;
345  }
346  CAFFE_ENFORCE(proto.has_tensor());
347  if (blob_states.count(key)) {
348  CAFFE_ENFORCE(blob_states[key].is_tensor, "Must be tensor ", key);
349  CAFFE_ENFORCE(
350  blob_states[key].current_size < blob_states[key].total_size,
351  "Found an extra part for an already filled tensor: ",
352  key);
353  CAFFE_ENFORCE(
354  proto.tensor().has_segment(),
355  "Partial tensor must have a segment: ",
356  key);
357  blob_states[key].current_size +=
358  proto.tensor().segment().end() - proto.tensor().segment().begin();
359  CAFFE_ENFORCE(
360  blob_states[key].current_size <= blob_states[key].total_size,
361  "Tensor parts are bigger than target size for tensor: ",
362  key);
363  } else {
364  const auto& dims = proto.tensor().dims();
365  int64_t total_size = 1;
366  for (const auto& dim : dims) {
367  total_size *= dim;
368  }
369  auto current_size = total_size;
370  if (proto.tensor().has_segment()) {
371  current_size =
372  proto.tensor().segment().end() - proto.tensor().segment().begin();
373  }
374  blob_states[key] =
375  BlobState(total_size, current_size, true /* is_tensor */);
376  }
377 
378  if (blob_states[key].current_size == blob_states[key].total_size) {
379  (*loaded_blobs)++;
380  }
381  }
382 
383  void validateBlobStates(
384  const std::unordered_map<string, BlobState>& blob_states) {
385  for (const auto& iter : blob_states) {
386  const BlobState& blob_state = iter.second;
387  CAFFE_ENFORCE(
388  blob_state.current_size == blob_state.total_size,
389  "Data size mismatch for blob ",
390  iter.first,
391  ". Expected: ",
392  blob_state.total_size,
393  " Read: ",
394  blob_state.current_size);
395  }
396  }
397 
398  Workspace* ws_;
399  bool absolute_path_;
400  string add_prefix_;
401  string strip_prefix_;
402  string db_name_;
403  std::vector<std::string> db_names_;
404  string db_type_;
405  bool keep_device_;
406  bool load_all_;
407  bool allow_incomplete_;
408  std::map<string, int> output_indices_;
409  std::map<string, int> key_to_dbid_;
410  std::vector<std::string> blob_names_;
411 };
412 
413 template <class Context>
414 class SaveOp final : public Operator<Context> {
415  public:
416  USE_OPERATOR_CONTEXT_FUNCTIONS;
417  explicit SaveOp(const OperatorDef& operator_def, Workspace* ws)
418  : Operator<Context>(operator_def, ws),
419  ws_(ws),
420  absolute_path_(
421  this->template GetSingleArgument<int>("absolute_path", false)),
422  strip_prefix_(
423  this->template GetSingleArgument<string>("strip_prefix", "")),
424  db_name_(this->template GetSingleArgument<string>("db", "")),
425  db_type_(this->template GetSingleArgument<string>("db_type", "")),
426  blob_names_(
427  this->template GetRepeatedArgument<string>("blob_name_overrides")),
428  chunk_size_(this->template GetSingleArgument<int>(
429  "chunk_size",
430  kDefaultChunkSize)) {
431  CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name.");
432  CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
433  CAFFE_ENFORCE(
434  blob_names_.empty() ||
435  blob_names_.size() == OperatorBase::Inputs().size(),
436  "Number of blobs and blob_name_overrides mismatch.");
437  CAFFE_ENFORCE(
438  blob_names_.empty() || strip_prefix_.empty(),
439  "strip_prefix and blob_name_overrides are mutually exclusive.");
440 
441  if (blob_names_.empty()) {
442  std::set<std::string> input_names;
443  blob_names_.resize(OperatorBase::Inputs().size());
444  for (int i = 0; i < blob_names_.size(); ++i) {
445  std::string name;
446  if (strip_prefix_.empty()) {
447  name = operator_def.input(i);
448  } else {
449  auto match_pos = operator_def.input(i).find(strip_prefix_);
450  if (match_pos == string::npos) {
451  name = operator_def.input(i);
452  } else {
453  name = operator_def.input(i).substr(
454  match_pos + strip_prefix_.size(), string::npos);
455  }
456  }
457  CAFFE_ENFORCE(
458  input_names.insert(name).second, "Duplicated input: ", name);
459  blob_names_[i] = name;
460  }
461  }
462  }
463 
464  bool RunOnDevice() override {
465  string full_db_name =
466  absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_);
467  std::unique_ptr<DB> out_db(
468  caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::NEW));
469  CAFFE_ENFORCE(
470  out_db.get(),
471  "Cannot find db implementation of type ",
472  db_type_,
473  " (while trying to open ",
474  full_db_name,
475  ")");
476 
477  BlobSerializerBase::SerializationAcceptor acceptor = [&](
478  const std::string& blobName, const std::string& data) {
479  // transaction should take care of locking
480  VLOG(2) << "Sending " << blobName << " blob's data of size "
481  << data.size() << " to db";
482  auto transaction = out_db->NewTransaction();
483  transaction->Put(blobName, data);
484  transaction->Commit();
485  };
486 
487  const vector<const Blob*>& inputs = OperatorBase::Inputs();
488  for (int i = 0; i < inputs.size(); ++i) {
489  SerializeBlob(*inputs[i], blob_names_[i], acceptor, chunk_size_);
490  }
491  out_db->Close();
492  return true;
493  }
494 
495  private:
496  Workspace* ws_;
497  bool absolute_path_;
498  string strip_prefix_;
499  string db_name_;
500  string db_type_;
501  std::vector<std::string> blob_names_;
502  int chunk_size_;
503 };
504 
505 template <typename... Ts>
506 string FormatString(const string& pattern, Ts... values) {
507  // Note(Yangqing): We believe that 1024 is enough, but who are we to assert
508  // that?
509  // As a result, if things go wrong, we'll just throw the towel and quit loud.
510  // Yeah, I know that there is snprintf, but it is not present in *some*
511  // platforms unfortunately.
512  char buffer[1024];
513  int written = sprintf(buffer, pattern.c_str(), values...);
514  if (written < 0 || written + 1 > 1024) {
515  LOG(FATAL) << "FormatString fails: total bytes written " << written;
516  }
517  return string(buffer);
518  /*
519  * The following is the snprintf version that is safe; enable it one day?
520  unsigned int required =
521  std::snprintf(nullptr, 0, pattern.c_str(), values...) + 1;
522  char bytes[required];
523  std::snprintf(bytes, required, pattern.c_str(), values...);
524  return string(bytes);
525  */
526 }
527 
528 // CheckpointOp is a wrapper over a SaveFloatTensorOp that basically allows
529 // flexible naming over iterations.
530 // The file pattern in db_name should be a format string that can be passed into
531 // sprintf with an int argument specifying the current iteration. An example:
532 // "/path/to/my/checkpoint/checkpoint_at_%d.pb"
533 template <class Context>
534 class CheckpointOp final : public Operator<Context> {
535  public:
536  explicit CheckpointOp(const OperatorDef& operator_def, Workspace* ws)
537  : Operator<Context>(operator_def, ws),
538  db_pattern_(this->template GetSingleArgument<string>("db", "")),
539  every_(this->template GetSingleArgument<int>("every", 1)),
540  ws_(ws),
541  save_op_def_(operator_def) {
542  CAFFE_ENFORCE_GT(
543  db_pattern_.size(), 0, "Must specify a checkpoint file pattern.");
544  CAFFE_ENFORCE_GT(every_, 0, "Checkpoint interval should be positive.");
545  if (every_ == 1) {
546  // Just issue a warning, but it's totally legal so we don't do anything.
547  LOG(WARNING) << "It seems that we are checkpointting every iteration. "
548  << "Is that intended?";
549  }
550  save_op_def_.set_type("Save");
551  }
552 
553  USE_OPERATOR_CONTEXT_FUNCTIONS;
554 
555  bool RunOnDevice() override {
556  int64_t iter =
557  this->template Input<Tensor>(0, CPU).template data<int64_t>()[0];
558  if (iter % every_ == 0) {
559  GetMutableArgument("db", true, &save_op_def_)
560  ->set_s(FormatString(db_pattern_, iter));
561  SaveOp<Context> sub_op(save_op_def_, ws_);
562  return sub_op.Run();
563  } else {
564  return true;
565  }
566  }
567 
568  private:
569  string db_pattern_;
570  int every_;
571  Workspace* ws_;
572  OperatorDef save_op_def_;
573 };
574 
575 } // namespace caffe2
576 
577 #endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
void DeserializeBlob(const string &content, Blob *result)
Deserializes from a string containing either BlobProto or TensorProto.
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Definition: db.h:260
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
T * Reset(T *allocated)
Sets the underlying object to the allocated one.
Definition: blob.h:132
void SerializeBlob(const Blob &blob, const string &name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size)
Serializes the given blob, if possible.