1 #ifndef CAFFE2_OPERATORS_INDEX_OPS_H_ 2 #define CAFFE2_OPERATORS_INDEX_OPS_H_ 7 #include <unordered_map> 9 #include "caffe2/core/blob_serialization.h" 10 #include "caffe2/core/operator.h" 11 #include "caffe2/core/tensor.h" 15 using IndexKeyTypes = TensorTypes<int32_t, int64_t, std::string>;
16 using int64_tValue = int64_t;
22 : maxElements_{maxElements}, meta_(type), frozen_{
false} {}
28 bool isFrozen()
const {
32 int64_t maxElements()
const {
43 std::lock_guard<std::mutex> guard(dictMutex_);
50 int64_tValue nextId_{1};
51 std::atomic<bool> frozen_{
false};
52 std::mutex dictMutex_;
57 explicit Index(int64_tValue maxElements)
58 :
IndexBase(maxElements, TypeMeta::Make<T>()) {}
60 void Get(
const T* keys, int64_tValue* values,
size_t numKeys) {
62 FrozenGet(keys, values, numKeys);
65 std::lock_guard<std::mutex> lock(dictMutex_);
66 for (
int i = 0; i < numKeys; ++i) {
67 auto it = dict_.find(keys[i]);
68 if (it != dict_.end()) {
69 values[i] = it->second;
70 }
else if (nextId_ < maxElements_) {
71 auto newValue = nextId_++;
72 dict_.insert({keys[i], newValue});
75 CAFFE_THROW(
"Dict max size reached");
80 bool Load(
const T* keys,
size_t numKeys) {
82 numKeys <= maxElements_,
83 "Cannot load index: Tensor is larger than max_elements.");
85 for (
int i = 0; i < numKeys; ++i) {
87 dict.insert({keys[i], i + 1}).second,
88 "Repeated elements found: cannot load into dictionary.");
92 std::lock_guard<std::mutex> lock(dictMutex_);
95 nextId_ = numKeys + 1;
101 std::lock_guard<std::mutex> lock(dictMutex_);
102 out->Resize(nextId_ - 1);
103 auto outData = out->template mutable_data<T>();
104 for (
const auto& entry : dict_) {
105 outData[entry.second - 1] = entry.first;
111 void FrozenGet(
const T* keys, int64_tValue* values,
size_t numKeys) {
112 for (
int i = 0; i < numKeys; ++i) {
113 auto it = dict_.find(keys[i]);
114 values[i] = it != dict_.end() ? it->second : 0;
118 std::unordered_map<T, int64_tValue> dict_;
123 #endif // CAFFE2_OPERATORS_INDEX_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...