Caffe2 - C++ API
A deep learning, cross platform ML framework
proto_utils.cc
1 
17 #include "caffe2/utils/proto_utils.h"
18 
19 #include <fcntl.h>
20 #include <cerrno>
21 #include <fstream>
22 
23 #include <google/protobuf/io/coded_stream.h>
24 #include <google/protobuf/io/zero_copy_stream_impl.h>
25 
26 #ifndef CAFFE2_USE_LITE_PROTO
27 #include <google/protobuf/text_format.h>
28 #endif // !CAFFE2_USE_LITE_PROTO
29 
30 #include "caffe2/core/logging.h"
31 
32 using ::google::protobuf::Message;
33 using ::google::protobuf::MessageLite;
34 
35 namespace caffe2 {
36 
37 std::string DeviceTypeName(const int32_t& d) {
38  switch (d) {
39  case CPU:
40  return "CPU";
41  case CUDA:
42  return "CUDA";
43  case OPENGL:
44  return "OPENGL";
45  case MKLDNN:
46  return "MKLDNN";
47  default:
48  CAFFE_THROW(
49  "Unknown device: ",
50  d,
51  ". If you have recently updated the caffe2.proto file to add a new "
52  "device type, did you forget to update the TensorDeviceTypeName() "
53  "function to reflect such recent changes?");
54  // The below code won't run but is needed to suppress some compiler
55  // warnings.
56  return "";
57  }
58 };
59 
60 bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) {
61  return (
62  lhs.device_type() == rhs.device_type() &&
63  lhs.cuda_gpu_id() == rhs.cuda_gpu_id() &&
64  lhs.node_name() == rhs.node_name());
65 }
66 
67 bool ReadStringFromFile(const char* filename, string* str) {
68  std::ifstream ifs(filename, std::ios::in);
69  if (!ifs) {
70  VLOG(1) << "File cannot be opened: " << filename
71  << " error: " << ifs.rdstate();
72  return false;
73  }
74  ifs.seekg(0, std::ios::end);
75  size_t n = ifs.tellg();
76  str->resize(n);
77  ifs.seekg(0);
78  ifs.read(&(*str)[0], n);
79  return true;
80 }
81 
82 bool WriteStringToFile(const string& str, const char* filename) {
83  std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
84  if (!ofs.is_open()) {
85  VLOG(1) << "File cannot be created: " << filename
86  << " error: " << ofs.rdstate();
87  return false;
88  }
89  ofs << str;
90  return true;
91 }
92 
93 // IO-specific proto functions: we will deal with the protocol buffer lite and
94 // full versions differently.
95 
96 #ifdef CAFFE2_USE_LITE_PROTO
97 
98 // Lite runtime.
99 
100 namespace {
101 class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
102  public:
103  explicit IfstreamInputStream(const string& filename)
104  : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
105  ~IfstreamInputStream() { ifs_.close(); }
106 
107  int Read(void* buffer, int size) {
108  if (!ifs_) {
109  return -1;
110  }
111  ifs_.read(static_cast<char*>(buffer), size);
112  return ifs_.gcount();
113  }
114 
115  private:
116  std::ifstream ifs_;
117 };
118 } // namespace
119 
120 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
121  ::google::protobuf::io::CopyingInputStreamAdaptor stream(
122  new IfstreamInputStream(filename));
123  stream.SetOwnsCopyingStream(true);
124  // Total bytes hard limit / warning limit are set to 1GB and 512MB
125  // respectively.
126  ::google::protobuf::io::CodedInputStream coded_stream(&stream);
127  coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
128  return proto->ParseFromCodedStream(&coded_stream);
129 }
130 
131 void WriteProtoToBinaryFile(
132  const MessageLite& /*proto*/,
133  const char* /*filename*/) {
134  LOG(FATAL) << "Not implemented yet.";
135 }
136 
137 #else // CAFFE2_USE_LITE_PROTO
138 
139 // Full protocol buffer.
140 
141 using ::google::protobuf::io::FileInputStream;
142 using ::google::protobuf::io::FileOutputStream;
143 using ::google::protobuf::io::ZeroCopyInputStream;
144 using ::google::protobuf::io::CodedInputStream;
145 using ::google::protobuf::io::ZeroCopyOutputStream;
146 using ::google::protobuf::io::CodedOutputStream;
147 
148 bool ReadProtoFromTextFile(const char* filename, Message* proto) {
149  int fd = open(filename, O_RDONLY);
150  CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
151  FileInputStream* input = new FileInputStream(fd);
152  bool success = google::protobuf::TextFormat::Parse(input, proto);
153  delete input;
154  close(fd);
155  return success;
156 }
157 
158 void WriteProtoToTextFile(const Message& proto, const char* filename) {
159  int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
160  FileOutputStream* output = new FileOutputStream(fd);
161  CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output));
162  delete output;
163  close(fd);
164 }
165 
166 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
167 #if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified
168  int fd = open(filename, O_RDONLY | O_BINARY);
169 #else
170  int fd = open(filename, O_RDONLY);
171 #endif
172  CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
173  std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
174  std::unique_ptr<CodedInputStream> coded_input(
175  new CodedInputStream(raw_input.get()));
176  // A hack to manually allow using very large protocol buffers.
177  coded_input->SetTotalBytesLimit(1073741824, 536870912);
178  bool success = proto->ParseFromCodedStream(coded_input.get());
179  coded_input.reset();
180  raw_input.reset();
181  close(fd);
182  return success;
183 }
184 
185 void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
186  int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
187  CAFFE_ENFORCE_NE(
188  fd, -1, "File cannot be created: ", filename, " error number: ", errno);
189  std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd));
190  std::unique_ptr<CodedOutputStream> coded_output(
191  new CodedOutputStream(raw_output.get()));
192  CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get()));
193  coded_output.reset();
194  raw_output.reset();
195  close(fd);
196 }
197 
198 #endif // CAFFE2_USE_LITE_PROTO
199 
200 ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
201  for (auto& arg : def.arg()) {
202  if (arg_map_.count(arg.name())) {
203  if (arg.SerializeAsString() != arg_map_[arg.name()].SerializeAsString()) {
204  // If there are two arguments of the same name but different contents,
205  // we will throw an error.
206  CAFFE_THROW(
207  "Found argument of the same name ",
208  arg.name(),
209  "but with different contents.",
210  ProtoDebugString(def));
211  } else {
212  LOG(WARNING) << "Duplicated argument name [" << arg.name()
213  << "] found in operator def: "
214  << ProtoDebugString(def);
215  }
216  }
217  arg_map_[arg.name()] = arg;
218  }
219 }
220 
221 ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
222  for (auto& arg : netdef.arg()) {
223  CAFFE_ENFORCE(
224  arg_map_.count(arg.name()) == 0,
225  "Duplicated argument name [", arg.name(), "] found in net def: ",
226  ProtoDebugString(netdef));
227  arg_map_[arg.name()] = arg;
228  }
229 }
230 
231 bool ArgumentHelper::HasArgument(const string& name) const {
232  return arg_map_.count(name);
233 }
234 
235 namespace {
236 // Helper function to verify that conversion between types won't loose any
237 // significant bit.
238 template <typename InputType, typename TargetType>
239 bool SupportsLosslessConversion(const InputType& value) {
240  return static_cast<InputType>(static_cast<TargetType>(value)) == value;
241 }
242 }
243 
244 bool operator==(const NetDef& l, const NetDef& r) {
245  return l.SerializeAsString() == r.SerializeAsString();
246 }
247 
248 std::ostream& operator<<(std::ostream& output, const NetDef& n) {
249  output << n.SerializeAsString();
250  return output;
251 }
252 
253 #define INSTANTIATE_GET_SINGLE_ARGUMENT( \
254  T, fieldname, enforce_lossless_conversion) \
255  template <> \
256  T ArgumentHelper::GetSingleArgument<T>( \
257  const string& name, const T& default_value) const { \
258  if (arg_map_.count(name) == 0) { \
259  VLOG(1) << "Using default parameter value " << default_value \
260  << " for parameter " << name; \
261  return default_value; \
262  } \
263  CAFFE_ENFORCE( \
264  arg_map_.at(name).has_##fieldname(), \
265  "Argument ", \
266  name, \
267  " does not have the right field: expected field " #fieldname); \
268  auto value = arg_map_.at(name).fieldname(); \
269  if (enforce_lossless_conversion) { \
270  auto supportsConversion = \
271  SupportsLosslessConversion<decltype(value), T>(value); \
272  CAFFE_ENFORCE( \
273  supportsConversion, \
274  "Value", \
275  value, \
276  " of argument ", \
277  name, \
278  "cannot be represented correctly in a target type"); \
279  } \
280  return static_cast<T>(value); \
281  } \
282  template <> \
283  bool ArgumentHelper::HasSingleArgumentOfType<T>(const string& name) const { \
284  if (arg_map_.count(name) == 0) { \
285  return false; \
286  } \
287  return arg_map_.at(name).has_##fieldname(); \
288  }
289 
290 INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
291 INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false)
292 INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false)
293 INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
294 INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
295 INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true)
296 INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
297 INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
298 INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
299 INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
300 INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
301 INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
302 #undef INSTANTIATE_GET_SINGLE_ARGUMENT
303 
304 #define INSTANTIATE_GET_REPEATED_ARGUMENT( \
305  T, fieldname, enforce_lossless_conversion) \
306  template <> \
307  vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
308  const string& name, const std::vector<T>& default_value) const { \
309  if (arg_map_.count(name) == 0) { \
310  return default_value; \
311  } \
312  vector<T> values; \
313  for (const auto& v : arg_map_.at(name).fieldname()) { \
314  if (enforce_lossless_conversion) { \
315  auto supportsConversion = \
316  SupportsLosslessConversion<decltype(v), T>(v); \
317  CAFFE_ENFORCE( \
318  supportsConversion, \
319  "Value", \
320  v, \
321  " of argument ", \
322  name, \
323  "cannot be represented correctly in a target type"); \
324  } \
325  values.push_back(static_cast<T>(v)); \
326  } \
327  return values; \
328  }
329 
330 INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false)
331 INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false)
332 INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false)
333 INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
334 INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
335 INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true)
336 INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
337 INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
338 INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
339 INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
340 INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
341 INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false)
342 #undef INSTANTIATE_GET_REPEATED_ARGUMENT
343 
344 #define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
345 template <> \
346 Argument MakeArgument(const string& name, const T& value) { \
347  Argument arg; \
348  arg.set_name(name); \
349  arg.set_##fieldname(value); \
350  return arg; \
351 }
352 
353 CAFFE2_MAKE_SINGULAR_ARGUMENT(bool, i)
354 CAFFE2_MAKE_SINGULAR_ARGUMENT(float, f)
355 CAFFE2_MAKE_SINGULAR_ARGUMENT(int, i)
356 CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
357 CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s)
358 #undef CAFFE2_MAKE_SINGULAR_ARGUMENT
359 
360 template <>
361 Argument MakeArgument(const string& name, const MessageLite& value) {
362  Argument arg;
363  arg.set_name(name);
364  arg.set_s(value.SerializeAsString());
365  return arg;
366 }
367 
368 #define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \
369 template <> \
370 Argument MakeArgument(const string& name, const vector<T>& value) { \
371  Argument arg; \
372  arg.set_name(name); \
373  for (const auto& v : value) { \
374  arg.add_##fieldname(v); \
375  } \
376  return arg; \
377 }
378 
379 CAFFE2_MAKE_REPEATED_ARGUMENT(float, floats)
380 CAFFE2_MAKE_REPEATED_ARGUMENT(int, ints)
381 CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints)
382 CAFFE2_MAKE_REPEATED_ARGUMENT(string, strings)
383 #undef CAFFE2_MAKE_REPEATED_ARGUMENT
384 
385 bool HasOutput(const OperatorDef& op, const std::string& output) {
386  for (const auto& outp : op.output()) {
387  if (outp == output) {
388  return true;
389  }
390  }
391  return false;
392 }
393 
394 bool HasInput(const OperatorDef& op, const std::string& input) {
395  for (const auto& inp : op.input()) {
396  if (inp == input) {
397  return true;
398  }
399  }
400  return false;
401 }
402 
403 const Argument& GetArgument(const OperatorDef& def, const string& name) {
404  for (const Argument& arg : def.arg()) {
405  if (arg.name() == name) {
406  return arg;
407  }
408  }
409  CAFFE_THROW(
410  "Argument named ",
411  name,
412  " does not exist in operator ",
413  ProtoDebugString(def));
414 }
415 
416 bool GetFlagArgument(
417  const OperatorDef& def,
418  const string& name,
419  bool def_value) {
420  for (const Argument& arg : def.arg()) {
421  if (arg.name() == name) {
422  CAFFE_ENFORCE(
423  arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg));
424  return arg.i();
425  }
426  }
427  return def_value;
428 }
429 
430 Argument* GetMutableArgument(
431  const string& name,
432  const bool create_if_missing,
433  OperatorDef* def) {
434  for (int i = 0; i < def->arg_size(); ++i) {
435  if (def->arg(i).name() == name) {
436  return def->mutable_arg(i);
437  }
438  }
439  // If no argument of the right name is found...
440  if (create_if_missing) {
441  Argument* arg = def->add_arg();
442  arg->set_name(name);
443  return arg;
444  } else {
445  return nullptr;
446  }
447 }
448 
449 } // namespace caffe2
Definition: types.h:88
Copyright (c) 2016-present, Facebook, Inc.