Caffe2 - C++ API
A deep learning, cross platform ML framework
dataset_ops.h
1 #ifndef CAFFE2_OPERATORS_DATASET_OPS_H_
2 #define CAFFE2_OPERATORS_DATASET_OPS_H_
3 
4 #include <memory>
5 #include <mutex>
6 #include <string>
7 #include <vector>
8 #include "caffe2/core/blob.h"
9 #include "caffe2/core/blob_serialization.h"
10 #include "caffe2/core/tensor.h"
11 
12 namespace caffe2 {
13 namespace dataset_ops {
14 
15 // used for lengths tensors in the dataset
16 using TLength = int32_t;
17 // used for all internal dataset operations (offsets, sizes to read, etc.)
18 using TOffset = int64_t;
19 
24 class TreeIterator {
25  public:
26  struct FieldDesc {
27  int id;
28  int lengthFieldId = -1;
29  std::string name;
30  };
31 
32  explicit TreeIterator(const std::vector<std::string>& fields);
33 
34  void advance(
35  const std::vector<const TLength*>& lengths,
36  std::vector<TOffset>& offsets,
37  std::vector<TOffset>& sizes,
38  std::vector<TOffset>& limits,
39  TOffset num);
40 
41  // Corresponds to the number of fields that have "length" as its last name
42  int numLengthFields() const {
43  return lengthFieldIds_.size();
44  }
45 
46  // Corresponds to the number of length fields + 1 (for the top-level domain)
47  int numOffsetFields() const {
48  return numLengthFields() + 1;
49  }
50 
51  // Get lengthField description for the given field
52  const FieldDesc* lengthFieldFor(const FieldDesc& desc) {
53  return (desc.lengthFieldId == -1)
54  ? nullptr
55  : &fields_.at(lengthFieldIds_.at(desc.lengthFieldId));
56  }
57 
58  // Get lengthField description for the given lengthFieldId, where
59  // 0 <= lengthFieldId < numLengthFields()
60  const FieldDesc& lengthField(int lengthFieldId) {
61  return fields_.at(lengthFieldIds_.at(lengthFieldId));
62  }
63 
64  // Returns the index into the 'offset' vector for the given field.
65  int offsetFieldIdFor(const FieldDesc& fieldDesc) {
66  return fieldDesc.lengthFieldId + 1;
67  }
68 
69  // Returns the field description for all fields.
70  const std::vector<FieldDesc>& fields() {
71  return fields_;
72  }
73 
74  const std::vector<int>& lengthFieldIds() const {
75  return lengthFieldIds_;
76  }
77 
78  private:
79  // Description of each field
80  std::vector<FieldDesc> fields_;
81  // Index into fields_ above for the fields that are lengths.
82  std::vector<int> lengthFieldIds_;
83 };
84 
85 class TreeCursor {
86  public:
87  explicit TreeCursor(const TreeIterator& iterator) : it(iterator) {}
88  std::vector<TOffset> offsets;
89  std::mutex mutex_;
90  TreeIterator it;
91 };
92 
97 class TreeWalker {
98  public:
99  TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor);
100 
101  // Returns the number of records in a dataset
102  inline TOffset size() const {
103  return limits_.at(0);
104  }
105 
106  void advance();
107 
108  private:
109  inline const TensorCPU& input(int32_t idx) const {
110  return inputs_[idx]->Get<TensorCPU>();
111  }
112 
113  // TODO: Change to fieldDesc
114  inline const TreeIterator::FieldDesc& field(int idx) const {
115  return cursor_.it.fields().at(idx);
116  }
117 
118  inline int lengthIdx(int fieldId) const {
119  return field(fieldId).lengthFieldId + 1;
120  }
121 
122  inline TOffset offset(int fieldId) const {
123  return prevOffsets_[lengthIdx(fieldId)];
124  }
125 
126  std::vector<int64_t> fieldDim(int fieldId) const;
127 
128  void* fieldPtr(int fieldId) const;
129 
130  public:
131  // Simple Proxy class to expose nicer API for field access
132  class Field {
133  public:
134  Field(TreeWalker& walker, int fieldId)
135  : walker_(walker), fieldId_(fieldId) {}
136 
137  inline std::vector<int64_t> dim() const {
138  return walker_.fieldDim(fieldId_);
139  }
140 
141  inline int64_t size() const {
142  int64_t size = 1;
143  for (const auto d : dim()) {
144  size *= d;
145  }
146  return size;
147  }
148 
149  inline const TypeMeta& meta() const {
150  return walker_.input(fieldId_).dtype();
151  }
152 
153  inline void* ptr() const {
154  return walker_.fieldPtr(fieldId_);
155  }
156 
157  int fieldId() const {
158  return fieldId_;
159  }
160 
161  inline TOffset offset() const {
162  return walker_.offset(fieldId_);
163  }
164 
165  private:
166  const TreeWalker& walker_;
167  const int fieldId_;
168  };
169 
170  // Notice that a reference is returned. If advance() is called the fields will
171  // be updated to represent the new state.
172  inline const std::vector<Field>& fields() const {
173  return fields_;
174  }
175 
176  private:
177  void gatherLengthData();
178 
179  void gatherSizeLimits();
180 
181  const vector<const Blob*>& inputs_;
182  TreeCursor& cursor_;
183  std::vector<Field> fields_;
184 
185  std::vector<const TLength*> lengths_;
186  std::vector<TOffset> limits_;
187  std::vector<TOffset> sizes_;
188  std::vector<TOffset> offsets_;
189  std::vector<TOffset> prevOffsets_;
190 };
191 
192 using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;
193 
194 using TensorVectorPtr = std::unique_ptr<std::vector<Tensor>>;
195 
197  public:
198  void Serialize(
199  const void* pointer,
200  TypeMeta typeMeta,
201  const string& name,
202  BlobSerializerBase::SerializationAcceptor acceptor) override;
203 };
204 
206  public:
207  void Deserialize(const BlobProto& proto, Blob* blob) override;
208 };
209 
210 } // namespace dataset_ops
211 } // namespace caffe2
212 
213 #endif // CAFFE2_OPERATORS_DATASET_OPS_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Simple wrapper class allowing an easy traversal of the tensors representing the hirerarchical structu...
Definition: dataset_ops.h:97
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
BlobSerializerBase is an abstract class that serializes a blob to a string.
Provides functionality to iterate across a list of tensors where some of those tensors represent leng...
Definition: dataset_ops.h:24