Caffe2 - C++ API
A deep learning, cross platform ML framework
proto_utils.h
1 #ifndef CAFFE2_UTILS_PROTO_UTILS_H_
2 #define CAFFE2_UTILS_PROTO_UTILS_H_
3 
4 #ifdef CAFFE2_USE_LITE_PROTO
5 #include <google/protobuf/message_lite.h>
6 #else // CAFFE2_USE_LITE_PROTO
7 #include <google/protobuf/message.h>
8 #endif // !CAFFE2_USE_LITE_PROTO
9 
10 #include "caffe2/core/logging.h"
11 #include "caffe2/utils/proto_wrap.h"
12 #include "caffe2/proto/caffe2_pb.h"
13 
14 namespace caffe2 {
15 
16 using std::string;
17 using ::google::protobuf::MessageLite;
18 
19 // A wrapper function to return device name string for use in blob serialization
20 // / deserialization. This should have one to one correspondence with
21 // caffe2/proto/caffe2.proto: enum DeviceType.
22 //
23 // Note that we can't use DeviceType_Name, because that is only available in
24 // protobuf-full, and some platforms (like mobile) may want to use
25 // protobuf-lite instead.
26 CAFFE2_API std::string DeviceTypeName(const int32_t& d);
27 
28 CAFFE2_API int DeviceId(const DeviceOption& option);
29 
30 // Returns if the two DeviceOptions are pointing to the same device.
31 CAFFE2_API bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs);
32 
33 CAFFE2_API bool IsCPUDeviceType(int device_type);
34 CAFFE2_API bool IsGPUDeviceType(int device_type);
35 
36 // Common interfaces that reads file contents into a string.
37 CAFFE2_API bool ReadStringFromFile(const char* filename, string* str);
38 CAFFE2_API bool WriteStringToFile(const string& str, const char* filename);
39 
40 // Common interfaces that are supported by both lite and full protobuf.
41 CAFFE2_API bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto);
42 inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) {
43  return ReadProtoFromBinaryFile(filename.c_str(), proto);
44 }
45 
46 CAFFE2_API void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename);
47 inline void WriteProtoToBinaryFile(const MessageLite& proto,
48  const string& filename) {
49  return WriteProtoToBinaryFile(proto, filename.c_str());
50 }
51 
52 #ifdef CAFFE2_USE_LITE_PROTO
53 
54 namespace TextFormat {
55 inline bool ParseFromString(const string& spec, MessageLite* proto) {
56  LOG(FATAL) << "If you are running lite version, you should not be "
57  << "calling any text-format protobuffers.";
58 }
59 } // namespace TextFormat
60 
61 
62 CAFFE2_API string ProtoDebugString(const MessageLite& proto);
63 
64 CAFFE2_API bool ParseProtoFromLargeString(const string& str, MessageLite* proto);
65 
66 // Text format MessageLite wrappers: these functions do nothing but just
67 // allowing things to compile. It will produce a runtime error if you are using
68 // MessageLite but still want text support.
69 inline bool ReadProtoFromTextFile(
70  const char* /*filename*/,
71  MessageLite* /*proto*/) {
72  LOG(FATAL) << "If you are running lite version, you should not be "
73  << "calling any text-format protobuffers.";
74  return false; // Just to suppress compiler warning.
75 }
76 inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
77  return ReadProtoFromTextFile(filename.c_str(), proto);
78 }
79 
80 inline void WriteProtoToTextFile(
81  const MessageLite& /*proto*/,
82  const char* /*filename*/) {
83  LOG(FATAL) << "If you are running lite version, you should not be "
84  << "calling any text-format protobuffers.";
85 }
86 inline void WriteProtoToTextFile(const MessageLite& proto,
87  const string& filename) {
88  return WriteProtoToTextFile(proto, filename.c_str());
89 }
90 
91 inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) {
92  return (ReadProtoFromBinaryFile(filename, proto) ||
93  ReadProtoFromTextFile(filename, proto));
94 }
95 
96 inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
97  return ReadProtoFromFile(filename.c_str(), proto);
98 }
99 
100 #else // CAFFE2_USE_LITE_PROTO
101 
102 using ::google::protobuf::Message;
103 
104 namespace TextFormat {
105 CAFFE2_API bool ParseFromString(const string& spec, Message* proto);
106 } // namespace TextFormat
107 
108 CAFFE2_API string ProtoDebugString(const Message& proto);
109 
110 CAFFE2_API bool ParseProtoFromLargeString(const string& str, Message* proto);
111 
112 CAFFE2_API bool ReadProtoFromTextFile(const char* filename, Message* proto);
113 inline bool ReadProtoFromTextFile(const string filename, Message* proto) {
114  return ReadProtoFromTextFile(filename.c_str(), proto);
115 }
116 
117 CAFFE2_API void WriteProtoToTextFile(const Message& proto, const char* filename);
118 inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
119  return WriteProtoToTextFile(proto, filename.c_str());
120 }
121 
122 // Read Proto from a file, letting the code figure out if it is text or binary.
123 inline bool ReadProtoFromFile(const char* filename, Message* proto) {
124  return (ReadProtoFromBinaryFile(filename, proto) ||
125  ReadProtoFromTextFile(filename, proto));
126 }
127 
128 inline bool ReadProtoFromFile(const string& filename, Message* proto) {
129  return ReadProtoFromFile(filename.c_str(), proto);
130 }
131 
132 #endif // CAFFE2_USE_LITE_PROTO
133 
134 template <
135  class IterableInputs = std::initializer_list<string>,
136  class IterableOutputs = std::initializer_list<string>,
137  class IterableArgs = std::initializer_list<Argument>>
138 OperatorDef CreateOperatorDef(
139  const string& type,
140  const string& name,
141  const IterableInputs& inputs,
142  const IterableOutputs& outputs,
143  const IterableArgs& args,
144  const DeviceOption& device_option = DeviceOption(),
145  const string& engine = "") {
146  OperatorDef def;
147  def.set_type(type);
148  def.set_name(name);
149  for (const string& in : inputs) {
150  def.add_input(in);
151  }
152  for (const string& out : outputs) {
153  def.add_output(out);
154  }
155  for (const Argument& arg : args) {
156  def.add_arg()->CopyFrom(arg);
157  }
158  if (device_option.has_device_type()) {
159  def.mutable_device_option()->CopyFrom(device_option);
160  }
161  if (engine.size()) {
162  def.set_engine(engine);
163  }
164  return def;
165 }
166 
167 // A simplified version compared to the full CreateOperator, if you do not need
168 // to specify args.
169 template <
170  class IterableInputs = std::initializer_list<string>,
171  class IterableOutputs = std::initializer_list<string>>
172 inline OperatorDef CreateOperatorDef(
173  const string& type,
174  const string& name,
175  const IterableInputs& inputs,
176  const IterableOutputs& outputs,
177  const DeviceOption& device_option = DeviceOption(),
178  const string& engine = "") {
179  return CreateOperatorDef(
180  type,
181  name,
182  inputs,
183  outputs,
184  std::vector<Argument>(),
185  device_option,
186  engine);
187 }
188 
189 CAFFE2_API bool HasOutput(const OperatorDef& op, const std::string& output);
190 CAFFE2_API bool HasInput(const OperatorDef& op, const std::string& input);
191 
200 class C10_EXPORT ArgumentHelper {
201  public:
202  template <typename Def>
203  static bool HasArgument(const Def& def, const string& name) {
204  return ArgumentHelper(def).HasArgument(name);
205  }
206 
207  template <typename Def, typename T>
208  static T GetSingleArgument(
209  const Def& def,
210  const string& name,
211  const T& default_value) {
212  return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
213  }
214 
215  template <typename Def, typename T>
216  static bool HasSingleArgumentOfType(const Def& def, const string& name) {
217  return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
218  }
219 
220  template <typename Def, typename T>
221  static vector<T> GetRepeatedArgument(
222  const Def& def,
223  const string& name,
224  const std::vector<T>& default_value = std::vector<T>()) {
225  return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
226  }
227 
228  template <typename Def, typename MessageType>
229  static MessageType GetMessageArgument(const Def& def, const string& name) {
230  return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
231  }
232 
233  template <typename Def, typename MessageType>
234  static vector<MessageType> GetRepeatedMessageArgument(
235  const Def& def,
236  const string& name) {
237  return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
238  }
239 
240  template <typename Def>
241  static bool RemoveArgument(Def& def, int index) {
242  if (index >= def.arg_size()) {
243  return false;
244  }
245  if (index < def.arg_size() - 1) {
246  def.mutable_arg()->SwapElements(index, def.arg_size() - 1);
247  }
248  def.mutable_arg()->RemoveLast();
249  return true;
250  }
251 
252  explicit ArgumentHelper(const OperatorDef& def);
253  explicit ArgumentHelper(const NetDef& netdef);
254  bool HasArgument(const string& name) const;
255 
256  template <typename T>
257  T GetSingleArgument(const string& name, const T& default_value) const;
258  template <typename T>
259  bool HasSingleArgumentOfType(const string& name) const;
260  template <typename T>
261  vector<T> GetRepeatedArgument(
262  const string& name,
263  const std::vector<T>& default_value = std::vector<T>()) const;
264 
265  template <typename MessageType>
266  MessageType GetMessageArgument(const string& name) const {
267  CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
268  MessageType message;
269  if (arg_map_.at(name).has_s()) {
270  CAFFE_ENFORCE(
271  message.ParseFromString(arg_map_.at(name).s()),
272  "Faild to parse content from the string");
273  } else {
274  VLOG(1) << "Return empty message for parameter " << name;
275  }
276  return message;
277  }
278 
279  template <typename MessageType>
280  vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
281  CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
282  vector<MessageType> messages(arg_map_.at(name).strings_size());
283  for (int i = 0; i < messages.size(); ++i) {
284  CAFFE_ENFORCE(
285  messages[i].ParseFromString(arg_map_.at(name).strings(i)),
286  "Faild to parse content from the string");
287  }
288  return messages;
289  }
290 
291  private:
292  CaffeMap<string, Argument> arg_map_;
293 };
294 
295 // **** Arguments Utils *****
296 
297 // Helper methods to get an argument from OperatorDef or NetDef given argument
298 // name. Throws if argument does not exist.
299 CAFFE2_API const Argument& GetArgument(const OperatorDef& def, const string& name);
300 CAFFE2_API const Argument& GetArgument(const NetDef& def, const string& name);
301 
302 // Helper methods to query a boolean argument flag from OperatorDef or NetDef
303 // given argument name. If argument does not exist, return default value.
304 // Throws if argument exists but the type is not boolean.
305 CAFFE2_API bool GetFlagArgument(
306  const OperatorDef& def,
307  const string& name,
308  bool default_value = false);
309 CAFFE2_API bool GetFlagArgument(
310  const NetDef& def,
311  const string& name,
312  bool default_value = false);
313 
314 CAFFE2_API Argument* GetMutableArgument(
315  const string& name,
316  const bool create_if_missing,
317  OperatorDef* def);
318 
319 template <typename T>
320 CAFFE2_API Argument MakeArgument(const string& name, const T& value);
321 
322 template <typename T>
323 inline void AddArgument(const string& name, const T& value, OperatorDef* def) {
324  GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
325 }
326 // **** End Arguments Utils *****
327 
328 bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) {
329  return IsSameDevice(dl, dr);
330 }
331 
332 
333 } // namespace caffe2
334 
335 namespace std {
336 template <>
337 struct hash<caffe2::DeviceOption> {
338  typedef caffe2::DeviceOption argument_type;
339  typedef std::size_t result_type;
340  result_type operator()(argument_type const& device_option) const {
341  std::string serialized;
342  CAFFE_ENFORCE(device_option.SerializeToString(&serialized));
343  return std::hash<std::string>{}(serialized);
344  }
345 };
346 } // namespace std
347 
348 #endif // CAFFE2_UTILS_PROTO_UTILS_H_
A helper class to index into arguments.
Definition: proto_utils.h:200
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13