Caffe2 - C++ API
A deep learning, cross platform ML framework
index_ops.cc
1 
17 #include <atomic>
18 #include <limits>
19 #include <mutex>
20 #include <sstream>
21 #include <unordered_map>
22 #include <vector>
23 #include "caffe2/core/blob_serialization.h"
24 #include "caffe2/core/operator.h"
25 #include "caffe2/core/tensor.h"
26 
27 namespace caffe2 {
28 namespace {
29 using IndexKeyTypes = TensorTypes<int32_t, int64_t, std::string>;
30 using TIndexValue = int64_t;
31 } // namespace
32 
33 struct IndexBase {
34  public:
35  IndexBase(TIndexValue maxElements, const TypeMeta& type)
36  : maxElements_{maxElements}
37  , meta_(type)
38  , frozen_{false} {}
39 
40  void Freeze() { frozen_ = true; }
41 
42  bool isFrozen() const {
43  return frozen_;
44  }
45 
46  int64_t maxElements() const {
47  return maxElements_;
48  }
49 
50  virtual ~IndexBase() {}
51 
52  const TypeMeta& Type() const { return meta_; }
53 
54  TIndexValue Size() {
55  std::lock_guard<std::mutex> guard(dictMutex_);
56  return nextId_;
57  }
58 
59  protected:
60  int64_t maxElements_;
61  TypeMeta meta_;
62  TIndexValue nextId_{1}; // guarded by dictMutex_
63  std::atomic<bool> frozen_{false};
64  std::mutex dictMutex_;
65 };
66 
67 template<typename T>
68 struct Index: IndexBase {
69  explicit Index(TIndexValue maxElements)
70  : IndexBase(maxElements, TypeMeta::Make<T>()) {}
71 
72  void Get(const T* keys, TIndexValue* values, size_t numKeys) {
73  if (frozen_) {
74  FrozenGet(keys, values, numKeys);
75  return;
76  }
77  std::lock_guard<std::mutex> lock(dictMutex_);
78  for (int i = 0; i < numKeys; ++i) {
79  auto it = dict_.find(keys[i]);
80  if (it != dict_.end()) {
81  values[i] = it->second;
82  } else if (nextId_ < maxElements_) {
83  auto newValue = nextId_++;
84  dict_.insert({keys[i], newValue});
85  values[i] = newValue;
86  } else {
87  CAFFE_THROW("Dict max size reached");
88  }
89  }
90  }
91 
92  bool Load(const T* keys, size_t numKeys) {
93  CAFFE_ENFORCE(
94  numKeys <= maxElements_,
95  "Cannot load index: Tensor is larger than max_elements.");
96  decltype(dict_) dict;
97  for (int i = 0; i < numKeys; ++i) {
98  CAFFE_ENFORCE(
99  dict.insert({keys[i], i + 1}).second,
100  "Repeated elements found: cannot load into dictionary.");
101  }
102  // assume no `get` is inflight while this happens
103  {
104  std::lock_guard<std::mutex> lock(dictMutex_);
105  // let the old dict get destructed outside of the lock
106  dict_.swap(dict);
107  nextId_ = numKeys + 1;
108  }
109  return true;
110  }
111 
112  template<typename Ctx>
113  bool Store(Tensor<Ctx>* out) {
114  std::lock_guard<std::mutex> lock(dictMutex_);
115  out->Resize(nextId_ - 1);
116  auto outData = out->template mutable_data<T>();
117  for (const auto& entry : dict_) {
118  outData[entry.second - 1] = entry.first;
119  }
120  return true;
121  }
122 
123  private:
124  void FrozenGet(const T* keys, TIndexValue* values, size_t numKeys) {
125  for (int i = 0; i < numKeys; ++i) {
126  auto it = dict_.find(keys[i]);
127  values[i] = it != dict_.end() ? it->second : 0;
128  }
129  }
130 
131  std::unordered_map<T, TIndexValue> dict_;
132 };
133 
134 // TODO(azzolini): support sizes larger than int32
135 template<class T>
136 class IndexCreateOp: public Operator<CPUContext> {
137  public:
138  IndexCreateOp(const OperatorDef& operator_def, Workspace* ws)
139  : Operator(operator_def, ws),
140  maxElements_(OperatorBase::GetSingleArgument<int>(
141  "max_elements",
142  std::numeric_limits<int>::max())) {}
143 
144  bool RunOnDevice() override {
145  *OperatorBase::Output<std::unique_ptr<IndexBase>>(0) =
146  std::unique_ptr<IndexBase>(new Index<T>(maxElements_));
147  return true;
148  }
149 
150  private:
151  TIndexValue maxElements_;
152 };
153 
154 class IndexGetOp: public Operator<CPUContext> {
155  public:
156  IndexGetOp(const OperatorDef& operator_def, Workspace* ws)
157  : Operator(operator_def, ws) {}
158 
159  bool RunOnDevice() override {
160  return DispatchHelper<IndexKeyTypes>::call(this, Input(1));
161  }
162  template <typename T>
163  bool DoRunWithType() {
164  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
165  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
166  CAFFE_ENFORCE(dict, "Wrong dictionary type given input keys.");
167  const auto& keys = Input(1);
168  auto* values = Output(0);
169  values->ResizeLike(keys);
170  dict->Get(keys.data<T>(), values->mutable_data<TIndexValue>(), keys.size());
171  return true;
172  }
173 };
174 
175 class IndexLoadOp: public Operator<CPUContext> {
176  public:
177  IndexLoadOp(const OperatorDef& operator_def, Workspace* ws)
178  : Operator(operator_def, ws),
179  skipFirstEntry_(
180  OperatorBase::GetSingleArgument<int>("skip_first_entry", 0)) {}
181 
182  bool RunOnDevice() override {
183  return DispatchHelper<IndexKeyTypes>::call(this, Input(1));
184  }
185  template <typename T>
186  bool DoRunWithType() {
187  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
188  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
189  CAFFE_ENFORCE(dict, "Wrong dictionary type given input keys.");
190  const auto& keys = Input(1);
191  const auto* keys_data = keys.data<T>();
192  auto keys_size = keys.size();
193  if (skipFirstEntry_) {
194  CAFFE_ENFORCE(keys.size() > 0);
195  ++keys_data;
196  --keys_size;
197  }
198  return dict->Load(keys_data, keys_size);
199  }
200 
201  private:
202  bool skipFirstEntry_;
203 };
204 
205 class IndexStoreOp: public Operator<CPUContext> {
206  public:
207  IndexStoreOp(const OperatorDef& operator_def, Workspace* ws)
208  : Operator(operator_def, ws) {}
209 
210  bool RunOnDevice() override {
211  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
212  return DispatchHelper<IndexKeyTypes>::call(this, base->Type());
213  }
214 
215  template <typename T>
216  bool DoRunWithType() {
217  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
218  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
219  CAFFE_ENFORCE(dict);
220  return dict->Store(Output(0));
221  }
222 };
223 
224 class IndexFreezeOp: public Operator<CPUContext> {
225  public:
226  IndexFreezeOp(const OperatorDef& operator_def, Workspace* ws)
227  : Operator(operator_def, ws) {}
228 
229  bool RunOnDevice() override {
230  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
231  base->Freeze();
232  return true;
233  }
234 };
235 
236 class IndexSizeOp : public Operator<CPUContext> {
237  public:
238  IndexSizeOp(const OperatorDef& operator_def, Workspace* ws)
239  : Operator(operator_def, ws) {}
240 
241  bool RunOnDevice() override {
242  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
243  auto* out = Output(0);
244  out->Resize(std::vector<TIndex>{});
245  *out->mutable_data<TIndexValue>() = base->Size();
246  return true;
247  }
248 };
249 
250 REGISTER_CPU_OPERATOR(IntIndexCreate, IndexCreateOp<int32_t>);
251 REGISTER_CPU_OPERATOR(LongIndexCreate, IndexCreateOp<int64_t>);
252 REGISTER_CPU_OPERATOR(StringIndexCreate, IndexCreateOp<std::string>);
253 
254 REGISTER_CPU_OPERATOR(IndexGet, IndexGetOp);
255 REGISTER_CPU_OPERATOR(IndexLoad, IndexLoadOp);
256 REGISTER_CPU_OPERATOR(IndexStore, IndexStoreOp);
257 REGISTER_CPU_OPERATOR(IndexFreeze, IndexFreezeOp);
258 REGISTER_CPU_OPERATOR(IndexSize, IndexSizeOp);
259 
260 OPERATOR_SCHEMA(IntIndexCreate)
261  .NumInputs(0)
262  .NumOutputs(1)
263  .SetDoc(R"DOC(
264 Creates a dictionary that maps int32 keys to consecutive integers
265 from 1 to max_elements. Zero is reserved for unknown keys.
266 )DOC")
267  .Arg("max_elements", "Max number of elements, including the zero entry.")
268  .Output(0, "handler", "Pointer to an Index instance.");
269 
270 OPERATOR_SCHEMA(LongIndexCreate)
271  .NumInputs(0)
272  .NumOutputs(1)
273  .SetDoc(R"DOC(
274 Creates a dictionary that maps int64 keys to consecutive integers
275 from 1 to max_elements. Zero is reserved for unknown keys.
276 )DOC")
277  .Arg("max_elements", "Max number of elements, including the zero entry.")
278  .Output(0, "handler", "Pointer to an Index instance.");
279 
280 OPERATOR_SCHEMA(StringIndexCreate)
281  .NumInputs(0)
282  .NumOutputs(1)
283  .SetDoc(R"DOC(
284 Creates a dictionary that maps string keys to consecutive integers
285 from 1 to max_elements. Zero is reserved for unknown keys.
286 )DOC")
287  .Arg("max_elements", "Max number of elements, including the zero entry.")
288  .Output(0, "handle", "Pointer to an Index instance.");
289 
290 OPERATOR_SCHEMA(IndexGet)
291  .NumInputs(2)
292  .NumOutputs(1)
293  .SetDoc(R"DOC(
294 Given an index handle and a tensor of keys, return an Int tensor of same shape
295 containing the indices for each of the keys. If the index is frozen, unknown
296 entries are given index 0. Otherwise, new entries are added into the index.
297 If an insert is necessary but max_elements has been reached, fail.
298 )DOC")
299  .Input(0, "handle", "Pointer to an Index instance.")
300  .Input(1, "keys", "Tensor of keys to be looked up.")
301  .Output(0, "indices", "Indices for each of the keys.");
302 
303 OPERATOR_SCHEMA(IndexFreeze)
304  .NumInputs(1)
305  .NumOutputs(1)
306  .SetDoc(R"DOC(
307 Freezes the given index, disallowing creation of new index entries.
308 Should not be called concurrently with IndexGet.
309 )DOC")
310  .Input(0, "handle", "Pointer to an Index instance.")
311  .Output(0, "handle", "The input handle.")
312  .EnforceInplace({{0, 0}});
313 
314 OPERATOR_SCHEMA(IndexLoad)
315  .NumInputs(2)
316  .NumOutputs(1)
317  .SetDoc(R"DOC(
318 Loads the index from the given 1-D tensor. Elements in the tensor will be given
319 consecutive indexes starting at 1. Fails if tensor contains repeated elements.
320 )DOC")
321  .Input(0, "handle", "Pointer to an Index instance.")
322  .Input(1, "items", "1-D tensor with elements starting with index 1.")
323  .Output(0, "handle", "The input handle.")
324  .EnforceInplace({{0, 0}})
325  .Arg(
326  "skip_first_entry",
327  "If set, skips the first entry of the tensor. This allows "
328  "to load tensors that are aligned with an embedding, where the first "
329  "entry corresponds to the default 0 index entry.");
330 
331 OPERATOR_SCHEMA(IndexStore)
332  .NumInputs(1)
333  .NumOutputs(1)
334  .SetDoc(R"DOC(
335 Stores the keys of this index in a 1-D tensor. Since element 0 is reserved
336 for unknowns, the first element of the output tensor will be element of index 1.
337 )DOC")
338  .Input(0, "handle", "Pointer to an Index instance.")
339  .Output(0, "items", "1-D tensor with elements starting with index 1.");
340 
341 OPERATOR_SCHEMA(IndexSize)
342  .NumInputs(1)
343  .NumOutputs(1)
344  .SetDoc(R"DOC(
345 Returns the number of entries currently present in the index.
346 )DOC")
347  .Input(0, "handle", "Pointer to an Index instance.")
348  .Output(0, "items", "Scalar int64 tensor with number of entries.");
349 
350 NO_GRADIENT(IndexGetOp);
351 NO_GRADIENT(IntIndexCreate);
352 NO_GRADIENT(LongIndexCreate);
353 NO_GRADIENT(StringIndexCreate);
354 SHOULD_NOT_DO_GRADIENT(IndexFreeze);
355 SHOULD_NOT_DO_GRADIENT(IndexLoad);
356 SHOULD_NOT_DO_GRADIENT(IndexStore);
357 SHOULD_NOT_DO_GRADIENT(IndexSize);
358 
360  public:
361  IndexSerializer() {}
362  ~IndexSerializer() {}
363 
364  void Serialize(
365  const Blob& blob,
366  const string& name,
367  SerializationAcceptor acceptor) override {
368  auto& base = blob.template Get<std::unique_ptr<IndexBase>>();
369  Blob tensor_blob;
370  auto* tensor_out = tensor_blob.template GetMutable<Tensor<CPUContext>>();
371 
372  if (base->Type().Match<std::string>()) {
373  doStore<std::string>(base, tensor_out);
374  } else if (base->Type().Match<int32_t>()) {
375  doStore<int32_t>(base, tensor_out);
376  } else if (base->Type().Match<int64_t>()) {
377  doStore<int64_t>(base, tensor_out);
378  } else {
379  CAFFE_THROW("Index of this type can't be serialized.");
380  }
381 
382  CAFFE_ENFORCE(
383  tensor_out->size() <= std::numeric_limits<int32_t>::max(),
384  "Index too large to be serialized.");
385  BlobProto blob_proto;
387  ser.Serialize(
388  *tensor_out, name, blob_proto.mutable_tensor(), 0, tensor_out->size());
389  blob_proto.set_name(name);
390  blob_proto.set_type("std::unique_ptr<caffe2::IndexBase>");
391 
392  std::ostringstream os;
393  os << base->maxElements() << " " << base->isFrozen();
394  blob_proto.set_content(os.str());
395 
396  acceptor(name, blob_proto.SerializeAsString());
397  }
398 
399  private:
400  template <typename T>
401  void doStore(
402  const std::unique_ptr<IndexBase>& base,
403  Tensor<CPUContext>* tensor_out) {
404  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
405  CAFFE_ENFORCE(dict, "Wrong dictionary type.");
406  dict->Store(tensor_out);
407  }
408 };
409 
411  public:
412  void Deserialize(const BlobProto& proto, Blob* blob) override {
414  Blob tensor_blob;
415  deser.Deserialize(proto, &tensor_blob);
416 
417  std::istringstream is(proto.content());
418  int64_t maxElements{std::numeric_limits<int64_t>::max()};
419  bool isFrozen{false};
420  is >> maxElements >> isFrozen;
421 
422  auto& tensor_in = tensor_blob.template Get<Tensor<CPUContext>>();
423  auto* base = blob->template GetMutable<std::unique_ptr<IndexBase>>();
424 
425  if (tensor_in.IsType<std::string>()) {
426  doLoad<std::string>(base, maxElements, tensor_in);
427  } else if (tensor_in.IsType<int32_t>()) {
428  doLoad<int32_t>(base, maxElements, tensor_in);
429  } else if (tensor_in.IsType<int64_t>()) {
430  doLoad<int64_t>(base, maxElements, tensor_in);
431  } else {
432  CAFFE_THROW("Index of this type cannot be deserialized.");
433  }
434 
435  if (isFrozen) {
436  (*base)->Freeze();
437  }
438  }
439 
440  private:
441  template <typename T>
442  void doLoad(
443  std::unique_ptr<IndexBase>* base,
444  int64_t maxElements,
445  const Tensor<CPUContext>& tensor_in) {
446  base->reset(new Index<T>(maxElements));
447  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base->get());
448  dict->Load(tensor_in.data<T>(), tensor_in.size());
449  }
450 };
451 
452 CAFFE_KNOWN_TYPE(std::unique_ptr<caffe2::IndexBase>);
453 
454 REGISTER_BLOB_SERIALIZER(
455  (TypeMeta::Id<std::unique_ptr<caffe2::IndexBase>>()),
457 REGISTER_BLOB_DESERIALIZER(
458  std::unique_ptr<caffe2::IndexBase>,
460 
461 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
const T * data() const
Returns a typed pointer of the underlying storage.
Definition: tensor.h:500
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
static CAFFE2_API CaffeTypeId Id()
Returns the unique id for the given type T.
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:609
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:304
TensorDeserializer is the deserializer for Tensors.
Copyright (c) 2016-present, Facebook, Inc.
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:104
BlobSerializerBase is an abstract class that serializes a blob to a string.