Caffe2 - C++ API
A deep learning, cross platform ML framework
map_ops.h
1 #ifndef CAFFE2_OPERATORS_MAP_OPS_H_
2 #define CAFFE2_OPERATORS_MAP_OPS_H_
3 
4 #include <algorithm>
5 #include <iterator>
6 #include <string>
7 #include <typeinfo>
8 #include <unordered_map>
9 #include <utility>
10 #include <vector>
11 
12 #include "caffe2/core/blob_serialization.h"
13 #include "caffe2/core/context.h"
14 #include "caffe2/core/operator.h"
15 
16 namespace caffe2 {
17 
18 template <typename T>
20  static constexpr const char* name = "unknown";
21 };
22 
23 template <>
24 struct TypeNameTraits<int64_t> {
25  static constexpr const char* name = "int64_t";
26 };
27 
28 template <>
29 struct TypeNameTraits<int32_t> {
30  static constexpr const char* name = "int32_t";
31 };
32 
33 template <typename KEY_T, typename VALUE_T>
34 struct MapTypeTraits {
35  using MapType = std::unordered_map<KEY_T, VALUE_T>;
36  static string MapTypeName() {
37  return string("(std::unordered_map<") + TypeNameTraits<KEY_T>::name + ", " +
39  }
40 };
41 
42 using MapType64To64 = MapTypeTraits<int64_t, int64_t>::MapType;
43 using MapType64To32 = MapTypeTraits<int64_t, int32_t>::MapType;
44 using MapType32To32 = MapTypeTraits<int32_t, int32_t>::MapType;
45 using MapType32To64 = MapTypeTraits<int32_t, int64_t>::MapType;
46 
47 template <class Context>
48 class CreateMapOp final : public Operator<Context> {
49  public:
50  USE_OPERATOR_CONTEXT_FUNCTIONS;
51  template <class... Args>
52  explicit CreateMapOp(Args&&... args)
53  : Operator<Context>(std::forward<Args>(args)...) {}
54  ~CreateMapOp() {}
55 
56  bool RunOnDevice() override {
57  TensorProto::DataType key_dtype =
58  static_cast<TensorProto::DataType>(this->template GetSingleArgument<int>(
59  "key_dtype", TensorProto_DataType_INT32));
60 
62  this, DataTypeToTypeMeta(key_dtype));
63  }
64 
65  template <typename KEY_T>
66  bool DoRunWithType() {
67  TensorProto::DataType value_dtype =
68  static_cast<TensorProto::DataType>(this->template GetSingleArgument<int>(
69  "value_dtype", TensorProto_DataType_INT32));
70 
71  return DispatchHelper<
73  KEY_T>::call(this, DataTypeToTypeMeta(value_dtype));
74  }
75 
76  template <typename KEY_T, typename VALUE_T>
77  bool DoRunWithType2() {
78  // clear to make sure the map is empty
79  this->template Output<typename MapTypeTraits<KEY_T, VALUE_T>::MapType>(MAP)
80  ->clear();
81  return true;
82  }
83 
84  template <typename KEY_T>
85  bool DoRunWithOtherType2() {
86  TensorProto::DataType value_dtype =
87  static_cast<TensorProto::DataType>(this->template GetSingleArgument<int>(
88  "value_dtype", TensorProto_DataType_INT32));
89 
90  CAFFE_THROW(
91  "CreateMap is not implemented on value tensor of type ",
92  DataTypeToTypeMeta(value_dtype).name(),
93  "consider adding it as a type in the DispatchHelper list");
94  }
95 
96  OUTPUT_TAGS(MAP);
97 };
98 
99 template <class Context>
100 class KeyValueToMapOp final : public Operator<Context> {
101  public:
102  USE_OPERATOR_CONTEXT_FUNCTIONS;
103  template <class... Args>
104  explicit KeyValueToMapOp(Args&&... args)
105  : Operator<Context>(std::forward<Args>(args)...) {}
106  ~KeyValueToMapOp() {}
107 
108  bool RunOnDevice() override {
110  this, Input(KEYS));
111  }
112 
113  template <typename KEY_T>
114  bool DoRunWithType() {
115  return DispatchHelper<
117  KEY_T>::call(this, Input(VALUES));
118  }
119 
120  template <typename KEY_T, typename VALUE_T>
121  bool DoRunWithType2() {
122  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
123  const auto& key_input = Input(KEYS);
124  const auto& value_input = Input(VALUES);
125 
126  CAFFE_ENFORCE_EQ(key_input.numel(), value_input.numel());
127 
128  auto* key_data = key_input.template data<KEY_T>();
129  auto* value_data = value_input.template data<VALUE_T>();
130 
131  auto* map_data = this->template Output<MapType>(MAP);
132 
133  for (int i = 0; i < key_input.numel(); ++i) {
134  map_data->emplace(key_data[i], value_data[i]);
135  }
136 
137  return true;
138  }
139 
140  template <typename KEY_T>
141  bool DoRunWithOtherType2() {
142  CAFFE_THROW(
143  "KeyValueToMap is not implemented on value tensor of type ",
144  Input(VALUES).dtype().name(),
145  "consider adding it as a type in the DispatchHelper list");
146  }
147 
148  INPUT_TAGS(KEYS, VALUES);
149  OUTPUT_TAGS(MAP);
150 };
151 
152 template <class Context>
153 class MapToKeyValueOp final : public Operator<Context> {
154  public:
155  USE_OPERATOR_CONTEXT_FUNCTIONS;
156  template <class... Args>
157  explicit MapToKeyValueOp(Args&&... args)
158  : Operator<Context>(std::forward<Args>(args)...) {}
159  ~MapToKeyValueOp() {}
160 
161  bool RunOnDevice() override {
163  MapType64To64,
164  MapType64To32,
165  MapType32To32,
166  MapType32To64>>::call(this, OperatorBase::InputBlob(MAP));
167  }
168 
169  template <typename MAP_T>
170  bool DoRunWithType() {
171  using key_type = typename MAP_T::key_type;
172  using mapped_type = typename MAP_T::mapped_type;
173  auto& map_data = this->template Input<MAP_T>(MAP);
174 
175  auto* key_output = Output(KEYS, {static_cast<int64_t>(map_data.size())}, at::dtype<key_type>());
176  auto* value_output =
177  Output(VALUES, {static_cast<int64_t>(map_data.size())}, at::dtype<mapped_type>());
178  auto* key_data = key_output->template mutable_data<key_type>();
179  auto* value_data = value_output->template mutable_data<mapped_type>();
180 
181  for (const auto& it : map_data) {
182  *key_data = it.first;
183  *value_data = it.second;
184  key_data++;
185  value_data++;
186  }
187 
188  return true;
189  }
190 
191  INPUT_TAGS(MAP);
192  OUTPUT_TAGS(KEYS, VALUES);
193 };
194 
195 template <typename KEY_T, typename VALUE_T>
197  public:
198  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
199 
200  void Serialize(
201  const void* pointer,
202  TypeMeta typeMeta,
203  const string& name,
204  BlobSerializerBase::SerializationAcceptor acceptor) override {
205  CAFFE_ENFORCE(typeMeta.Match<MapType>());
206  const MapType& map_data = *static_cast<const MapType*>(pointer);
207  int64_t sz = map_data.size();
208  Tensor key_tensor(CPU);
209  key_tensor.Resize(sz);
210  Tensor value_tensor(CPU);
211  value_tensor.Resize(sz);
212  auto* key_data = key_tensor.mutable_data<KEY_T>();
213  auto* value_data = value_tensor.mutable_data<VALUE_T>();
214  for (const auto& it : map_data) {
215  *key_data = it.first;
216  *value_data = it.second;
217  key_data++;
218  value_data++;
219  }
220 
221  TensorProtos tensor_protos;
222  TensorSerializer ser;
223  ser.Serialize(
224  key_tensor, name, tensor_protos.add_protos(), 0, key_tensor.numel());
225  ser.Serialize(
226  value_tensor,
227  name,
228  tensor_protos.add_protos(),
229  0,
230  value_tensor.numel());
231 
232  BlobProto blob_proto;
233  blob_proto.set_name(name);
234  blob_proto.set_type(MapTypeTraits<KEY_T, VALUE_T>::MapTypeName());
235  blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos));
236  acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
237  }
238 };
239 
240 template <typename KEY_T, typename VALUE_T>
242  public:
243  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
244 
245  void Deserialize(const BlobProto& proto, Blob* blob) override {
246  TensorProtos tensor_protos;
247  CAFFE_ENFORCE(
248  tensor_protos.ParseFromString(proto.content()),
249  "Fail to parse TensorProtos");
250  TensorDeserializer deser;
251  Tensor key_tensor = deser.Deserialize(tensor_protos.protos(0));
252  Tensor value_tensor = deser.Deserialize(tensor_protos.protos(1));
253  auto* key_data = key_tensor.data<KEY_T>();
254  auto* value_data = value_tensor.data<VALUE_T>();
255 
256  auto* map_ptr = blob->template GetMutable<MapType>();
257  for (int i = 0; i < key_tensor.numel(); ++i) {
258  map_ptr->emplace(key_data[i], value_data[i]);
259  }
260  }
261 };
262 
263 } // namespace caffe2
264 
265 #endif // CAFFE2_OPERATORS_MAP_OPS_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
void Serialize(const void *pointer, TypeMeta typeMeta, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
BlobSerializerBase is an abstract class that serializes a blob to a string.