Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor_protos_db_input.h
1 
17 #ifndef CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
18 #define CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
19 
20 #include <iostream>
21 #include <mutex>
22 
23 #include "caffe2/core/db.h"
24 #include "caffe2/operators/prefetch_op.h"
25 
26 namespace caffe2 {
27 
28 template <class Context>
29 class TensorProtosDBInput final : public PrefetchOperator<Context> {
30  public:
31  using OperatorBase::OutputSize;
33  explicit TensorProtosDBInput(const OperatorDef& operator_def, Workspace* ws);
36  }
37 
38  bool Prefetch() override;
39  bool CopyPrefetched() override;
40 
41  private:
42  // Prefetch will always just happen on the CPU side.
43  vector<Blob> prefetched_blobs_;
44  int batch_size_;
45  bool shape_inferred_ = false;
46  string key_;
47  string value_;
48 };
49 
50 template <class Context>
52  const OperatorDef& operator_def,
53  Workspace* ws)
54  : PrefetchOperator<Context>(operator_def, ws),
55  prefetched_blobs_(operator_def.output_size()),
56  batch_size_(
57  OperatorBase::template GetSingleArgument<int>("batch_size", 0)) {}
58 
59 template <class Context>
61  const db::DBReader& reader = OperatorBase::Input<db::DBReader>(0);
62  TensorDeserializer<CPUContext> deserializer;
63  if (batch_size_ == 0) {
64  // We do not need to construct a batch. As a result, we will simply
65  // deserialize everything into the target prefetched blob.
66  reader.Read(&key_, &value_);
67  TensorProtos protos;
68  CAFFE_ENFORCE(protos.ParseFromString(value_));
69  CAFFE_ENFORCE(protos.protos_size() == OutputSize());
70  for (int i = 0; i < protos.protos_size(); ++i) {
71  if (protos.protos(i).has_device_detail()) {
72  protos.mutable_protos(i)->clear_device_detail();
73  }
74  deserializer.Deserialize(
75  protos.protos(i),
76  prefetched_blobs_[i].template GetMutable<TensorCPU>());
77  }
78  } else {
79  vector<TensorCPU> temp_tensors(OutputSize());
80  for (int item_id = 0; item_id < batch_size_; ++item_id) {
81  reader.Read(&key_, &value_);
82  TensorProtos protos;
83  CAFFE_ENFORCE(protos.ParseFromString(value_));
84  CAFFE_ENFORCE(protos.protos_size() == OutputSize());
85  if (!shape_inferred_) {
86  // First, set the shape of all the blobs.
87  for (int i = 0; i < protos.protos_size(); ++i) {
88  vector<int> dims(
89  protos.protos(i).dims().begin(), protos.protos(i).dims().end());
90  dims.insert(dims.begin(), batch_size_);
91  prefetched_blobs_[i].template GetMutable<TensorCPU>()->Resize(dims);
92  }
93  }
94  for (int i = 0; i < protos.protos_size(); ++i) {
95  TensorCPU* dst = prefetched_blobs_[i].template GetMutable<TensorCPU>();
96  TensorCPU& src = temp_tensors[i];
97  if (protos.protos(i).has_device_detail()) {
98  protos.mutable_protos(i)->clear_device_detail();
99  }
100  deserializer.Deserialize(protos.protos(i), &src);
101  DCHECK_EQ(src.size() * batch_size_, dst->size());
102  this->context_.template CopyItems<CPUContext, CPUContext>(
103  src.meta(),
104  src.size(),
105  src.raw_data(),
106  static_cast<char*>(dst->raw_mutable_data(src.meta())) +
107  src.nbytes() * item_id);
108  }
109  }
110  }
111  return true;
112 }
113 
114 template <class Context>
116  for (int i = 0; i < OutputSize(); ++i) {
117  OperatorBase::Output<Tensor<Context>>(i)->CopyFrom(
118  prefetched_blobs_[i].template Get<TensorCPU>(), &this->context_);
119  }
120  return true;
121 }
122 
123 } // namespace caffe2
124 
125 #endif // CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:238
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:160
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:609
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
TensorDeserializer is the deserializer for Tensors.
Copyright (c) 2016-present, Facebook, Inc.
void * raw_mutable_data(const TypeMeta &meta)
Returns a mutable raw pointer of the underlying storage.
Definition: tensor.h:526