Caffe2 - C++ API
A deep learning, cross platform ML framework
dataset_ops.h
1 
17 #ifndef CAFFE2_OPERATORS_DATASET_OPS_H_
18 #define CAFFE2_OPERATORS_DATASET_OPS_H_
19 
20 #include <memory>
21 #include <mutex>
22 #include <string>
23 #include <vector>
24 #include "caffe2/core/blob.h"
25 #include "caffe2/core/blob_serialization.h"
26 #include "caffe2/core/tensor.h"
27 
28 namespace caffe2 {
29 namespace dataset_ops {
30 
31 // used for lengths tensors in the dataset
32 using TLength = int32_t;
33 // used for all internal dataset operations (offsets, sizes to read, etc.)
34 using TOffset = int64_t;
35 
40 class TreeIterator {
41  public:
42  struct FieldDesc {
43  int id;
44  int lengthFieldId = -1;
45  std::string name;
46  };
47 
48  explicit TreeIterator(const std::vector<std::string>& fields);
49 
50  void advance(
51  const std::vector<const TLength*>& lengths,
52  std::vector<TOffset>& offsets,
53  std::vector<TOffset>& sizes,
54  std::vector<TOffset>& limits,
55  TOffset num);
56 
57  // Corresponds to the number of fields that have "length" as its last name
58  int numLengthFields() const {
59  return lengthFieldIds_.size();
60  }
61 
62  // Corresponds to the number of length fields + 1 (for the top-level domain)
63  int numOffsetFields() const {
64  return numLengthFields() + 1;
65  }
66 
67  // Get lengthField description for the given field
68  const FieldDesc* lengthFieldFor(const FieldDesc& desc) {
69  return (desc.lengthFieldId == -1)
70  ? nullptr
71  : &fields_.at(lengthFieldIds_.at(desc.lengthFieldId));
72  }
73 
74  // Get lengthField description for the given lengthFieldId, where
75  // 0 <= lengthFieldId < numLengthFields()
76  const FieldDesc& lengthField(int lengthFieldId) {
77  return fields_.at(lengthFieldIds_.at(lengthFieldId));
78  }
79 
80  // Returns the index into the 'offset' vector for the given field.
81  int offsetFieldIdFor(const FieldDesc& fieldDesc) {
82  return fieldDesc.lengthFieldId + 1;
83  }
84 
85  // Returns the field description for all fields.
86  const std::vector<FieldDesc>& fields() {
87  return fields_;
88  }
89 
90  const std::vector<int>& lengthFieldIds() const {
91  return lengthFieldIds_;
92  }
93 
94  private:
95  // Description of each field
96  std::vector<FieldDesc> fields_;
97  // Index into fields_ above for the fields that are lengths.
98  std::vector<int> lengthFieldIds_;
99 };
100 
101 class TreeCursor {
102  public:
103  explicit TreeCursor(const TreeIterator& iterator) : it(iterator) {}
104  std::vector<TOffset> offsets;
105  std::mutex mutex_;
106  TreeIterator it;
107 };
108 
113 class TreeWalker {
114  public:
115  TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor);
116 
117  // Returns the number of records in a dataset
118  inline TOffset size() const {
119  return limits_.at(0);
120  }
121 
122  void advance();
123 
124  private:
125  inline const TensorCPU& input(int32_t idx) const {
126  return inputs_[idx]->Get<TensorCPU>();
127  }
128 
129  // TODO: Change to fieldDesc
130  inline const TreeIterator::FieldDesc& field(int idx) const {
131  return cursor_.it.fields().at(idx);
132  }
133 
134  inline int lengthIdx(int fieldId) const {
135  return field(fieldId).lengthFieldId + 1;
136  }
137 
138  inline TOffset offset(int fieldId) const {
139  return prevOffsets_[lengthIdx(fieldId)];
140  }
141 
142  std::vector<TIndex> fieldDim(int fieldId) const;
143 
144  void* fieldPtr(int fieldId) const;
145 
146  public:
147  // Simple Proxy class to expose nicer API for field access
148  class Field {
149  public:
150  Field(TreeWalker& walker, int fieldId)
151  : walker_(walker), fieldId_(fieldId) {}
152 
153  inline std::vector<TIndex> dim() const {
154  return walker_.fieldDim(fieldId_);
155  }
156 
157  inline TIndex size() const {
158  TIndex size = 1;
159  for (const auto d : dim()) {
160  size *= d;
161  }
162  return size;
163  }
164 
165  inline const TypeMeta& meta() const {
166  return walker_.input(fieldId_).meta();
167  }
168 
169  inline void* ptr() const {
170  return walker_.fieldPtr(fieldId_);
171  }
172 
173  int fieldId() const {
174  return fieldId_;
175  }
176 
177  inline TOffset offset() const {
178  return walker_.offset(fieldId_);
179  }
180 
181  private:
182  const TreeWalker& walker_;
183  const int fieldId_;
184  };
185 
186  // Notice that a reference is returned. If advance() is called the fields will
187  // be updated to represent the new state.
188  inline const std::vector<Field>& fields() const {
189  return fields_;
190  }
191 
192  private:
193  void gatherLengthData();
194 
195  void gatherSizeLimits();
196 
197  const vector<const Blob*>& inputs_;
198  TreeCursor& cursor_;
199  std::vector<Field> fields_;
200 
201  std::vector<const TLength*> lengths_;
202  std::vector<TOffset> limits_;
203  std::vector<TOffset> sizes_;
204  std::vector<TOffset> offsets_;
205  std::vector<TOffset> prevOffsets_;
206 };
207 
208 using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;
209 
210 template <class Context>
211 using TensorVectorPtr = std::unique_ptr<std::vector<Tensor<Context>>>;
212 
214  public:
215  void Serialize(
216  const Blob& blob,
217  const string& name,
218  BlobSerializerBase::SerializationAcceptor acceptor) override;
219 };
220 
222  public:
223  void Deserialize(const BlobProto& proto, Blob* blob) override;
224 };
225 
226 } // namespace dataset_ops
227 } // namespace caffe2
228 
229 #endif // CAFFE2_OPERATORS_DATASET_OPS_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Copyright (c) 2016-present, Facebook, Inc.
Simple wrapper class allowing an easy traversal of the tensors representing the hirerarchical structu...
Definition: dataset_ops.h:113
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:104
BlobSerializerBase is an abstract class that serializes a blob to a string.
Provides functionality to iterate across a list of tensors where some of those tensors represent leng...
Definition: dataset_ops.h:40