1 #include "caffe2/utils/proto_utils.h" 3 #include <c10/core/DeviceType.h> 8 #include <unordered_set> 10 #include <google/protobuf/io/coded_stream.h> 12 #ifndef CAFFE2_USE_LITE_PROTO 13 #include <google/protobuf/text_format.h> 14 #include <google/protobuf/io/zero_copy_stream_impl.h> 16 #include <google/protobuf/io/zero_copy_stream_impl_lite.h> 17 #endif // !CAFFE2_USE_LITE_PROTO 19 #include "caffe2/core/logging.h" 21 using ::google::protobuf::MessageLite;
25 C10_EXPORT std::string DeviceTypeName(
const int32_t& d) {
26 return at::DeviceTypeName(static_cast<at::DeviceType>(d));
29 C10_EXPORT
int DeviceId(
const DeviceOption& option) {
30 switch (option.device_type()) {
32 return option.numa_node_id();
35 return option.device_id();
37 return option.numa_node_id();
39 CAFFE_THROW(
"Unknown device id for device type: ", option.device_type());
43 C10_EXPORT
bool IsSameDevice(
const DeviceOption& lhs,
const DeviceOption& rhs) {
45 lhs.device_type() == rhs.device_type() &&
46 lhs.device_id() == rhs.device_id() &&
47 lhs.node_name() == rhs.node_name() &&
48 lhs.numa_node_id() == rhs.numa_node_id());
51 C10_EXPORT
bool IsCPUDeviceType(
int device_type) {
52 static const std::unordered_set<int> cpu_types{
58 return cpu_types.count(device_type);
61 C10_EXPORT
bool IsGPUDeviceType(
int device_type) {
62 static const std::unordered_set<int> gpu_types{
66 return gpu_types.count(device_type);
69 C10_EXPORT
bool ReadStringFromFile(
const char* filename,
string* str) {
70 std::ifstream ifs(filename, std::ios::in);
72 VLOG(1) <<
"File cannot be opened: " << filename
73 <<
" error: " << ifs.rdstate();
76 ifs.seekg(0, std::ios::end);
77 size_t n = ifs.tellg();
80 ifs.read(&(*str)[0], n);
84 C10_EXPORT
bool WriteStringToFile(
const string& str,
const char* filename) {
85 std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
87 VLOG(1) <<
"File cannot be created: " << filename
88 <<
" error: " << ofs.rdstate();
98 #ifdef CAFFE2_USE_LITE_PROTO 103 class IfstreamInputStream :
public ::google::protobuf::io::CopyingInputStream {
105 explicit IfstreamInputStream(
const string& filename)
106 : ifs_(filename.c_str(),
std::ios::in |
std::ios::binary) {}
107 ~IfstreamInputStream() { ifs_.close(); }
109 int Read(
void* buffer,
int size) {
113 ifs_.read(static_cast<char*>(buffer), size);
114 return ifs_.gcount();
122 C10_EXPORT
string ProtoDebugString(
const MessageLite& proto) {
123 string serialized = proto.SerializeAsString();
124 for (
char& c : serialized) {
125 if (c < 0x20 || c >= 0x7f) {
132 C10_EXPORT
bool ParseProtoFromLargeString(
134 MessageLite* proto) {
135 ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
136 ::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
138 coded_stream.SetTotalBytesLimit(2147483647, 512LL << 20);
139 return proto->ParseFromCodedStream(&coded_stream);
142 C10_EXPORT
bool ReadProtoFromBinaryFile(
143 const char* filename,
144 MessageLite* proto) {
145 ::google::protobuf::io::CopyingInputStreamAdaptor stream(
146 new IfstreamInputStream(filename));
147 stream.SetOwnsCopyingStream(
true);
150 ::google::protobuf::io::CodedInputStream coded_stream(&stream);
151 coded_stream.SetTotalBytesLimit(2147483647, 512LL << 20);
152 return proto->ParseFromCodedStream(&coded_stream);
155 C10_EXPORT
void WriteProtoToBinaryFile(
158 LOG(FATAL) <<
"Not implemented yet.";
161 #else // CAFFE2_USE_LITE_PROTO 165 using ::google::protobuf::io::FileInputStream;
166 using ::google::protobuf::io::FileOutputStream;
167 using ::google::protobuf::io::ZeroCopyInputStream;
168 using ::google::protobuf::io::CodedInputStream;
169 using ::google::protobuf::io::ZeroCopyOutputStream;
170 using ::google::protobuf::io::CodedOutputStream;
171 using ::google::protobuf::Message;
173 namespace TextFormat {
174 C10_EXPORT
bool ParseFromString(
const string& spec, Message* proto) {
175 string bc_spec = spec;
178 auto num_replaced = c10::ReplaceAll(bc_spec,
"cuda_gpu_id",
"device_id");
180 LOG(ERROR) <<
"Your model was serialized in Protobuf TextFormat and " 183 <<
" places using the deprecated field name 'cuda_gpu_id'!\n" 185 <<
"\nPlease re-export your model in Protobuf binary format " 186 <<
"to make it backward compatible for field renaming.";
190 return ::google::protobuf::TextFormat::ParseFromString(std::move(bc_spec), proto);
194 C10_EXPORT
string ProtoDebugString(
const Message& proto) {
195 return proto.ShortDebugString();
198 C10_EXPORT
bool ParseProtoFromLargeString(
const string& str, Message* proto) {
199 ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
200 ::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
202 coded_stream.SetTotalBytesLimit(2147483647, 512LL << 20);
203 return proto->ParseFromCodedStream(&coded_stream);
206 C10_EXPORT
bool ReadProtoFromTextFile(
const char* filename, Message* proto) {
207 int fd = open(filename, O_RDONLY);
208 CAFFE_ENFORCE_NE(fd, -1,
"File not found: ", filename);
209 FileInputStream* input =
new FileInputStream(fd);
210 bool success = google::protobuf::TextFormat::Parse(input, proto);
216 C10_EXPORT
void WriteProtoToTextFile(
217 const Message& proto,
218 const char* filename) {
219 int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
220 FileOutputStream* output =
new FileOutputStream(fd);
221 CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output));
226 C10_EXPORT
bool ReadProtoFromBinaryFile(
227 const char* filename,
228 MessageLite* proto) {
229 #if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified 230 int fd = open(filename, O_RDONLY | O_BINARY);
232 int fd = open(filename, O_RDONLY);
234 CAFFE_ENFORCE_NE(fd, -1,
"File not found: ", filename);
235 std::unique_ptr<ZeroCopyInputStream> raw_input(
new FileInputStream(fd));
236 std::unique_ptr<CodedInputStream> coded_input(
237 new CodedInputStream(raw_input.get()));
239 coded_input->SetTotalBytesLimit(2147483647, 536870912);
240 bool success = proto->ParseFromCodedStream(coded_input.get());
247 C10_EXPORT
void WriteProtoToBinaryFile(
248 const MessageLite& proto,
249 const char* filename) {
250 int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
252 fd, -1,
"File cannot be created: ", filename,
" error number: ", errno);
253 std::unique_ptr<ZeroCopyOutputStream> raw_output(
new FileOutputStream(fd));
254 std::unique_ptr<CodedOutputStream> coded_output(
255 new CodedOutputStream(raw_output.get()));
256 CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get()));
257 coded_output.reset();
262 #endif // CAFFE2_USE_LITE_PROTO 264 C10_EXPORT ArgumentHelper::ArgumentHelper(
const OperatorDef& def) {
265 for (
auto& arg : def.arg()) {
266 if (arg_map_.count(arg.name())) {
267 if (arg.SerializeAsString() != arg_map_[arg.name()].SerializeAsString()) {
271 "Found argument of the same name ",
273 "but with different contents.",
274 ProtoDebugString(def));
276 LOG(WARNING) <<
"Duplicated argument name [" << arg.name()
277 <<
"] found in operator def: " 278 << ProtoDebugString(def);
281 arg_map_[arg.name()] = arg;
285 C10_EXPORT ArgumentHelper::ArgumentHelper(
const NetDef& netdef) {
286 for (
auto& arg : netdef.arg()) {
288 arg_map_.count(arg.name()) == 0,
289 "Duplicated argument name [", arg.name(),
"] found in net def: ",
290 ProtoDebugString(netdef));
291 arg_map_[arg.name()] = arg;
295 C10_EXPORT
bool ArgumentHelper::HasArgument(
const string& name)
const {
296 return arg_map_.count(name);
302 template <
typename InputType,
typename TargetType>
303 bool SupportsLosslessConversion(
const InputType& value) {
304 return static_cast<InputType
>(
static_cast<TargetType
>(value)) == value;
308 bool operator==(
const NetDef& l,
const NetDef& r) {
309 return l.SerializeAsString() == r.SerializeAsString();
312 std::ostream& operator<<(std::ostream& output,
const NetDef& n) {
313 output << n.SerializeAsString();
317 #define INSTANTIATE_GET_SINGLE_ARGUMENT( \ 318 T, fieldname, enforce_lossless_conversion) \ 320 C10_EXPORT T ArgumentHelper::GetSingleArgument<T>( \ 321 const string& name, const T& default_value) const { \ 322 if (arg_map_.count(name) == 0) { \ 323 VLOG(1) << "Using default parameter value " << default_value \ 324 << " for parameter " << name; \ 325 return default_value; \ 328 arg_map_.at(name).has_##fieldname(), \ 331 " does not have the right field: expected field " #fieldname); \ 332 auto value = arg_map_.at(name).fieldname(); \ 333 if (enforce_lossless_conversion) { \ 334 auto supportsConversion = \ 335 SupportsLosslessConversion<decltype(value), T>(value); \ 337 supportsConversion, \ 342 "cannot be represented correctly in a target type"); \ 344 return static_cast<T>(value); \ 347 C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType<T>( \ 348 const string& name) const { \ 349 if (arg_map_.count(name) == 0) { \ 352 return arg_map_.at(name).has_##fieldname(); \ 355 INSTANTIATE_GET_SINGLE_ARGUMENT(
float, f,
false)
356 INSTANTIATE_GET_SINGLE_ARGUMENT(
double, f, false)
357 INSTANTIATE_GET_SINGLE_ARGUMENT(
bool, i, false)
358 INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
359 INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
360 INSTANTIATE_GET_SINGLE_ARGUMENT(
int, i, true)
361 INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
362 INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
363 INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
364 INSTANTIATE_GET_SINGLE_ARGUMENT(
size_t, i, true)
365 INSTANTIATE_GET_SINGLE_ARGUMENT(
string, s, false)
366 INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
367 #undef INSTANTIATE_GET_SINGLE_ARGUMENT 369 #define INSTANTIATE_GET_REPEATED_ARGUMENT( \ 370 T, fieldname, enforce_lossless_conversion) \ 372 C10_EXPORT vector<T> ArgumentHelper::GetRepeatedArgument<T>( \ 373 const string& name, const std::vector<T>& default_value) const { \ 374 if (arg_map_.count(name) == 0) { \ 375 return default_value; \ 378 for (const auto& v : arg_map_.at(name).fieldname()) { \ 379 if (enforce_lossless_conversion) { \ 380 auto supportsConversion = \ 381 SupportsLosslessConversion<decltype(v), T>(v); \ 383 supportsConversion, \ 388 "cannot be represented correctly in a target type"); \ 390 values.push_back(static_cast<T>(v)); \ 395 INSTANTIATE_GET_REPEATED_ARGUMENT(
float, floats,
false)
396 INSTANTIATE_GET_REPEATED_ARGUMENT(
double, floats, false)
397 INSTANTIATE_GET_REPEATED_ARGUMENT(
bool, ints, false)
398 INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
399 INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
400 INSTANTIATE_GET_REPEATED_ARGUMENT(
int, ints, true)
401 INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
402 INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
403 INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
404 INSTANTIATE_GET_REPEATED_ARGUMENT(
size_t, ints, true)
405 INSTANTIATE_GET_REPEATED_ARGUMENT(
string, strings, false)
406 INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false)
407 #undef INSTANTIATE_GET_REPEATED_ARGUMENT 409 #define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ 411 C10_EXPORT Argument MakeArgument(const string& name, const T& value) { \ 413 arg.set_name(name); \ 414 arg.set_##fieldname(value); \ 418 CAFFE2_MAKE_SINGULAR_ARGUMENT(
bool, i)
419 CAFFE2_MAKE_SINGULAR_ARGUMENT(
float, f)
420 CAFFE2_MAKE_SINGULAR_ARGUMENT(
int, i)
421 CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
422 CAFFE2_MAKE_SINGULAR_ARGUMENT(
string, s)
423 #undef CAFFE2_MAKE_SINGULAR_ARGUMENT 426 C10_EXPORT
bool ArgumentHelper::RemoveArgument(OperatorDef& def,
int index);
428 bool ArgumentHelper::RemoveArgument(NetDef& def,
int index);
431 C10_EXPORT Argument MakeArgument(
const string& name,
const MessageLite& value) {
434 arg.set_s(value.SerializeAsString());
438 #define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \ 440 C10_EXPORT Argument MakeArgument( \ 441 const string& name, const vector<T>& value) { \ 443 arg.set_name(name); \ 444 for (const auto& v : value) { \ 445 arg.add_##fieldname(v); \ 450 CAFFE2_MAKE_REPEATED_ARGUMENT(
float, floats)
451 CAFFE2_MAKE_REPEATED_ARGUMENT(
int, ints)
452 CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints)
453 CAFFE2_MAKE_REPEATED_ARGUMENT(
string, strings)
454 #undef CAFFE2_MAKE_REPEATED_ARGUMENT 456 C10_EXPORT
bool HasOutput(
const OperatorDef& op,
const std::string& output) {
457 for (
const auto& outp : op.output()) {
458 if (outp == output) {
465 C10_EXPORT
bool HasInput(
const OperatorDef& op,
const std::string& input) {
466 for (
const auto& inp : op.input()) {
475 C10_EXPORT
int GetArgumentIndex(
476 const google::protobuf::RepeatedPtrField<Argument>& args,
477 const string& name) {
479 for (
const Argument& arg : args) {
480 if (arg.name() == name) {
488 C10_EXPORT
const Argument& GetArgument(
489 const OperatorDef& def,
490 const string& name) {
491 int index = GetArgumentIndex(def.arg(), name);
493 return def.arg(index);
498 " does not exist in operator ",
499 ProtoDebugString(def));
503 C10_EXPORT
const Argument& GetArgument(
const NetDef& def,
const string& name) {
504 int index = GetArgumentIndex(def.arg(), name);
506 return def.arg(index);
511 " does not exist in net ",
512 ProtoDebugString(def));
516 C10_EXPORT
bool GetFlagArgument(
517 const google::protobuf::RepeatedPtrField<Argument>& args,
519 bool default_value) {
520 int index = GetArgumentIndex(args, name);
522 auto arg = args.Get(index);
524 arg.has_i(),
"Can't parse argument as bool: ", ProtoDebugString(arg));
527 return default_value;
530 C10_EXPORT
bool GetFlagArgument(
531 const OperatorDef& def,
533 bool default_value) {
534 return GetFlagArgument(def.arg(), name, default_value);
538 GetFlagArgument(
const NetDef& def,
const string& name,
bool default_value) {
539 return GetFlagArgument(def.arg(), name, default_value);
542 C10_EXPORT Argument* GetMutableArgument(
544 const bool create_if_missing,
546 for (
int i = 0; i < def->arg_size(); ++i) {
547 if (def->arg(i).name() == name) {
548 return def->mutable_arg(i);
552 if (create_if_missing) {
553 Argument* arg = def->add_arg();
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...