Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor_protos_db_input.h
1 #ifndef CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
2 #define CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
3 
4 #include <iostream>
5 #include <mutex>
6 
7 #include "caffe2/core/db.h"
8 #include "caffe2/operators/prefetch_op.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class TensorProtosDBInput final : public PrefetchOperator<Context> {
14  public:
15  using OperatorBase::OutputSize;
17  explicit TensorProtosDBInput(const OperatorDef& operator_def, Workspace* ws);
20  }
21 
22  bool Prefetch() override;
23  bool CopyPrefetched() override;
24 
25  private:
26  // Prefetch will always just happen on the CPU side.
27  vector<Blob> prefetched_blobs_;
28  int batch_size_;
29  bool shape_inferred_ = false;
30  string key_;
31  string value_;
32 };
33 
34 template <class Context>
36  const OperatorDef& operator_def,
37  Workspace* ws)
38  : PrefetchOperator<Context>(operator_def, ws),
39  prefetched_blobs_(operator_def.output_size()),
40  batch_size_(
41  this->template GetSingleArgument<int>("batch_size", 0)) {}
42 
43 template <class Context>
45  const db::DBReader& reader = this->template Input<db::DBReader>(0);
46  TensorDeserializer deserializer;
47  if (batch_size_ == 0) {
48  // We do not need to construct a batch. As a result, we will simply
49  // deserialize everything into the target prefetched blob.
50  reader.Read(&key_, &value_);
51  TensorProtos protos;
52  CAFFE_ENFORCE(protos.ParseFromString(value_));
53  CAFFE_ENFORCE(protos.protos_size() == OutputSize());
54  for (int i = 0; i < protos.protos_size(); ++i) {
55  if (protos.protos(i).has_device_detail()) {
56  protos.mutable_protos(i)->clear_device_detail();
57  }
58  BlobSetTensor(
59  &prefetched_blobs_[i], deserializer.Deserialize(protos.protos(i)));
60  // deserializer.Deserialize(
61  // protos.protos(i), BlobGetMutableTensor(&prefetched_blobs_[i],
62  // CPU));
63  }
64  } else {
65  for (int item_id = 0; item_id < batch_size_; ++item_id) {
66  reader.Read(&key_, &value_);
67  TensorProtos protos;
68  CAFFE_ENFORCE(protos.ParseFromString(value_));
69  CAFFE_ENFORCE(protos.protos_size() == OutputSize());
70  // Note: shape_inferred_ is ignored, we'll always get dimensions from
71  // proto
72  for (int i = 0; i < protos.protos_size(); ++i) {
73  vector<int64_t> dims(
74  protos.protos(i).dims().begin(), protos.protos(i).dims().end());
75  dims.insert(dims.begin(), batch_size_);
76  if (protos.protos(i).has_device_detail()) {
77  protos.mutable_protos(i)->clear_device_detail();
78  }
79  Tensor src = deserializer.Deserialize(protos.protos(i));
80  Tensor* dst = BlobGetMutableTensor(
81  &prefetched_blobs_[i], dims, at::dtype(src.dtype()).device(CPU));
82  DCHECK_EQ(src.numel() * batch_size_, dst->numel());
83  this->context_.CopyItemsSameDevice(
84  src.dtype(),
85  src.numel(),
86  src.raw_data(),
87  static_cast<char*>(dst->raw_mutable_data(src.dtype())) +
88  src.nbytes() * item_id);
89  }
90  }
91  }
92  return true;
93 }
94 
95 template <class Context>
97  for (int i = 0; i < OutputSize(); ++i) {
98  OperatorBase::template Output<Tensor>(i, Context::GetDeviceType())
99  ->CopyFrom(
100  prefetched_blobs_[i].template Get<TensorCPU>(), /* async */ true);
101  }
102  return true;
103 }
104 
105 } // namespace caffe2
106 
107 #endif // CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:228
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13