Caffe2 - C++ API
A deep learning, cross platform ML framework
index_ops.cc
1 #include "caffe2/operators/index_ops.h"
2 #include <atomic>
3 #include <limits>
4 #include <mutex>
5 #include <sstream>
6 #include <unordered_map>
7 #include <vector>
8 #include "caffe2/core/blob_serialization.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/tensor.h"
11 
12 namespace caffe2 {
13 
14 // TODO(azzolini): support sizes larger than int32
15 template <class T>
16 class IndexCreateOp : public Operator<CPUContext> {
17  public:
18  template <class... Args>
19  explicit IndexCreateOp(Args&&... args)
20  : Operator(std::forward<Args>(args)...),
21  maxElements_(OperatorBase::GetSingleArgument<int>(
22  "max_elements",
23  std::numeric_limits<int>::max())) {}
24 
25  bool RunOnDevice() override {
26  *OperatorBase::Output<std::unique_ptr<IndexBase>>(0) =
27  std::unique_ptr<IndexBase>(new Index<T>(maxElements_));
28  return true;
29  }
30 
31  private:
32  int64_tValue maxElements_;
33 };
34 
35 class IndexGetOp : public Operator<CPUContext> {
36  public:
37  template <class... Args>
38  explicit IndexGetOp(Args&&... args) : Operator(std::forward<Args>(args)...) {}
39 
40  bool RunOnDevice() override {
42  }
43  template <typename T>
44  bool DoRunWithType() {
45  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
46  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
47  CAFFE_ENFORCE(dict, "Wrong dictionary type given input keys.");
48  const auto& keys = Input(1);
49 
50  auto* values = Output(0, keys.sizes(), at::dtype<int64_tValue>());
51  dict->Get(
52  keys.data<T>(),
53  values->template mutable_data<int64_tValue>(),
54  keys.numel());
55  return true;
56  }
57 };
58 
59 class IndexLoadOp : public Operator<CPUContext> {
60  public:
61  template <class... Args>
62  explicit IndexLoadOp(Args&&... args)
63  : Operator(std::forward<Args>(args)...),
64  skipFirstEntry_(
65  OperatorBase::GetSingleArgument<int>("skip_first_entry", 0)) {}
66 
67  bool RunOnDevice() override {
69  }
70  template <typename T>
71  bool DoRunWithType() {
72  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
73  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
74  CAFFE_ENFORCE(dict, "Wrong dictionary type given input keys.");
75  const auto& keys = Input(1);
76  const auto* keys_data = keys.data<T>();
77  auto keys_size = keys.numel();
78  if (skipFirstEntry_) {
79  CAFFE_ENFORCE(keys.numel() > 0);
80  ++keys_data;
81  --keys_size;
82  }
83  return dict->Load(keys_data, keys_size);
84  }
85 
86  private:
87  bool skipFirstEntry_;
88 };
89 
90 class IndexStoreOp : public Operator<CPUContext> {
91  public:
92  template <class... Args>
93  explicit IndexStoreOp(Args&&... args)
94  : Operator(std::forward<Args>(args)...) {}
95 
96  bool RunOnDevice() override {
97  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
98  return DispatchHelper<IndexKeyTypes>::call(this, base->Type());
99  }
100 
101  template <typename T>
102  bool DoRunWithType() {
103  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
104  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
105  CAFFE_ENFORCE(dict);
106  return dict->Store(Output(0));
107  }
108 };
109 
110 class IndexFreezeOp : public Operator<CPUContext> {
111  public:
112  template <class... Args>
113  explicit IndexFreezeOp(Args&&... args)
114  : Operator(std::forward<Args>(args)...) {}
115 
116  bool RunOnDevice() override {
117  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
118  base->Freeze();
119  return true;
120  }
121 };
122 
123 class IndexSizeOp : public Operator<CPUContext> {
124  public:
125  template <class... Args>
126  explicit IndexSizeOp(Args&&... args)
127  : Operator(std::forward<Args>(args)...) {}
128 
129  bool RunOnDevice() override {
130  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
131 
132  auto* out = Output(0, std::vector<int64_t>{}, at::dtype<int64_tValue>());
133  *out->template mutable_data<int64_tValue>() = base->Size();
134  return true;
135  }
136 };
137 
138 REGISTER_CPU_OPERATOR(IntIndexCreate, IndexCreateOp<int32_t>);
139 REGISTER_CPU_OPERATOR(LongIndexCreate, IndexCreateOp<int64_t>);
140 REGISTER_CPU_OPERATOR(StringIndexCreate, IndexCreateOp<std::string>);
141 
142 REGISTER_CPU_OPERATOR(IndexGet, IndexGetOp);
143 REGISTER_CPU_OPERATOR(IndexLoad, IndexLoadOp);
144 REGISTER_CPU_OPERATOR(IndexStore, IndexStoreOp);
145 REGISTER_CPU_OPERATOR(IndexFreeze, IndexFreezeOp);
146 REGISTER_CPU_OPERATOR(IndexSize, IndexSizeOp);
147 
148 OPERATOR_SCHEMA(IntIndexCreate)
149  .NumInputs(0)
150  .NumOutputs(1)
151  .SetDoc(R"DOC(
152 Creates a dictionary that maps int32 keys to consecutive integers
153 from 1 to max_elements. Zero is reserved for unknown keys.
154 )DOC")
155  .Arg("max_elements", "Max number of elements, including the zero entry.")
156  .Output(0, "handler", "Pointer to an Index instance.")
157  .ScalarType(TensorProto_DataType_UNDEFINED);
158 
159 OPERATOR_SCHEMA(LongIndexCreate)
160  .NumInputs(0)
161  .NumOutputs(1)
162  .SetDoc(R"DOC(
163 Creates a dictionary that maps int64 keys to consecutive integers
164 from 1 to max_elements. Zero is reserved for unknown keys.
165 )DOC")
166  .Arg("max_elements", "Max number of elements, including the zero entry.")
167  .Output(0, "handler", "Pointer to an Index instance.")
168  .ScalarType(TensorProto_DataType_UNDEFINED);
169 
170 OPERATOR_SCHEMA(StringIndexCreate)
171  .NumInputs(0)
172  .NumOutputs(1)
173  .SetDoc(R"DOC(
174 Creates a dictionary that maps string keys to consecutive integers
175 from 1 to max_elements. Zero is reserved for unknown keys.
176 )DOC")
177  .Arg("max_elements", "Max number of elements, including the zero entry.")
178  .Output(0, "handle", "Pointer to an Index instance.")
179  .ScalarType(TensorProto_DataType_UNDEFINED);
180 
181 OPERATOR_SCHEMA(IndexGet)
182  .NumInputs(2)
183  .NumOutputs(1)
184  .SetDoc(R"DOC(
185 Given an index handle and a tensor of keys, return an Int tensor of same shape
186 containing the indices for each of the keys. If the index is frozen, unknown
187 entries are given index 0. Otherwise, new entries are added into the index.
188 If an insert is necessary but max_elements has been reached, fail.
189 )DOC")
190  .Input(0, "handle", "Pointer to an Index instance.")
191  .Input(1, "keys", "Tensor of keys to be looked up.")
192  .Output(0, "indices", "Indices for each of the keys.")
193  .ScalarType(TensorProto::INT64);
194 
195 OPERATOR_SCHEMA(IndexFreeze)
196  .NumInputs(1)
197  .NumOutputs(1)
198  .SetDoc(R"DOC(
199 Freezes the given index, disallowing creation of new index entries.
200 Should not be called concurrently with IndexGet.
201 )DOC")
202  .Input(0, "handle", "Pointer to an Index instance.")
203  .Output(0, "handle", "The input handle.")
204  .EnforceInplace({{0, 0}})
205  .ScalarType(TensorProto_DataType_UNDEFINED);
206 
207 OPERATOR_SCHEMA(IndexLoad)
208  .NumInputs(2)
209  .NumOutputs(1)
210  .SetDoc(R"DOC(
211 Loads the index from the given 1-D tensor. Elements in the tensor will be given
212 consecutive indexes starting at 1. Fails if tensor contains repeated elements.
213 )DOC")
214  .Input(0, "handle", "Pointer to an Index instance.")
215  .Input(1, "items", "1-D tensor with elements starting with index 1.")
216  .Output(0, "handle", "The input handle.")
217  .EnforceInplace({{0, 0}})
218  .Arg(
219  "skip_first_entry",
220  "If set, skips the first entry of the tensor. This allows "
221  "to load tensors that are aligned with an embedding, where the first "
222  "entry corresponds to the default 0 index entry.")
223  .ScalarType(TensorProto_DataType_UNDEFINED);
224 
225 OPERATOR_SCHEMA(IndexStore)
226  .NumInputs(1)
227  .NumOutputs(1)
228  .SetDoc(R"DOC(
229 Stores the keys of this index in a 1-D tensor. Since element 0 is reserved
230 for unknowns, the first element of the output tensor will be element of index 1.
231 )DOC")
232  .Input(0, "handle", "Pointer to an Index instance.")
233  .Output(0, "items", "1-D tensor with elements starting with index 1.");
234 
235 OPERATOR_SCHEMA(IndexSize)
236  .NumInputs(1)
237  .NumOutputs(1)
238  .SetDoc(R"DOC(
239 Returns the number of entries currently present in the index.
240 )DOC")
241  .Input(0, "handle", "Pointer to an Index instance.")
242  .Output(0, "items", "Scalar int64 tensor with number of entries.");
243 
244 NO_GRADIENT(IndexGetOp);
245 NO_GRADIENT(IntIndexCreate);
246 NO_GRADIENT(LongIndexCreate);
247 NO_GRADIENT(StringIndexCreate);
248 SHOULD_NOT_DO_GRADIENT(IndexFreeze);
249 SHOULD_NOT_DO_GRADIENT(IndexLoad);
250 SHOULD_NOT_DO_GRADIENT(IndexStore);
251 SHOULD_NOT_DO_GRADIENT(IndexSize);
252 
254  public:
255  IndexSerializer() {}
256  ~IndexSerializer() override {}
257 
258  void Serialize(
259  const void* pointer,
260  TypeMeta typeMeta,
261  const string& name,
262  SerializationAcceptor acceptor) override {
263  CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<IndexBase>>());
264  const auto& base = *static_cast<const std::unique_ptr<IndexBase>*>(pointer);
265  Blob tensor_blob;
266  auto* tensor_out = BlobGetMutableTensor(&tensor_blob, CPU);
267 
268  if (base->Type().Match<std::string>()) {
269  doStore<std::string>(base, tensor_out);
270  } else if (base->Type().Match<int32_t>()) {
271  doStore<int32_t>(base, tensor_out);
272  } else if (base->Type().Match<int64_t>()) {
273  doStore<int64_t>(base, tensor_out);
274  } else {
275  CAFFE_THROW("Index of this type can't be serialized.");
276  }
277 
278  CAFFE_ENFORCE(
279  tensor_out->numel() <= std::numeric_limits<int32_t>::max(),
280  "Index too large to be serialized.");
281  BlobProto blob_proto;
282  TensorSerializer ser;
283  ser.Serialize(
284  *tensor_out, name, blob_proto.mutable_tensor(), 0, tensor_out->numel());
285  blob_proto.set_name(name);
286  blob_proto.set_type("std::unique_ptr<caffe2::IndexBase>");
287 
288  std::ostringstream os;
289  os << base->maxElements() << " " << base->isFrozen();
290  blob_proto.set_content(os.str());
291 
292  acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
293  }
294 
295  private:
296  template <typename T>
297  void doStore(const std::unique_ptr<IndexBase>& base, Tensor* tensor_out) {
298  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
299  CAFFE_ENFORCE(dict, "Wrong dictionary type.");
300  dict->Store(tensor_out);
301  }
302 };
303 
305  public:
306  void Deserialize(const BlobProto& proto, Blob* blob) override {
307  TensorDeserializer deser;
308  Blob tensor_blob;
309  deser.Deserialize(proto, &tensor_blob);
310 
311  std::istringstream is(proto.content());
312  int64_t maxElements{std::numeric_limits<int64_t>::max()};
313  bool isFrozen{false};
314  is >> maxElements >> isFrozen;
315 
316  auto& tensor_in = tensor_blob.template Get<Tensor>();
317  auto* base = blob->template GetMutable<std::unique_ptr<IndexBase>>();
318 
319  if (tensor_in.IsType<std::string>()) {
320  doLoad<std::string>(base, maxElements, tensor_in);
321  } else if (tensor_in.IsType<int32_t>()) {
322  doLoad<int32_t>(base, maxElements, tensor_in);
323  } else if (tensor_in.IsType<int64_t>()) {
324  doLoad<int64_t>(base, maxElements, tensor_in);
325  } else {
326  CAFFE_THROW("Index of this type cannot be deserialized.");
327  }
328 
329  if (isFrozen) {
330  (*base)->Freeze();
331  }
332  }
333 
334  private:
335  template <typename T>
336  void doLoad(
337  std::unique_ptr<IndexBase>* base,
338  int64_t maxElements,
339  const Tensor& tensor_in) {
340  base->reset(new Index<T>(maxElements));
341  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base->get());
342  dict->Load(tensor_in.data<T>(), tensor_in.numel());
343  }
344 };
345 
346 CAFFE_KNOWN_TYPE(std::unique_ptr<caffe2::IndexBase>);
347 
348 REGISTER_BLOB_SERIALIZER(
349  (TypeMeta::Id<std::unique_ptr<caffe2::IndexBase>>()),
351 REGISTER_BLOB_DESERIALIZER(
352  std::unique_ptr<caffe2::IndexBase>,
354 
355 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
void Serialize(const void *pointer, TypeMeta typeMeta, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
BlobSerializerBase is an abstract class that serializes a blob to a string.