Caffe2 - C++ API
A deep learning, cross platform ML framework
proto_utils.h
1 
17 #ifndef CAFFE2_UTILS_PROTO_UTILS_H_
18 #define CAFFE2_UTILS_PROTO_UTILS_H_
19 
20 #ifdef CAFFE2_USE_LITE_PROTO
21 #include <google/protobuf/message_lite.h>
22 #else // CAFFE2_USE_LITE_PROTO
23 #include <google/protobuf/message.h>
24 #endif // !CAFFE2_USE_LITE_PROTO
25 
26 #include "caffe2/core/logging.h"
27 #include "caffe2/proto/caffe2.pb.h"
28 
29 namespace caffe2 {
30 
31 using std::string;
32 using ::google::protobuf::MessageLite;
33 
34 // A wrapper function to return device name string for use in blob serialization
35 // / deserialization. This should have one to one correspondence with
36 // caffe2/proto/caffe2.proto: enum DeviceType.
37 //
38 // Note that we can't use DeviceType_Name, because that is only available in
39 // protobuf-full, and some platforms (like mobile) may want to use
40 // protobuf-lite instead.
41 std::string DeviceTypeName(const int32_t& d);
42 
43 // Returns if the two DeviceOptions are pointing to the same device.
44 bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs);
45 
46 // Common interfaces that reads file contents into a string.
47 bool ReadStringFromFile(const char* filename, string* str);
48 bool WriteStringToFile(const string& str, const char* filename);
49 
50 // Common interfaces that are supported by both lite and full protobuf.
51 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto);
52 inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) {
53  return ReadProtoFromBinaryFile(filename.c_str(), proto);
54 }
55 
56 void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename);
57 inline void WriteProtoToBinaryFile(const MessageLite& proto,
58  const string& filename) {
59  return WriteProtoToBinaryFile(proto, filename.c_str());
60 }
61 
62 #ifdef CAFFE2_USE_LITE_PROTO
63 
64 inline string ProtoDebugString(const MessageLite& proto) {
65  return proto.SerializeAsString();
66 }
67 
68 // Text format MessageLite wrappers: these functions do nothing but just
69 // allowing things to compile. It will produce a runtime error if you are using
70 // MessageLite but still want text support.
71 inline bool ReadProtoFromTextFile(
72  const char* /*filename*/,
73  MessageLite* /*proto*/) {
74  LOG(FATAL) << "If you are running lite version, you should not be "
75  << "calling any text-format protobuffers.";
76  return false; // Just to suppress compiler warning.
77 }
78 inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
79  return ReadProtoFromTextFile(filename.c_str(), proto);
80 }
81 
82 inline void WriteProtoToTextFile(
83  const MessageLite& /*proto*/,
84  const char* /*filename*/) {
85  LOG(FATAL) << "If you are running lite version, you should not be "
86  << "calling any text-format protobuffers.";
87 }
88 inline void WriteProtoToTextFile(const MessageLite& proto,
89  const string& filename) {
90  return WriteProtoToTextFile(proto, filename.c_str());
91 }
92 
93 inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) {
94  return (ReadProtoFromBinaryFile(filename, proto) ||
95  ReadProtoFromTextFile(filename, proto));
96 }
97 
98 inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
99  return ReadProtoFromFile(filename.c_str(), proto);
100 }
101 
102 #else // CAFFE2_USE_LITE_PROTO
103 
104 using ::google::protobuf::Message;
105 
106 inline string ProtoDebugString(const Message& proto) {
107  return proto.ShortDebugString();
108 }
109 
110 bool ReadProtoFromTextFile(const char* filename, Message* proto);
111 inline bool ReadProtoFromTextFile(const string filename, Message* proto) {
112  return ReadProtoFromTextFile(filename.c_str(), proto);
113 }
114 
115 void WriteProtoToTextFile(const Message& proto, const char* filename);
116 inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
117  return WriteProtoToTextFile(proto, filename.c_str());
118 }
119 
120 // Read Proto from a file, letting the code figure out if it is text or binary.
121 inline bool ReadProtoFromFile(const char* filename, Message* proto) {
122  return (ReadProtoFromBinaryFile(filename, proto) ||
123  ReadProtoFromTextFile(filename, proto));
124 }
125 
126 inline bool ReadProtoFromFile(const string& filename, Message* proto) {
127  return ReadProtoFromFile(filename.c_str(), proto);
128 }
129 
130 #endif // CAFFE2_USE_LITE_PROTO
131 
132 template <
133  class IterableInputs = std::initializer_list<string>,
134  class IterableOutputs = std::initializer_list<string>,
135  class IterableArgs = std::initializer_list<Argument>>
136 OperatorDef CreateOperatorDef(
137  const string& type,
138  const string& name,
139  const IterableInputs& inputs,
140  const IterableOutputs& outputs,
141  const IterableArgs& args,
142  const DeviceOption& device_option = DeviceOption(),
143  const string& engine = "") {
144  OperatorDef def;
145  def.set_type(type);
146  def.set_name(name);
147  for (const string& in : inputs) {
148  def.add_input(in);
149  }
150  for (const string& out : outputs) {
151  def.add_output(out);
152  }
153  for (const Argument& arg : args) {
154  def.add_arg()->CopyFrom(arg);
155  }
156  if (device_option.has_device_type()) {
157  def.mutable_device_option()->CopyFrom(device_option);
158  }
159  if (engine.size()) {
160  def.set_engine(engine);
161  }
162  return def;
163 }
164 
165 // A simplified version compared to the full CreateOperator, if you do not need
166 // to specify args.
167 template <
168  class IterableInputs = std::initializer_list<string>,
169  class IterableOutputs = std::initializer_list<string>>
170 inline OperatorDef CreateOperatorDef(
171  const string& type,
172  const string& name,
173  const IterableInputs& inputs,
174  const IterableOutputs& outputs,
175  const DeviceOption& device_option = DeviceOption(),
176  const string& engine = "") {
177  return CreateOperatorDef(
178  type,
179  name,
180  inputs,
181  outputs,
182  std::vector<Argument>(),
183  device_option,
184  engine);
185 }
186 
187 bool HasOutput(const OperatorDef& op, const std::string& output);
188 bool HasInput(const OperatorDef& op, const std::string& input);
189 
199  public:
200  template <typename Def>
201  static bool HasArgument(const Def& def, const string& name) {
202  return ArgumentHelper(def).HasArgument(name);
203  }
204 
205  template <typename Def, typename T>
206  static T GetSingleArgument(
207  const Def& def,
208  const string& name,
209  const T& default_value) {
210  return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
211  }
212 
213  template <typename Def, typename T>
214  static bool HasSingleArgumentOfType(const Def& def, const string& name) {
215  return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
216  }
217 
218  template <typename Def, typename T>
219  static vector<T> GetRepeatedArgument(
220  const Def& def,
221  const string& name,
222  const std::vector<T>& default_value = std::vector<T>()) {
223  return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
224  }
225 
226  template <typename Def, typename MessageType>
227  static MessageType GetMessageArgument(const Def& def, const string& name) {
228  return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
229  }
230 
231  template <typename Def, typename MessageType>
232  static vector<MessageType> GetRepeatedMessageArgument(
233  const Def& def,
234  const string& name) {
235  return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
236  }
237 
238  explicit ArgumentHelper(const OperatorDef& def);
239  explicit ArgumentHelper(const NetDef& netdef);
240  bool HasArgument(const string& name) const;
241 
242  template <typename T>
243  T GetSingleArgument(const string& name, const T& default_value) const;
244  template <typename T>
245  bool HasSingleArgumentOfType(const string& name) const;
246  template <typename T>
247  vector<T> GetRepeatedArgument(
248  const string& name,
249  const std::vector<T>& default_value = std::vector<T>()) const;
250 
251  template <typename MessageType>
252  MessageType GetMessageArgument(const string& name) const {
253  CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
254  MessageType message;
255  if (arg_map_.at(name).has_s()) {
256  CAFFE_ENFORCE(
257  message.ParseFromString(arg_map_.at(name).s()),
258  "Faild to parse content from the string");
259  } else {
260  VLOG(1) << "Return empty message for parameter " << name;
261  }
262  return message;
263  }
264 
265  template <typename MessageType>
266  vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
267  CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
268  vector<MessageType> messages(arg_map_.at(name).strings_size());
269  for (int i = 0; i < messages.size(); ++i) {
270  CAFFE_ENFORCE(
271  messages[i].ParseFromString(arg_map_.at(name).strings(i)),
272  "Faild to parse content from the string");
273  }
274  return messages;
275  }
276 
277  private:
278  CaffeMap<string, Argument> arg_map_;
279 };
280 
281 const Argument& GetArgument(const OperatorDef& def, const string& name);
282 bool GetFlagArgument(
283  const OperatorDef& def,
284  const string& name,
285  bool def_value = false);
286 
287 Argument* GetMutableArgument(
288  const string& name,
289  const bool create_if_missing,
290  OperatorDef* def);
291 
292 template <typename T>
293 Argument MakeArgument(const string& name, const T& value);
294 
295 template <typename T>
296 inline void AddArgument(const string& name, const T& value, OperatorDef* def) {
297  GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
298 }
299 
300 bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) {
301  return IsSameDevice(dl, dr);
302 }
303 
304 } // namespace caffe2
305 
306 namespace std {
307 template <>
308 struct hash<caffe2::DeviceOption> {
309  typedef caffe2::DeviceOption argument_type;
310  typedef std::size_t result_type;
311  result_type operator()(argument_type const& device_option) const {
312  std::string serialized;
313  CAFFE_ENFORCE(device_option.SerializeToString(&serialized));
314  return std::hash<std::string>{}(serialized);
315  }
316 };
317 } // namespace std
318 
319 #endif // CAFFE2_UTILS_PROTO_UTILS_H_
Definition: types.h:88
A helper class to index into arguments.
Definition: proto_utils.h:198
Copyright (c) 2016-present, Facebook, Inc.