1 #ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ 2 #define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ 6 #include <unordered_set> 8 #include "caffe2/core/blob_serialization.h" 9 #include "caffe2/core/context.h" 10 #include "caffe2/core/db.h" 11 #include "caffe2/core/logging.h" 12 #include "caffe2/core/operator.h" 13 #include "caffe2/utils/math.h" 14 #include "caffe2/utils/proto_utils.h" 23 std::set<int32_t> seen_chunks_ids;
26 int64_t total_size = 0,
27 int64_t current_size = 0,
28 bool is_tensor =
false)
29 : total_size(total_size),
30 current_size(current_size),
31 is_tensor(is_tensor) {}
37 using db::Transaction;
39 template <
class Context>
42 USE_OPERATOR_CONTEXT_FUNCTIONS;
47 this->
template GetSingleArgument<int>(
"absolute_path",
false)),
48 db_name_(this->
template GetSingleArgument<string>(
"db_name",
"")),
49 db_type_(this->
template GetSingleArgument<string>(
"db_type",
"")) {}
51 bool RunOnDevice()
override {
53 absolute_path_ ? db_name_ : (ws_->RootFolder() +
"/" + db_name_);
54 auto* output = Output(0);
56 bool* exists = output->template mutable_data<bool>();
58 *exists = caffe2::db::DBExists(db_type_, full_db_name);
69 template <
class Context>
72 USE_OPERATOR_CONTEXT_FUNCTIONS;
77 this->
template GetSingleArgument<int>(
"absolute_path",
false)),
78 add_prefix_(this->
template GetSingleArgument<string>(
"add_prefix",
"")),
80 this->
template GetSingleArgument<string>(
"strip_prefix",
"")),
81 db_name_(this->
template GetSingleArgument<string>(
"db",
"")),
82 db_names_(this->
template GetRepeatedArgument<string>(
"dbs")),
83 db_type_(this->
template GetSingleArgument<string>(
"db_type",
"")),
84 keep_device_(this->
template GetSingleArgument<int>(
"keep_device", 0)),
85 load_all_(this->
template GetSingleArgument<int>(
"load_all", 0)),
87 this->
template GetSingleArgument<bool>(
"allow_incomplete",
false)),
89 this->
template GetRepeatedArgument<string>(
"source_blob_names")) {
90 if (InputSize() == 0) {
91 CAFFE_ENFORCE_GT(db_type_.size(), 0,
"Must specify a db type.");
92 if (db_names_.empty()) {
93 CAFFE_ENFORCE_GT(db_name_.size(), 0,
"Must specify a db name.");
94 db_names_.push_back(db_name_);
97 std::set<std::string> db_name_set;
98 for (
const string& db_name : db_names_) {
99 CAFFE_ENFORCE_GT(db_name.size(), 0,
"Db name should not be empty.");
101 db_name_set.insert(db_name).second,
102 "Duplicated db name: ",
108 CAFFE_ENFORCE(blob_names_.empty() || blob_names_.size() == OutputSize(),
109 "Number of output blobs and source_blob_names mismatch.");
110 CAFFE_ENFORCE(blob_names_.empty() || strip_prefix_.empty(),
111 "strip_prefix and source_blob_names are mutually exclusive.");
112 CAFFE_ENFORCE(blob_names_.empty() || !load_all_,
113 "cannot load_all_ while using source_blob_names.");
118 if(blob_names_.empty()) {
119 for (
const string& name : operator_def.output()) {
120 blob_names_.push_back(name);
124 std::set<std::string> name_set;
125 for (
const string& name : blob_names_) {
126 CAFFE_ENFORCE(name_set.insert(name).second,
127 "Duplicated source blob name: ", name);
128 output_indices_[name] = idx++;
133 void SetCurrentDevice(BlobProto* proto);
135 bool RunOnDevice()
override {
136 int total_loaded_blobs = 0;
137 std::unordered_map<string, BlobState> blob_states;
138 if (InputSize() > 0) {
139 for (
int i = 0; i < InputSize(); ++i) {
140 const db::DBReader& reader = this->
template Input<db::DBReader>(i);
141 extract(i, reader.
cursor(), &blob_states, &total_loaded_blobs);
144 for (
int i = 0; i < db_names_.size(); ++i) {
145 string full_db_name = absolute_path_
147 : (ws_->RootFolder() +
"/" + db_names_[i]);
148 std::unique_ptr<DB> in_db(
149 caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ));
152 "Cannot find db implementation of type ",
154 " (while trying to open ",
157 std::unique_ptr<Cursor> cursor(in_db->NewCursor());
158 extract(i, cursor.get(), &blob_states, &total_loaded_blobs);
162 validateBlobStates(blob_states);
164 if (load_all_ || total_loaded_blobs == OutputSize()) {
165 VLOG(1) <<
"Loaded " << total_loaded_blobs <<
" blobs fully from db(s)";
170 if (allow_incomplete_) {
171 VLOG(1) <<
"Loaded " << total_loaded_blobs <<
" blobs out of " 172 << OutputSize() <<
" blobs from db(s).";
174 for (
const string& output_name : this->debug_def().output()) {
175 if (blob_states.count(output_name) == 0) {
176 LOG(ERROR) <<
"Failed to load blob: " << output_name;
194 std::unordered_map<string, BlobState>* blob_states,
195 int* total_loaded_blobs) {
197 extractAll(db_id, cursor, blob_states, total_loaded_blobs);
202 OperatorBase::Outputs(),
211 std::unordered_map<string, BlobState>* blob_states,
212 int* total_loaded_blobs) {
213 CAFFE_ENFORCE(cursor,
"cursor is not valid");
214 int loaded_blobs = 0;
215 for (; cursor->Valid(); cursor->Next()) {
216 const auto key = buildBlobNameFromDbKey(cursor->key());
217 if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
218 CAFFE_THROW(
"Duplicate Key ", key,
" is found!\n");
220 key_to_dbid_[key] = db_id;
225 proto.ParseFromString(cursor->value()),
"Couldn't parse Proto");
229 SetCurrentDevice(&proto);
231 Blob* blob = ws_->CreateBlob(key);
232 ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
234 *total_loaded_blobs += loaded_blobs;
240 const vector<Blob*>& outputs,
241 std::unordered_map<string, BlobState>* blob_states,
242 int* total_loaded_blobs) {
243 CAFFE_ENFORCE(cursor);
244 int loaded_blobs = 0;
245 for (; cursor->Valid(); cursor->Next()) {
246 const auto key = buildBlobNameFromDbKey(cursor->key());
247 if (!output_indices_.count(key)) {
248 VLOG(1) <<
"Key " << key <<
" not used. Skipping.";
250 if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
251 CAFFE_THROW(
"Duplicate Key ", key,
" is found!\n");
253 key_to_dbid_[key] = db_id;
256 VLOG(2) <<
"Deserializing blob " << key;
258 CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
262 SetCurrentDevice(&proto);
264 auto blobIndex = output_indices_[key];
265 Blob* blob = outputs.at(blobIndex);
266 ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
268 if (*total_loaded_blobs + loaded_blobs == OutputSize()) {
274 *total_loaded_blobs += loaded_blobs;
277 string buildBlobNameFromDbKey(
const string& dbKey) {
278 string key = dbKey.substr(0, dbKey.find(kChunkIdSeparator));
279 if (!strip_prefix_.empty()) {
280 auto match_pos = key.find(strip_prefix_);
281 if (match_pos != string::npos) {
282 key = key.substr(match_pos + strip_prefix_.size());
285 key = add_prefix_ + key;
294 const BlobProto& proto,
295 std::unordered_map<string, BlobState>* blob_states_ptr,
298 auto& blob_states = *blob_states_ptr;
299 if (blob_states.count(key) == 0) {
308 if (proto.has_content_num_chunks()) {
309 if (!blob_states.count(key)) {
310 blob_states[key] = BlobState(proto.content_num_chunks());
314 .seen_chunks_ids.insert(proto.content_chunk_id())
316 "Chunk with the same id has occured twice for: ",
319 proto.content_chunk_id() >= 0 &&
320 proto.content_chunk_id() < blob_states[key].total_size,
321 "Chunk id has to be not less than 0 and " 322 "less than content_num_chunks for key: ",
324 blob_states[key].current_size++;
326 !blob_states[key].is_tensor,
327 "Proto with content_chunks can not store tensor: ",
330 blob_states[key].current_size <= blob_states[key].total_size,
331 "Found an extra part for an already filled blob: ",
333 if (blob_states[key].current_size == blob_states[key].total_size) {
338 if (!proto.has_tensor()) {
341 CAFFE_ENFORCE(blob_states.count(key) == 0,
"Blob duplicated: ", key);
342 blob_states[key] = BlobState();
346 CAFFE_ENFORCE(proto.has_tensor());
347 if (blob_states.count(key)) {
348 CAFFE_ENFORCE(blob_states[key].is_tensor,
"Must be tensor ", key);
350 blob_states[key].current_size < blob_states[key].total_size,
351 "Found an extra part for an already filled tensor: ",
354 proto.tensor().has_segment(),
355 "Partial tensor must have a segment: ",
357 blob_states[key].current_size +=
358 proto.tensor().segment().end() - proto.tensor().segment().begin();
360 blob_states[key].current_size <= blob_states[key].total_size,
361 "Tensor parts are bigger than target size for tensor: ",
364 const auto& dims = proto.tensor().dims();
365 int64_t total_size = 1;
366 for (
const auto& dim : dims) {
369 auto current_size = total_size;
370 if (proto.tensor().has_segment()) {
372 proto.tensor().segment().end() - proto.tensor().segment().begin();
375 BlobState(total_size, current_size,
true );
378 if (blob_states[key].current_size == blob_states[key].total_size) {
383 void validateBlobStates(
384 const std::unordered_map<string, BlobState>& blob_states) {
385 for (
const auto& iter : blob_states) {
386 const BlobState& blob_state = iter.second;
388 blob_state.current_size == blob_state.total_size,
389 "Data size mismatch for blob ",
392 blob_state.total_size,
394 blob_state.current_size);
401 string strip_prefix_;
403 std::vector<std::string> db_names_;
407 bool allow_incomplete_;
408 std::map<string, int> output_indices_;
409 std::map<string, int> key_to_dbid_;
410 std::vector<std::string> blob_names_;
413 template <
class Context>
416 USE_OPERATOR_CONTEXT_FUNCTIONS;
421 this->
template GetSingleArgument<int>(
"absolute_path",
false)),
423 this->
template GetSingleArgument<string>(
"strip_prefix",
"")),
424 db_name_(this->
template GetSingleArgument<string>(
"db",
"")),
425 db_type_(this->
template GetSingleArgument<string>(
"db_type",
"")),
427 this->
template GetRepeatedArgument<string>(
"blob_name_overrides")),
428 chunk_size_(this->
template GetSingleArgument<int>(
430 kDefaultChunkSize)) {
431 CAFFE_ENFORCE_GT(db_name_.size(), 0,
"Must specify a db name.");
432 CAFFE_ENFORCE_GT(db_type_.size(), 0,
"Must specify a db type.");
434 blob_names_.empty() ||
435 blob_names_.size() == OperatorBase::Inputs().size(),
436 "Number of blobs and blob_name_overrides mismatch.");
438 blob_names_.empty() || strip_prefix_.empty(),
439 "strip_prefix and blob_name_overrides are mutually exclusive.");
441 if (blob_names_.empty()) {
442 std::set<std::string> input_names;
443 blob_names_.resize(OperatorBase::Inputs().size());
444 for (
int i = 0; i < blob_names_.size(); ++i) {
446 if (strip_prefix_.empty()) {
447 name = operator_def.input(i);
449 auto match_pos = operator_def.input(i).find(strip_prefix_);
450 if (match_pos == string::npos) {
451 name = operator_def.input(i);
453 name = operator_def.input(i).substr(
454 match_pos + strip_prefix_.size(), string::npos);
458 input_names.insert(name).second,
"Duplicated input: ", name);
459 blob_names_[i] = name;
464 bool RunOnDevice()
override {
465 string full_db_name =
466 absolute_path_ ? db_name_ : (ws_->RootFolder() +
"/" + db_name_);
467 std::unique_ptr<DB> out_db(
468 caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::NEW));
471 "Cannot find db implementation of type ",
473 " (while trying to open ",
477 BlobSerializerBase::SerializationAcceptor acceptor = [&](
478 const std::string& blobName,
const std::string& data) {
480 VLOG(2) <<
"Sending " << blobName <<
" blob's data of size " 481 << data.size() <<
" to db";
482 auto transaction = out_db->NewTransaction();
483 transaction->Put(blobName, data);
484 transaction->Commit();
487 const vector<const Blob*>& inputs = OperatorBase::Inputs();
488 for (
int i = 0; i < inputs.size(); ++i) {
489 SerializeBlob(*inputs[i], blob_names_[i], acceptor, chunk_size_);
498 string strip_prefix_;
501 std::vector<std::string> blob_names_;
505 template <
typename... Ts>
506 string FormatString(
const string& pattern, Ts... values) {
513 int written = sprintf(buffer, pattern.c_str(), values...);
514 if (written < 0 || written + 1 > 1024) {
515 LOG(FATAL) <<
"FormatString fails: total bytes written " << written;
517 return string(buffer);
533 template <
class Context>
538 db_pattern_(this->
template GetSingleArgument<string>(
"db",
"")),
539 every_(this->
template GetSingleArgument<int>(
"every", 1)),
541 save_op_def_(operator_def) {
543 db_pattern_.size(), 0,
"Must specify a checkpoint file pattern.");
544 CAFFE_ENFORCE_GT(every_, 0,
"Checkpoint interval should be positive.");
547 LOG(WARNING) <<
"It seems that we are checkpointting every iteration. " 548 <<
"Is that intended?";
550 save_op_def_.set_type(
"Save");
553 USE_OPERATOR_CONTEXT_FUNCTIONS;
555 bool RunOnDevice()
override {
557 this->
template Input<Tensor>(0, CPU).
template data<int64_t>()[0];
558 if (iter % every_ == 0) {
559 GetMutableArgument(
"db",
true, &save_op_def_)
560 ->set_s(FormatString(db_pattern_, iter));
572 OperatorDef save_op_def_;
577 #endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ Blob is a general container that hosts a typed pointer.
A reader wrapper for DB that also allows us to serialize it.
void DeserializeBlob(const string &content, Blob *result)
Deserializes from a string containing either BlobProto or TensorProto.
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
T * Reset(T *allocated)
Sets the underlying object to the allocated one.
void SerializeBlob(const Blob &blob, const string &name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size)
Serializes the given blob, if possible.