1 #ifndef CAFFE2_OPERATORS_MAP_OPS_H_ 2 #define CAFFE2_OPERATORS_MAP_OPS_H_ 8 #include <unordered_map> 12 #include "caffe2/core/blob_serialization.h" 13 #include "caffe2/core/context.h" 14 #include "caffe2/core/operator.h" 20 static constexpr
const char* name =
"unknown";
25 static constexpr
const char* name =
"int64_t";
30 static constexpr
const char* name =
"int32_t";
33 template <
typename KEY_T,
typename VALUE_T>
35 using MapType = std::unordered_map<KEY_T, VALUE_T>;
36 static string MapTypeName() {
42 using MapType64To64 = MapTypeTraits<int64_t, int64_t>::MapType;
43 using MapType64To32 = MapTypeTraits<int64_t, int32_t>::MapType;
44 using MapType32To32 = MapTypeTraits<int32_t, int32_t>::MapType;
45 using MapType32To64 = MapTypeTraits<int32_t, int64_t>::MapType;
47 template <
class Context>
50 USE_OPERATOR_CONTEXT_FUNCTIONS;
51 template <
class... Args>
56 bool RunOnDevice()
override {
57 TensorProto::DataType key_dtype =
58 static_cast<TensorProto::DataType
>(this->
template GetSingleArgument<int>(
59 "key_dtype", TensorProto_DataType_INT32));
62 this, DataTypeToTypeMeta(key_dtype));
65 template <
typename KEY_T>
66 bool DoRunWithType() {
67 TensorProto::DataType value_dtype =
68 static_cast<TensorProto::DataType
>(this->
template GetSingleArgument<int>(
69 "value_dtype", TensorProto_DataType_INT32));
73 KEY_T>::call(
this, DataTypeToTypeMeta(value_dtype));
76 template <
typename KEY_T,
typename VALUE_T>
77 bool DoRunWithType2() {
79 this->
template Output<typename MapTypeTraits<KEY_T, VALUE_T>::MapType>(MAP)
84 template <
typename KEY_T>
85 bool DoRunWithOtherType2() {
86 TensorProto::DataType value_dtype =
87 static_cast<TensorProto::DataType
>(this->
template GetSingleArgument<int>(
88 "value_dtype", TensorProto_DataType_INT32));
91 "CreateMap is not implemented on value tensor of type ",
92 DataTypeToTypeMeta(value_dtype).name(),
93 "consider adding it as a type in the DispatchHelper list");
99 template <
class Context>
102 USE_OPERATOR_CONTEXT_FUNCTIONS;
103 template <
class... Args>
106 ~KeyValueToMapOp() {}
108 bool RunOnDevice()
override {
113 template <
typename KEY_T>
114 bool DoRunWithType() {
117 KEY_T>::call(
this, Input(VALUES));
120 template <
typename KEY_T,
typename VALUE_T>
121 bool DoRunWithType2() {
122 using MapType =
typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
123 const auto& key_input = Input(KEYS);
124 const auto& value_input = Input(VALUES);
126 CAFFE_ENFORCE_EQ(key_input.numel(), value_input.numel());
128 auto* key_data = key_input.template data<KEY_T>();
129 auto* value_data = value_input.template data<VALUE_T>();
131 auto* map_data = this->
template Output<MapType>(MAP);
133 for (
int i = 0; i < key_input.numel(); ++i) {
134 map_data->emplace(key_data[i], value_data[i]);
140 template <
typename KEY_T>
141 bool DoRunWithOtherType2() {
143 "KeyValueToMap is not implemented on value tensor of type ",
144 Input(VALUES).dtype().name(),
145 "consider adding it as a type in the DispatchHelper list");
148 INPUT_TAGS(KEYS, VALUES);
152 template <
class Context>
155 USE_OPERATOR_CONTEXT_FUNCTIONS;
156 template <
class... Args>
159 ~MapToKeyValueOp() {}
161 bool RunOnDevice()
override {
166 MapType32To64>>::call(
this, OperatorBase::InputBlob(MAP));
169 template <
typename MAP_T>
170 bool DoRunWithType() {
171 using key_type =
typename MAP_T::key_type;
172 using mapped_type =
typename MAP_T::mapped_type;
173 auto& map_data = this->
template Input<MAP_T>(MAP);
175 auto* key_output = Output(KEYS, {
static_cast<int64_t
>(map_data.size())}, at::dtype<key_type>());
177 Output(VALUES, {
static_cast<int64_t
>(map_data.size())}, at::dtype<mapped_type>());
178 auto* key_data = key_output->template mutable_data<key_type>();
179 auto* value_data = value_output->template mutable_data<mapped_type>();
181 for (
const auto& it : map_data) {
182 *key_data = it.first;
183 *value_data = it.second;
192 OUTPUT_TAGS(KEYS, VALUES);
195 template <
typename KEY_T,
typename VALUE_T>
198 using MapType =
typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
204 BlobSerializerBase::SerializationAcceptor acceptor)
override {
205 CAFFE_ENFORCE(typeMeta.Match<MapType>());
206 const MapType& map_data = *
static_cast<const MapType*
>(pointer);
207 int64_t sz = map_data.size();
209 key_tensor.Resize(sz);
211 value_tensor.Resize(sz);
212 auto* key_data = key_tensor.mutable_data<KEY_T>();
213 auto* value_data = value_tensor.mutable_data<VALUE_T>();
214 for (
const auto& it : map_data) {
215 *key_data = it.first;
216 *value_data = it.second;
221 TensorProtos tensor_protos;
224 key_tensor, name, tensor_protos.add_protos(), 0, key_tensor.numel());
228 tensor_protos.add_protos(),
230 value_tensor.numel());
232 BlobProto blob_proto;
233 blob_proto.set_name(name);
235 blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos));
236 acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
240 template <
typename KEY_T,
typename VALUE_T>
243 using MapType =
typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
245 void Deserialize(
const BlobProto& proto,
Blob* blob)
override {
246 TensorProtos tensor_protos;
248 tensor_protos.ParseFromString(proto.content()),
249 "Fail to parse TensorProtos");
251 Tensor key_tensor = deser.Deserialize(tensor_protos.protos(0));
252 Tensor value_tensor = deser.Deserialize(tensor_protos.protos(1));
253 auto* key_data = key_tensor.data<KEY_T>();
254 auto* value_data = value_tensor.data<VALUE_T>();
256 auto* map_ptr = blob->template GetMutable<MapType>();
257 for (
int i = 0; i < key_tensor.numel(); ++i) {
258 map_ptr->emplace(key_data[i], value_data[i]);
265 #endif // CAFFE2_OPERATORS_MAP_OPS_H_ Blob is a general container that hosts a typed pointer.
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
void Serialize(const void *pointer, TypeMeta typeMeta, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
BlobSerializerBase is an abstract class that serializes a blob to a string.