Caffe2 - C++ API
A deep learning, cross platform ML framework
index_ops.h
1 #ifndef CAFFE2_OPERATORS_INDEX_OPS_H_
2 #define CAFFE2_OPERATORS_INDEX_OPS_H_
3 
4 #include <limits>
5 #include <mutex>
6 #include <sstream>
7 #include <unordered_map>
8 #include <vector>
9 #include "caffe2/core/blob_serialization.h"
10 #include "caffe2/core/operator.h"
11 #include "caffe2/core/tensor.h"
12 
13 namespace caffe2 {
14 namespace {
15 using IndexKeyTypes = TensorTypes<int32_t, int64_t, std::string>;
16 using int64_tValue = int64_t;
17 } // namespace
18 
19 struct IndexBase {
20  public:
21  IndexBase(int64_tValue maxElements, const TypeMeta& type)
22  : maxElements_{maxElements}, meta_(type), frozen_{false} {}
23 
24  void Freeze() {
25  frozen_ = true;
26  }
27 
28  bool isFrozen() const {
29  return frozen_;
30  }
31 
32  int64_t maxElements() const {
33  return maxElements_;
34  }
35 
36  virtual ~IndexBase() {}
37 
38  const TypeMeta& Type() const {
39  return meta_;
40  }
41 
42  int64_tValue Size() {
43  std::lock_guard<std::mutex> guard(dictMutex_);
44  return nextId_;
45  }
46 
47  protected:
48  int64_t maxElements_;
49  TypeMeta meta_;
50  int64_tValue nextId_{1}; // guarded by dictMutex_
51  std::atomic<bool> frozen_{false};
52  std::mutex dictMutex_;
53 };
54 
55 template <typename T>
56 struct Index : IndexBase {
57  explicit Index(int64_tValue maxElements)
58  : IndexBase(maxElements, TypeMeta::Make<T>()) {}
59 
60  void Get(const T* keys, int64_tValue* values, size_t numKeys) {
61  if (frozen_) {
62  FrozenGet(keys, values, numKeys);
63  return;
64  }
65  std::lock_guard<std::mutex> lock(dictMutex_);
66  for (int i = 0; i < numKeys; ++i) {
67  auto it = dict_.find(keys[i]);
68  if (it != dict_.end()) {
69  values[i] = it->second;
70  } else if (nextId_ < maxElements_) {
71  auto newValue = nextId_++;
72  dict_.insert({keys[i], newValue});
73  values[i] = newValue;
74  } else {
75  CAFFE_THROW("Dict max size reached");
76  }
77  }
78  }
79 
80  bool Load(const T* keys, size_t numKeys) {
81  CAFFE_ENFORCE(
82  numKeys <= maxElements_,
83  "Cannot load index: Tensor is larger than max_elements.");
84  decltype(dict_) dict;
85  for (int i = 0; i < numKeys; ++i) {
86  CAFFE_ENFORCE(
87  dict.insert({keys[i], i + 1}).second,
88  "Repeated elements found: cannot load into dictionary.");
89  }
90  // assume no `get` is inflight while this happens
91  {
92  std::lock_guard<std::mutex> lock(dictMutex_);
93  // let the old dict get destructed outside of the lock
94  dict_.swap(dict);
95  nextId_ = numKeys + 1;
96  }
97  return true;
98  }
99 
100  bool Store(Tensor* out) {
101  std::lock_guard<std::mutex> lock(dictMutex_);
102  out->Resize(nextId_ - 1);
103  auto outData = out->template mutable_data<T>();
104  for (const auto& entry : dict_) {
105  outData[entry.second - 1] = entry.first;
106  }
107  return true;
108  }
109 
110  private:
111  void FrozenGet(const T* keys, int64_tValue* values, size_t numKeys) {
112  for (int i = 0; i < numKeys; ++i) {
113  auto it = dict_.find(keys[i]);
114  values[i] = it != dict_.end() ? it->second : 0;
115  }
116  }
117 
118  std::unordered_map<T, int64_tValue> dict_;
119 };
120 
121 } // namespace caffe2
122 
123 #endif // CAFFE2_OPERATORS_INDEX_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324