Caffe2 - C++ API
A deep learning, cross platform ML framework
map_ops.h
1 
17 #ifndef CAFFE2_OPERATORS_MAP_OPS_H_
18 #define CAFFE2_OPERATORS_MAP_OPS_H_
19 
20 #include <algorithm>
21 #include <iterator>
22 #include <string>
23 #include <typeinfo>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include "caffe2/core/blob_serialization.h"
29 #include "caffe2/core/context.h"
30 #include "caffe2/core/operator.h"
31 
32 namespace caffe2 {
33 
34 template <typename T>
36  static constexpr const char* name = "unknown";
37 };
38 
39 template <>
40 struct TypeNameTraits<int64_t> {
41  static constexpr const char* name = "int64_t";
42 };
43 
44 template <>
45 struct TypeNameTraits<int32_t> {
46  static constexpr const char* name = "int32_t";
47 };
48 
49 template <typename KEY_T, typename VALUE_T>
50 struct MapTypeTraits {
51  using MapType = std::unordered_map<KEY_T, VALUE_T>;
52  static string MapTypeName() {
53  return string("(std::unordered_map<") + TypeNameTraits<KEY_T>::name + ", " +
55  }
56 };
57 
58 using MapType64To64 = MapTypeTraits<int64_t, int64_t>::MapType;
59 using MapType64To32 = MapTypeTraits<int64_t, int32_t>::MapType;
60 using MapType32To32 = MapTypeTraits<int32_t, int32_t>::MapType;
61 using MapType32To64 = MapTypeTraits<int32_t, int64_t>::MapType;
62 
63 template <class Context>
64 class CreateMapOp final : public Operator<Context> {
65  public:
66  USE_OPERATOR_CONTEXT_FUNCTIONS;
67  CreateMapOp(const OperatorDef& operator_def, Workspace* ws)
68  : Operator<Context>(operator_def, ws) {}
69  ~CreateMapOp() {}
70 
71  bool RunOnDevice() override {
72  TensorProto::DataType key_dtype =
73  static_cast<TensorProto::DataType>(OperatorBase::GetSingleArgument<int>(
74  "key_dtype", TensorProto_DataType_INT32));
75 
77  this, DataTypeToTypeMeta(key_dtype));
78  }
79 
80  template <typename KEY_T>
81  bool DoRunWithType() {
82  TensorProto::DataType value_dtype =
83  static_cast<TensorProto::DataType>(OperatorBase::GetSingleArgument<int>(
84  "value_dtype", TensorProto_DataType_INT32));
85 
86  return DispatchHelper<
88  KEY_T>::call(this, DataTypeToTypeMeta(value_dtype));
89  }
90 
91  template <typename KEY_T, typename VALUE_T>
92  bool DoRunWithType2() {
93  // clear to make sure the map is empty
94  OperatorBase::Output<typename MapTypeTraits<KEY_T, VALUE_T>::MapType>(MAP)
95  ->clear();
96  return true;
97  }
98 
99  template <typename KEY_T>
100  bool DoRunWithOtherType2() {
101  TensorProto::DataType value_dtype =
102  static_cast<TensorProto::DataType>(OperatorBase::GetSingleArgument<int>(
103  "value_dtype", TensorProto_DataType_INT32));
104 
105  CAFFE_THROW(
106  "CreateMap is not implemented on value tensor of type ",
107  DataTypeToTypeMeta(value_dtype).name(),
108  "Consider adding it a type in the list DispatchHelper");
109  }
110 
111  OUTPUT_TAGS(MAP);
112 };
113 
114 template <class Context>
115 class KeyValueToMapOp final : public Operator<Context> {
116  public:
117  USE_OPERATOR_CONTEXT_FUNCTIONS;
118  KeyValueToMapOp(const OperatorDef& operator_def, Workspace* ws)
119  : Operator<Context>(operator_def, ws) {}
120  ~KeyValueToMapOp() {}
121 
122  bool RunOnDevice() override {
124  this, Input(KEYS));
125  }
126 
127  template <typename KEY_T>
128  bool DoRunWithType() {
129  return DispatchHelper<
131  KEY_T>::call(this, Input(VALUES));
132  }
133 
134  template <typename KEY_T, typename VALUE_T>
135  bool DoRunWithType2() {
136  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
137  const auto& key_input = Input(KEYS);
138  const auto& value_input = Input(VALUES);
139 
140  CAFFE_ENFORCE_EQ(key_input.size(), value_input.size());
141 
142  auto* key_data = key_input.template data<KEY_T>();
143  auto* value_data = value_input.template data<VALUE_T>();
144 
145  auto* map_data = OperatorBase::Output<MapType>(MAP);
146 
147  for (int i = 0; i < key_input.size(); ++i) {
148  map_data->emplace(key_data[i], value_data[i]);
149  }
150 
151  return true;
152  }
153 
154  template <typename KEY_T>
155  bool DoRunWithOtherType2() {
156  CAFFE_THROW(
157  "KeyValueToMap is not implemented on value tensor of type ",
158  Input(VALUES).meta().name(),
159  "Consider adding it a type in the list DispatchHelper");
160  }
161 
162  INPUT_TAGS(KEYS, VALUES);
163  OUTPUT_TAGS(MAP);
164 };
165 
166 template <class Context>
167 class MapToKeyValueOp final : public Operator<Context> {
168  public:
169  USE_OPERATOR_CONTEXT_FUNCTIONS;
170  MapToKeyValueOp(const OperatorDef& operator_def, Workspace* ws)
171  : Operator<Context>(operator_def, ws) {}
172  ~MapToKeyValueOp() {}
173 
174  bool RunOnDevice() override {
176  MapType64To64,
177  MapType64To32,
178  MapType32To32,
179  MapType32To64>>::call(this, OperatorBase::InputBlob(MAP));
180  }
181 
182  template <typename MAP_T>
183  bool DoRunWithType() {
184  using key_type = typename MAP_T::key_type;
185  using mapped_type = typename MAP_T::mapped_type;
186  auto& map_data = OperatorBase::Input<MAP_T>(MAP);
187  auto* key_output = Output(KEYS);
188  auto* value_output = Output(VALUES);
189  key_output->Resize(map_data.size());
190  value_output->Resize(map_data.size());
191  auto* key_data = key_output->template mutable_data<key_type>();
192  auto* value_data = value_output->template mutable_data<mapped_type>();
193 
194  for (const auto& it : map_data) {
195  *key_data = it.first;
196  *value_data = it.second;
197  key_data++;
198  value_data++;
199  }
200 
201  return true;
202  }
203 
204  INPUT_TAGS(MAP);
205  OUTPUT_TAGS(KEYS, VALUES);
206 };
207 
208 template <typename KEY_T, typename VALUE_T>
210  public:
211  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
212 
213  void Serialize(
214  const Blob& blob,
215  const string& name,
216  BlobSerializerBase::SerializationAcceptor acceptor) override {
217  CAFFE_ENFORCE(blob.IsType<MapType>());
218  const MapType& map_data = blob.template Get<MapType>();
219  TIndex sz = map_data.size();
220  Tensor<CPUContext> key_tensor;
221  key_tensor.Resize(sz);
222  Tensor<CPUContext> value_tensor;
223  value_tensor.Resize(sz);
224  auto* key_data = key_tensor.mutable_data<KEY_T>();
225  auto* value_data = value_tensor.mutable_data<VALUE_T>();
226  for (const auto& it : map_data) {
227  *key_data = it.first;
228  *value_data = it.second;
229  key_data++;
230  value_data++;
231  }
232 
233  TensorProtos tensor_protos;
235  ser.Serialize(
236  key_tensor, name, tensor_protos.add_protos(), 0, key_tensor.size());
237  ser.Serialize(
238  value_tensor, name, tensor_protos.add_protos(), 0, value_tensor.size());
239 
240  BlobProto blob_proto;
241  blob_proto.set_name(name);
242  blob_proto.set_type(MapTypeTraits<KEY_T, VALUE_T>::MapTypeName());
243  blob_proto.set_content(tensor_protos.SerializeAsString());
244  acceptor(name, blob_proto.SerializeAsString());
245  }
246 };
247 
248 template <typename KEY_T, typename VALUE_T>
250  public:
251  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
252 
253  void Deserialize(const BlobProto& proto, Blob* blob) override {
254  TensorProtos tensor_protos;
255  CAFFE_ENFORCE(
256  tensor_protos.ParseFromString(proto.content()),
257  "Fail to parse TensorProtos");
259  Tensor<CPUContext> key_tensor, value_tensor;
260  deser.Deserialize(tensor_protos.protos(0), &key_tensor);
261  deser.Deserialize(tensor_protos.protos(1), &value_tensor);
262  auto* key_data = key_tensor.data<KEY_T>();
263  auto* value_data = value_tensor.data<VALUE_T>();
264 
265  auto* map_ptr = blob->template GetMutable<MapType>();
266  for (int i = 0; i < key_tensor.size(); ++i) {
267  map_ptr->emplace(key_data[i], value_data[i]);
268  }
269  }
270 };
271 
272 } // namespace caffe2
273 
274 #endif // CAFFE2_OPERATORS_MAP_OPS_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
const T * data() const
Returns a typed pointer of the underlying storage.
Definition: tensor.h:500
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:609
T * mutable_data()
Returns a typed pointer of the underlying storage.
Definition: tensor.h:594
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:304
TensorDeserializer is the deserializer for Tensors.
Copyright (c) 2016-present, Facebook, Inc.
bool IsType() const
Checks if the content stored in the blob is of type T.
Definition: blob.h:74
BlobSerializerBase is an abstract class that serializes a blob to a string.