1 #ifndef CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ 2 #define CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ 7 #include "caffe2/core/db.h" 8 #include "caffe2/operators/prefetch_op.h" 12 template <
class Context>
15 using OperatorBase::OutputSize;
22 bool Prefetch()
override;
23 bool CopyPrefetched()
override;
27 vector<Blob> prefetched_blobs_;
29 bool shape_inferred_ =
false;
34 template <
class Context>
36 const OperatorDef& operator_def,
39 prefetched_blobs_(operator_def.output_size()),
41 this->
template GetSingleArgument<int>(
"batch_size", 0)) {}
43 template <
class Context>
45 const db::DBReader& reader = this->
template Input<db::DBReader>(0);
47 if (batch_size_ == 0) {
50 reader.
Read(&key_, &value_);
52 CAFFE_ENFORCE(protos.ParseFromString(value_));
53 CAFFE_ENFORCE(protos.protos_size() == OutputSize());
54 for (
int i = 0; i < protos.protos_size(); ++i) {
55 if (protos.protos(i).has_device_detail()) {
56 protos.mutable_protos(i)->clear_device_detail();
59 &prefetched_blobs_[i], deserializer.Deserialize(protos.protos(i)));
65 for (
int item_id = 0; item_id < batch_size_; ++item_id) {
66 reader.
Read(&key_, &value_);
68 CAFFE_ENFORCE(protos.ParseFromString(value_));
69 CAFFE_ENFORCE(protos.protos_size() == OutputSize());
72 for (
int i = 0; i < protos.protos_size(); ++i) {
74 protos.protos(i).dims().begin(), protos.protos(i).dims().end());
75 dims.insert(dims.begin(), batch_size_);
76 if (protos.protos(i).has_device_detail()) {
77 protos.mutable_protos(i)->clear_device_detail();
79 Tensor src = deserializer.Deserialize(protos.protos(i));
80 Tensor* dst = BlobGetMutableTensor(
81 &prefetched_blobs_[i], dims, at::dtype(src.dtype()).device(CPU));
82 DCHECK_EQ(src.numel() * batch_size_, dst->numel());
83 this->context_.CopyItemsSameDevice(
87 static_cast<char*
>(dst->raw_mutable_data(src.dtype())) +
88 src.nbytes() * item_id);
95 template <
class Context>
97 for (
int i = 0; i < OutputSize(); ++i) {
98 OperatorBase::template Output<Tensor>(i, Context::GetDeviceType())
100 prefetched_blobs_[i].
template Get<TensorCPU>(),
true);
107 #endif // CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
A reader wrapper for DB that also allows us to serialize it.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...