Caffe2 - C++ API
A deep learning, cross platform ML framework
proto_utils.cc
1 #include "caffe2/utils/proto_utils.h"
2 
3 #include <c10/core/DeviceType.h>
4 
5 #include <fcntl.h>
6 #include <cerrno>
7 #include <fstream>
8 #include <unordered_set>
9 
10 #include <google/protobuf/io/coded_stream.h>
11 
12 #ifndef CAFFE2_USE_LITE_PROTO
13 #include <google/protobuf/text_format.h>
14 #include <google/protobuf/io/zero_copy_stream_impl.h>
15 #else
16 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
17 #endif // !CAFFE2_USE_LITE_PROTO
18 
19 #include "caffe2/core/logging.h"
20 
21 using ::google::protobuf::MessageLite;
22 
23 namespace caffe2 {
24 
25 C10_EXPORT std::string DeviceTypeName(const int32_t& d) {
26  return at::DeviceTypeName(static_cast<at::DeviceType>(d));
27 }
28 
29 C10_EXPORT int DeviceId(const DeviceOption& option) {
30  switch (option.device_type()) {
31  case PROTO_CPU:
32  return option.numa_node_id();
33  case PROTO_CUDA:
34  case PROTO_HIP:
35  return option.device_id();
36  case PROTO_MKLDNN:
37  return option.numa_node_id();
38  default:
39  CAFFE_THROW("Unknown device id for device type: ", option.device_type());
40  }
41 }
42 
43 C10_EXPORT bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) {
44  return (
45  lhs.device_type() == rhs.device_type() &&
46  lhs.device_id() == rhs.device_id() &&
47  lhs.node_name() == rhs.node_name() &&
48  lhs.numa_node_id() == rhs.numa_node_id());
49 }
50 
51 C10_EXPORT bool IsCPUDeviceType(int device_type) {
52  static const std::unordered_set<int> cpu_types{
53  PROTO_CPU,
54  PROTO_MKLDNN,
55  PROTO_IDEEP,
56  PROTO_ONLY_FOR_TEST,
57  };
58  return cpu_types.count(device_type);
59 }
60 
61 C10_EXPORT bool IsGPUDeviceType(int device_type) {
62  static const std::unordered_set<int> gpu_types{
63  PROTO_CUDA,
64  PROTO_HIP,
65  };
66  return gpu_types.count(device_type);
67 }
68 
69 C10_EXPORT bool ReadStringFromFile(const char* filename, string* str) {
70  std::ifstream ifs(filename, std::ios::in);
71  if (!ifs) {
72  VLOG(1) << "File cannot be opened: " << filename
73  << " error: " << ifs.rdstate();
74  return false;
75  }
76  ifs.seekg(0, std::ios::end);
77  size_t n = ifs.tellg();
78  str->resize(n);
79  ifs.seekg(0);
80  ifs.read(&(*str)[0], n);
81  return true;
82 }
83 
84 C10_EXPORT bool WriteStringToFile(const string& str, const char* filename) {
85  std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
86  if (!ofs.is_open()) {
87  VLOG(1) << "File cannot be created: " << filename
88  << " error: " << ofs.rdstate();
89  return false;
90  }
91  ofs << str;
92  return true;
93 }
94 
95 // IO-specific proto functions: we will deal with the protocol buffer lite and
96 // full versions differently.
97 
98 #ifdef CAFFE2_USE_LITE_PROTO
99 
100 // Lite runtime.
101 
102 namespace {
103 class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
104  public:
105  explicit IfstreamInputStream(const string& filename)
106  : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
107  ~IfstreamInputStream() { ifs_.close(); }
108 
109  int Read(void* buffer, int size) {
110  if (!ifs_) {
111  return -1;
112  }
113  ifs_.read(static_cast<char*>(buffer), size);
114  return ifs_.gcount();
115  }
116 
117  private:
118  std::ifstream ifs_;
119 };
120 } // namespace
121 
122 C10_EXPORT string ProtoDebugString(const MessageLite& proto) {
123  string serialized = proto.SerializeAsString();
124  for (char& c : serialized) {
125  if (c < 0x20 || c >= 0x7f) {
126  c = '?';
127  }
128  }
129  return serialized;
130 }
131 
132 C10_EXPORT bool ParseProtoFromLargeString(
133  const string& str,
134  MessageLite* proto) {
135  ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
136  ::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
137  // Set PlanDef message size limit to 2G.
138  coded_stream.SetTotalBytesLimit(2147483647, 512LL << 20);
139  return proto->ParseFromCodedStream(&coded_stream);
140 }
141 
142 C10_EXPORT bool ReadProtoFromBinaryFile(
143  const char* filename,
144  MessageLite* proto) {
145  ::google::protobuf::io::CopyingInputStreamAdaptor stream(
146  new IfstreamInputStream(filename));
147  stream.SetOwnsCopyingStream(true);
148  // Total bytes hard limit / warning limit are set to 2GB and 512MB
149  // respectively.
150  ::google::protobuf::io::CodedInputStream coded_stream(&stream);
151  coded_stream.SetTotalBytesLimit(2147483647, 512LL << 20);
152  return proto->ParseFromCodedStream(&coded_stream);
153 }
154 
155 C10_EXPORT void WriteProtoToBinaryFile(
156  const MessageLite& /*proto*/,
157  const char* /*filename*/) {
158  LOG(FATAL) << "Not implemented yet.";
159 }
160 
161 #else // CAFFE2_USE_LITE_PROTO
162 
163 // Full protocol buffer.
164 
165 using ::google::protobuf::io::FileInputStream;
166 using ::google::protobuf::io::FileOutputStream;
167 using ::google::protobuf::io::ZeroCopyInputStream;
168 using ::google::protobuf::io::CodedInputStream;
169 using ::google::protobuf::io::ZeroCopyOutputStream;
170 using ::google::protobuf::io::CodedOutputStream;
171 using ::google::protobuf::Message;
172 
173 namespace TextFormat {
174 C10_EXPORT bool ParseFromString(const string& spec, Message* proto) {
175  string bc_spec = spec;
176 
177  {
178  auto num_replaced = c10::ReplaceAll(bc_spec, "cuda_gpu_id", "device_id");
179  if (num_replaced) {
180  LOG(ERROR) << "Your model was serialized in Protobuf TextFormat and "
181  << "it has "
182  << num_replaced
183  << " places using the deprecated field name 'cuda_gpu_id'!\n"
184  << spec
185  << "\nPlease re-export your model in Protobuf binary format "
186  << "to make it backward compatible for field renaming.";
187  }
188  }
189 
190  return ::google::protobuf::TextFormat::ParseFromString(std::move(bc_spec), proto);
191 }
192 } // namespace TextFormat
193 
194 C10_EXPORT string ProtoDebugString(const Message& proto) {
195  return proto.ShortDebugString();
196 }
197 
198 C10_EXPORT bool ParseProtoFromLargeString(const string& str, Message* proto) {
199  ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
200  ::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
201  // Set PlanDef message size limit to 2G.
202  coded_stream.SetTotalBytesLimit(2147483647, 512LL << 20);
203  return proto->ParseFromCodedStream(&coded_stream);
204 }
205 
206 C10_EXPORT bool ReadProtoFromTextFile(const char* filename, Message* proto) {
207  int fd = open(filename, O_RDONLY);
208  CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
209  FileInputStream* input = new FileInputStream(fd);
210  bool success = google::protobuf::TextFormat::Parse(input, proto);
211  delete input;
212  close(fd);
213  return success;
214 }
215 
216 C10_EXPORT void WriteProtoToTextFile(
217  const Message& proto,
218  const char* filename) {
219  int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
220  FileOutputStream* output = new FileOutputStream(fd);
221  CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output));
222  delete output;
223  close(fd);
224 }
225 
226 C10_EXPORT bool ReadProtoFromBinaryFile(
227  const char* filename,
228  MessageLite* proto) {
229 #if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified
230  int fd = open(filename, O_RDONLY | O_BINARY);
231 #else
232  int fd = open(filename, O_RDONLY);
233 #endif
234  CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
235  std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
236  std::unique_ptr<CodedInputStream> coded_input(
237  new CodedInputStream(raw_input.get()));
238  // A hack to manually allow using very large protocol buffers.
239  coded_input->SetTotalBytesLimit(2147483647, 536870912);
240  bool success = proto->ParseFromCodedStream(coded_input.get());
241  coded_input.reset();
242  raw_input.reset();
243  close(fd);
244  return success;
245 }
246 
247 C10_EXPORT void WriteProtoToBinaryFile(
248  const MessageLite& proto,
249  const char* filename) {
250  int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
251  CAFFE_ENFORCE_NE(
252  fd, -1, "File cannot be created: ", filename, " error number: ", errno);
253  std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd));
254  std::unique_ptr<CodedOutputStream> coded_output(
255  new CodedOutputStream(raw_output.get()));
256  CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get()));
257  coded_output.reset();
258  raw_output.reset();
259  close(fd);
260 }
261 
262 #endif // CAFFE2_USE_LITE_PROTO
263 
264 C10_EXPORT ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
265  for (auto& arg : def.arg()) {
266  if (arg_map_.count(arg.name())) {
267  if (arg.SerializeAsString() != arg_map_[arg.name()].SerializeAsString()) {
268  // If there are two arguments of the same name but different contents,
269  // we will throw an error.
270  CAFFE_THROW(
271  "Found argument of the same name ",
272  arg.name(),
273  "but with different contents.",
274  ProtoDebugString(def));
275  } else {
276  LOG(WARNING) << "Duplicated argument name [" << arg.name()
277  << "] found in operator def: "
278  << ProtoDebugString(def);
279  }
280  }
281  arg_map_[arg.name()] = arg;
282  }
283 }
284 
285 C10_EXPORT ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
286  for (auto& arg : netdef.arg()) {
287  CAFFE_ENFORCE(
288  arg_map_.count(arg.name()) == 0,
289  "Duplicated argument name [", arg.name(), "] found in net def: ",
290  ProtoDebugString(netdef));
291  arg_map_[arg.name()] = arg;
292  }
293 }
294 
295 C10_EXPORT bool ArgumentHelper::HasArgument(const string& name) const {
296  return arg_map_.count(name);
297 }
298 
299 namespace {
300 // Helper function to verify that conversion between types won't loose any
301 // significant bit.
302 template <typename InputType, typename TargetType>
303 bool SupportsLosslessConversion(const InputType& value) {
304  return static_cast<InputType>(static_cast<TargetType>(value)) == value;
305 }
306 }
307 
308 bool operator==(const NetDef& l, const NetDef& r) {
309  return l.SerializeAsString() == r.SerializeAsString();
310 }
311 
312 std::ostream& operator<<(std::ostream& output, const NetDef& n) {
313  output << n.SerializeAsString();
314  return output;
315 }
316 
317 #define INSTANTIATE_GET_SINGLE_ARGUMENT( \
318  T, fieldname, enforce_lossless_conversion) \
319  template <> \
320  C10_EXPORT T ArgumentHelper::GetSingleArgument<T>( \
321  const string& name, const T& default_value) const { \
322  if (arg_map_.count(name) == 0) { \
323  VLOG(1) << "Using default parameter value " << default_value \
324  << " for parameter " << name; \
325  return default_value; \
326  } \
327  CAFFE_ENFORCE( \
328  arg_map_.at(name).has_##fieldname(), \
329  "Argument ", \
330  name, \
331  " does not have the right field: expected field " #fieldname); \
332  auto value = arg_map_.at(name).fieldname(); \
333  if (enforce_lossless_conversion) { \
334  auto supportsConversion = \
335  SupportsLosslessConversion<decltype(value), T>(value); \
336  CAFFE_ENFORCE( \
337  supportsConversion, \
338  "Value", \
339  value, \
340  " of argument ", \
341  name, \
342  "cannot be represented correctly in a target type"); \
343  } \
344  return static_cast<T>(value); \
345  } \
346  template <> \
347  C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType<T>( \
348  const string& name) const { \
349  if (arg_map_.count(name) == 0) { \
350  return false; \
351  } \
352  return arg_map_.at(name).has_##fieldname(); \
353  }
354 
355 INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
356 INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false)
357 INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false)
358 INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
359 INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
360 INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true)
361 INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
362 INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
363 INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
364 INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
365 INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
366 INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
367 #undef INSTANTIATE_GET_SINGLE_ARGUMENT
368 
369 #define INSTANTIATE_GET_REPEATED_ARGUMENT( \
370  T, fieldname, enforce_lossless_conversion) \
371  template <> \
372  C10_EXPORT vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
373  const string& name, const std::vector<T>& default_value) const { \
374  if (arg_map_.count(name) == 0) { \
375  return default_value; \
376  } \
377  vector<T> values; \
378  for (const auto& v : arg_map_.at(name).fieldname()) { \
379  if (enforce_lossless_conversion) { \
380  auto supportsConversion = \
381  SupportsLosslessConversion<decltype(v), T>(v); \
382  CAFFE_ENFORCE( \
383  supportsConversion, \
384  "Value", \
385  v, \
386  " of argument ", \
387  name, \
388  "cannot be represented correctly in a target type"); \
389  } \
390  values.push_back(static_cast<T>(v)); \
391  } \
392  return values; \
393  }
394 
395 INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false)
396 INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false)
397 INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false)
398 INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
399 INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
400 INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true)
401 INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
402 INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
403 INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
404 INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
405 INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
406 INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false)
407 #undef INSTANTIATE_GET_REPEATED_ARGUMENT
408 
409 #define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
410  template <> \
411  C10_EXPORT Argument MakeArgument(const string& name, const T& value) { \
412  Argument arg; \
413  arg.set_name(name); \
414  arg.set_##fieldname(value); \
415  return arg; \
416  }
417 
418 CAFFE2_MAKE_SINGULAR_ARGUMENT(bool, i)
419 CAFFE2_MAKE_SINGULAR_ARGUMENT(float, f)
420 CAFFE2_MAKE_SINGULAR_ARGUMENT(int, i)
421 CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
422 CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s)
423 #undef CAFFE2_MAKE_SINGULAR_ARGUMENT
424 
425 template <>
426 C10_EXPORT bool ArgumentHelper::RemoveArgument(OperatorDef& def, int index);
427 template <>
428 bool ArgumentHelper::RemoveArgument(NetDef& def, int index);
429 
430 template <>
431 C10_EXPORT Argument MakeArgument(const string& name, const MessageLite& value) {
432  Argument arg;
433  arg.set_name(name);
434  arg.set_s(value.SerializeAsString());
435  return arg;
436 }
437 
438 #define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \
439  template <> \
440  C10_EXPORT Argument MakeArgument( \
441  const string& name, const vector<T>& value) { \
442  Argument arg; \
443  arg.set_name(name); \
444  for (const auto& v : value) { \
445  arg.add_##fieldname(v); \
446  } \
447  return arg; \
448  }
449 
450 CAFFE2_MAKE_REPEATED_ARGUMENT(float, floats)
451 CAFFE2_MAKE_REPEATED_ARGUMENT(int, ints)
452 CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints)
453 CAFFE2_MAKE_REPEATED_ARGUMENT(string, strings)
454 #undef CAFFE2_MAKE_REPEATED_ARGUMENT
455 
456 C10_EXPORT bool HasOutput(const OperatorDef& op, const std::string& output) {
457  for (const auto& outp : op.output()) {
458  if (outp == output) {
459  return true;
460  }
461  }
462  return false;
463 }
464 
465 C10_EXPORT bool HasInput(const OperatorDef& op, const std::string& input) {
466  for (const auto& inp : op.input()) {
467  if (inp == input) {
468  return true;
469  }
470  }
471  return false;
472 }
473 
474 // Return the argument index or -1 if it does not exist.
475 C10_EXPORT int GetArgumentIndex(
476  const google::protobuf::RepeatedPtrField<Argument>& args,
477  const string& name) {
478  int index = 0;
479  for (const Argument& arg : args) {
480  if (arg.name() == name) {
481  return index;
482  }
483  index++;
484  }
485  return -1;
486 }
487 
488 C10_EXPORT const Argument& GetArgument(
489  const OperatorDef& def,
490  const string& name) {
491  int index = GetArgumentIndex(def.arg(), name);
492  if (index != -1) {
493  return def.arg(index);
494  } else {
495  CAFFE_THROW(
496  "Argument named ",
497  name,
498  " does not exist in operator ",
499  ProtoDebugString(def));
500  }
501 }
502 
503 C10_EXPORT const Argument& GetArgument(const NetDef& def, const string& name) {
504  int index = GetArgumentIndex(def.arg(), name);
505  if (index != -1) {
506  return def.arg(index);
507  } else {
508  CAFFE_THROW(
509  "Argument named ",
510  name,
511  " does not exist in net ",
512  ProtoDebugString(def));
513  }
514 }
515 
516 C10_EXPORT bool GetFlagArgument(
517  const google::protobuf::RepeatedPtrField<Argument>& args,
518  const string& name,
519  bool default_value) {
520  int index = GetArgumentIndex(args, name);
521  if (index != -1) {
522  auto arg = args.Get(index);
523  CAFFE_ENFORCE(
524  arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg));
525  return arg.i();
526  }
527  return default_value;
528 }
529 
530 C10_EXPORT bool GetFlagArgument(
531  const OperatorDef& def,
532  const string& name,
533  bool default_value) {
534  return GetFlagArgument(def.arg(), name, default_value);
535 }
536 
537 C10_EXPORT bool
538 GetFlagArgument(const NetDef& def, const string& name, bool default_value) {
539  return GetFlagArgument(def.arg(), name, default_value);
540 }
541 
542 C10_EXPORT Argument* GetMutableArgument(
543  const string& name,
544  const bool create_if_missing,
545  OperatorDef* def) {
546  for (int i = 0; i < def->arg_size(); ++i) {
547  if (def->arg(i).name() == name) {
548  return def->mutable_arg(i);
549  }
550  }
551  // If no argument of the right name is found...
552  if (create_if_missing) {
553  Argument* arg = def->add_arg();
554  arg->set_name(name);
555  return arg;
556  } else {
557  return nullptr;
558  }
559 }
560 
561 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13