1 #include "caffe2/operators/index_ops.h" 6 #include <unordered_map> 8 #include "caffe2/core/blob_serialization.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/tensor.h" 18 template <
class... Args>
20 :
Operator(std::forward<Args>(args)...),
21 maxElements_(OperatorBase::GetSingleArgument<int>(
23 std::numeric_limits<int>::max())) {}
25 bool RunOnDevice()
override {
26 *OperatorBase::Output<std::unique_ptr<IndexBase>>(0) =
27 std::unique_ptr<IndexBase>(
new Index<T>(maxElements_));
32 int64_tValue maxElements_;
37 template <
class... Args>
40 bool RunOnDevice()
override {
44 bool DoRunWithType() {
45 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
46 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
47 CAFFE_ENFORCE(dict,
"Wrong dictionary type given input keys.");
48 const auto& keys =
Input(1);
50 auto* values = Output(0, keys.sizes(), at::dtype<int64_tValue>());
53 values->template mutable_data<int64_tValue>(),
61 template <
class... Args>
63 :
Operator(std::forward<Args>(args)...),
65 OperatorBase::GetSingleArgument<int>(
"skip_first_entry", 0)) {}
67 bool RunOnDevice()
override {
71 bool DoRunWithType() {
72 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
73 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
74 CAFFE_ENFORCE(dict,
"Wrong dictionary type given input keys.");
75 const auto& keys =
Input(1);
76 const auto* keys_data = keys.data<
T>();
77 auto keys_size = keys.numel();
78 if (skipFirstEntry_) {
79 CAFFE_ENFORCE(keys.numel() > 0);
83 return dict->Load(keys_data, keys_size);
92 template <
class... Args>
94 :
Operator(std::forward<Args>(args)...) {}
96 bool RunOnDevice()
override {
97 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
101 template <
typename T>
102 bool DoRunWithType() {
103 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
104 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
106 return dict->Store(Output(0));
112 template <
class... Args>
114 :
Operator(std::forward<Args>(args)...) {}
116 bool RunOnDevice()
override {
117 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
125 template <
class... Args>
127 :
Operator(std::forward<Args>(args)...) {}
129 bool RunOnDevice()
override {
130 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
132 auto* out = Output(0, std::vector<int64_t>{}, at::dtype<int64_tValue>());
133 *out->template mutable_data<int64_tValue>() = base->Size();
148 OPERATOR_SCHEMA(IntIndexCreate)
152 Creates a dictionary that maps int32 keys to consecutive integers 153 from 1 to max_elements. Zero is reserved for unknown keys. 155 .Arg("max_elements",
"Max number of elements, including the zero entry.")
156 .Output(0,
"handler",
"Pointer to an Index instance.")
157 .ScalarType(TensorProto_DataType_UNDEFINED);
159 OPERATOR_SCHEMA(LongIndexCreate)
163 Creates a dictionary that maps int64 keys to consecutive integers 164 from 1 to max_elements. Zero is reserved for unknown keys. 166 .Arg("max_elements",
"Max number of elements, including the zero entry.")
167 .Output(0,
"handler",
"Pointer to an Index instance.")
168 .ScalarType(TensorProto_DataType_UNDEFINED);
170 OPERATOR_SCHEMA(StringIndexCreate)
174 Creates a dictionary that maps string keys to consecutive integers 175 from 1 to max_elements. Zero is reserved for unknown keys. 177 .Arg("max_elements",
"Max number of elements, including the zero entry.")
178 .Output(0,
"handle",
"Pointer to an Index instance.")
179 .ScalarType(TensorProto_DataType_UNDEFINED);
181 OPERATOR_SCHEMA(IndexGet)
185 Given an index handle and a tensor of keys, return an Int tensor of same shape 186 containing the indices for each of the keys. If the index is frozen, unknown 187 entries are given index 0. Otherwise, new entries are added into the index. 188 If an insert is necessary but max_elements has been reached, fail. 190 .Input(0, "handle",
"Pointer to an Index instance.")
191 .Input(1,
"keys",
"Tensor of keys to be looked up.")
192 .Output(0,
"indices",
"Indices for each of the keys.")
193 .ScalarType(TensorProto::INT64);
195 OPERATOR_SCHEMA(IndexFreeze)
199 Freezes the given index, disallowing creation of new index entries. 200 Should not be called concurrently with IndexGet. 202 .Input(0, "handle",
"Pointer to an Index instance.")
203 .Output(0,
"handle",
"The input handle.")
204 .EnforceInplace({{0, 0}})
205 .ScalarType(TensorProto_DataType_UNDEFINED);
207 OPERATOR_SCHEMA(IndexLoad)
211 Loads the index from the given 1-D tensor. Elements in the tensor will be given 212 consecutive indexes starting at 1. Fails if tensor contains repeated elements. 214 .Input(0, "handle",
"Pointer to an Index instance.")
215 .Input(1,
"items",
"1-D tensor with elements starting with index 1.")
216 .Output(0,
"handle",
"The input handle.")
217 .EnforceInplace({{0, 0}})
220 "If set, skips the first entry of the tensor. This allows " 221 "to load tensors that are aligned with an embedding, where the first " 222 "entry corresponds to the default 0 index entry.")
223 .ScalarType(TensorProto_DataType_UNDEFINED);
225 OPERATOR_SCHEMA(IndexStore)
229 Stores the keys of this index in a 1-D tensor. Since element 0 is reserved 230 for unknowns, the first element of the output tensor will be element of index 1. 232 .Input(0, "handle",
"Pointer to an Index instance.")
233 .Output(0,
"items",
"1-D tensor with elements starting with index 1.");
235 OPERATOR_SCHEMA(IndexSize)
239 Returns the number of entries currently present in the index. 241 .Input(0, "handle",
"Pointer to an Index instance.")
242 .Output(0,
"items",
"Scalar int64 tensor with number of entries.");
245 NO_GRADIENT(IntIndexCreate);
246 NO_GRADIENT(LongIndexCreate);
247 NO_GRADIENT(StringIndexCreate);
248 SHOULD_NOT_DO_GRADIENT(IndexFreeze);
249 SHOULD_NOT_DO_GRADIENT(IndexLoad);
250 SHOULD_NOT_DO_GRADIENT(IndexStore);
251 SHOULD_NOT_DO_GRADIENT(IndexSize);
262 SerializationAcceptor acceptor)
override {
263 CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<IndexBase>>());
264 const auto& base = *
static_cast<const std::unique_ptr<IndexBase>*
>(pointer);
266 auto* tensor_out = BlobGetMutableTensor(&tensor_blob, CPU);
268 if (base->Type().Match<std::string>()) {
269 doStore<std::string>(base, tensor_out);
270 }
else if (base->Type().Match<int32_t>()) {
271 doStore<int32_t>(base, tensor_out);
272 }
else if (base->Type().Match<int64_t>()) {
273 doStore<int64_t>(base, tensor_out);
275 CAFFE_THROW(
"Index of this type can't be serialized.");
279 tensor_out->numel() <= std::numeric_limits<int32_t>::max(),
280 "Index too large to be serialized.");
281 BlobProto blob_proto;
284 *tensor_out, name, blob_proto.mutable_tensor(), 0, tensor_out->numel());
285 blob_proto.set_name(name);
286 blob_proto.set_type(
"std::unique_ptr<caffe2::IndexBase>");
288 std::ostringstream os;
289 os << base->maxElements() <<
" " << base->isFrozen();
290 blob_proto.set_content(os.str());
292 acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
296 template <
typename T>
297 void doStore(
const std::unique_ptr<IndexBase>& base,
Tensor* tensor_out) {
298 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
299 CAFFE_ENFORCE(dict,
"Wrong dictionary type.");
300 dict->Store(tensor_out);
306 void Deserialize(
const BlobProto& proto,
Blob* blob)
override {
309 deser.Deserialize(proto, &tensor_blob);
311 std::istringstream is(proto.content());
312 int64_t maxElements{std::numeric_limits<int64_t>::max()};
313 bool isFrozen{
false};
314 is >> maxElements >> isFrozen;
316 auto& tensor_in = tensor_blob.template Get<Tensor>();
317 auto* base = blob->template GetMutable<std::unique_ptr<IndexBase>>();
319 if (tensor_in.IsType<std::string>()) {
320 doLoad<std::string>(base, maxElements, tensor_in);
321 }
else if (tensor_in.IsType<int32_t>()) {
322 doLoad<int32_t>(base, maxElements, tensor_in);
323 }
else if (tensor_in.IsType<int64_t>()) {
324 doLoad<int64_t>(base, maxElements, tensor_in);
326 CAFFE_THROW(
"Index of this type cannot be deserialized.");
335 template <
typename T>
337 std::unique_ptr<IndexBase>* base,
339 const Tensor& tensor_in) {
340 base->reset(
new Index<T>(maxElements));
341 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base->get());
342 dict->Load(tensor_in.data<
T>(), tensor_in.numel());
346 CAFFE_KNOWN_TYPE(std::unique_ptr<caffe2::IndexBase>);
348 REGISTER_BLOB_SERIALIZER(
349 (TypeMeta::Id<std::unique_ptr<caffe2::IndexBase>>()),
351 REGISTER_BLOB_DESERIALIZER(
352 std::unique_ptr<caffe2::IndexBase>,
Blob is a general container that hosts a typed pointer.
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
void Serialize(const void *pointer, TypeMeta typeMeta, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
BlobSerializerBase is an abstract class that serializes a blob to a string.