1 #ifndef CAFFE2_UTILS_PROTO_UTILS_H_ 2 #define CAFFE2_UTILS_PROTO_UTILS_H_ 4 #ifdef CAFFE2_USE_LITE_PROTO 5 #include <google/protobuf/message_lite.h> 6 #else // CAFFE2_USE_LITE_PROTO 7 #include <google/protobuf/message.h> 8 #endif // !CAFFE2_USE_LITE_PROTO 10 #include "caffe2/core/logging.h" 11 #include "caffe2/utils/proto_wrap.h" 12 #include "caffe2/proto/caffe2_pb.h" 17 using ::google::protobuf::MessageLite;
26 CAFFE2_API std::string DeviceTypeName(
const int32_t& d);
28 CAFFE2_API
int DeviceId(
const DeviceOption& option);
31 CAFFE2_API
bool IsSameDevice(
const DeviceOption& lhs,
const DeviceOption& rhs);
33 CAFFE2_API
bool IsCPUDeviceType(
int device_type);
34 CAFFE2_API
bool IsGPUDeviceType(
int device_type);
37 CAFFE2_API
bool ReadStringFromFile(
const char* filename,
string* str);
38 CAFFE2_API
bool WriteStringToFile(
const string& str,
const char* filename);
41 CAFFE2_API
bool ReadProtoFromBinaryFile(
const char* filename, MessageLite* proto);
42 inline bool ReadProtoFromBinaryFile(
const string filename, MessageLite* proto) {
43 return ReadProtoFromBinaryFile(filename.c_str(), proto);
46 CAFFE2_API
void WriteProtoToBinaryFile(
const MessageLite& proto,
const char* filename);
47 inline void WriteProtoToBinaryFile(
const MessageLite& proto,
48 const string& filename) {
49 return WriteProtoToBinaryFile(proto, filename.c_str());
52 #ifdef CAFFE2_USE_LITE_PROTO 54 namespace TextFormat {
55 inline bool ParseFromString(
const string& spec, MessageLite* proto) {
56 LOG(FATAL) <<
"If you are running lite version, you should not be " 57 <<
"calling any text-format protobuffers.";
62 CAFFE2_API
string ProtoDebugString(
const MessageLite& proto);
64 CAFFE2_API
bool ParseProtoFromLargeString(
const string& str, MessageLite* proto);
69 inline bool ReadProtoFromTextFile(
72 LOG(FATAL) <<
"If you are running lite version, you should not be " 73 <<
"calling any text-format protobuffers.";
76 inline bool ReadProtoFromTextFile(
const string filename, MessageLite* proto) {
77 return ReadProtoFromTextFile(filename.c_str(), proto);
80 inline void WriteProtoToTextFile(
83 LOG(FATAL) <<
"If you are running lite version, you should not be " 84 <<
"calling any text-format protobuffers.";
86 inline void WriteProtoToTextFile(
const MessageLite& proto,
87 const string& filename) {
88 return WriteProtoToTextFile(proto, filename.c_str());
91 inline bool ReadProtoFromFile(
const char* filename, MessageLite* proto) {
92 return (ReadProtoFromBinaryFile(filename, proto) ||
93 ReadProtoFromTextFile(filename, proto));
96 inline bool ReadProtoFromFile(
const string& filename, MessageLite* proto) {
97 return ReadProtoFromFile(filename.c_str(), proto);
100 #else // CAFFE2_USE_LITE_PROTO 102 using ::google::protobuf::Message;
104 namespace TextFormat {
105 CAFFE2_API
bool ParseFromString(
const string& spec, Message* proto);
108 CAFFE2_API
string ProtoDebugString(
const Message& proto);
110 CAFFE2_API
bool ParseProtoFromLargeString(
const string& str, Message* proto);
112 CAFFE2_API
bool ReadProtoFromTextFile(
const char* filename, Message* proto);
113 inline bool ReadProtoFromTextFile(
const string filename, Message* proto) {
114 return ReadProtoFromTextFile(filename.c_str(), proto);
117 CAFFE2_API
void WriteProtoToTextFile(
const Message& proto,
const char* filename);
118 inline void WriteProtoToTextFile(
const Message& proto,
const string& filename) {
119 return WriteProtoToTextFile(proto, filename.c_str());
123 inline bool ReadProtoFromFile(
const char* filename, Message* proto) {
124 return (ReadProtoFromBinaryFile(filename, proto) ||
125 ReadProtoFromTextFile(filename, proto));
128 inline bool ReadProtoFromFile(
const string& filename, Message* proto) {
129 return ReadProtoFromFile(filename.c_str(), proto);
132 #endif // CAFFE2_USE_LITE_PROTO 135 class IterableInputs = std::initializer_list<string>,
136 class IterableOutputs = std::initializer_list<string>,
137 class IterableArgs = std::initializer_list<Argument>>
138 OperatorDef CreateOperatorDef(
141 const IterableInputs& inputs,
142 const IterableOutputs& outputs,
143 const IterableArgs& args,
144 const DeviceOption& device_option = DeviceOption(),
145 const string& engine =
"") {
149 for (
const string& in : inputs) {
152 for (
const string& out : outputs) {
155 for (
const Argument& arg : args) {
156 def.add_arg()->CopyFrom(arg);
158 if (device_option.has_device_type()) {
159 def.mutable_device_option()->CopyFrom(device_option);
162 def.set_engine(engine);
170 class IterableInputs = std::initializer_list<string>,
171 class IterableOutputs = std::initializer_list<string>>
172 inline OperatorDef CreateOperatorDef(
175 const IterableInputs& inputs,
176 const IterableOutputs& outputs,
177 const DeviceOption& device_option = DeviceOption(),
178 const string& engine =
"") {
179 return CreateOperatorDef(
184 std::vector<Argument>(),
189 CAFFE2_API
bool HasOutput(
const OperatorDef& op,
const std::string& output);
190 CAFFE2_API
bool HasInput(
const OperatorDef& op,
const std::string& input);
202 template <
typename Def>
203 static bool HasArgument(
const Def& def,
const string& name) {
207 template <
typename Def,
typename T>
208 static T GetSingleArgument(
211 const T& default_value) {
215 template <
typename Def,
typename T>
216 static bool HasSingleArgumentOfType(
const Def& def,
const string& name) {
220 template <
typename Def,
typename T>
221 static vector<T> GetRepeatedArgument(
224 const std::vector<T>& default_value = std::vector<T>()) {
225 return ArgumentHelper(def).GetRepeatedArgument<
T>(name, default_value);
228 template <
typename Def,
typename MessageType>
229 static MessageType GetMessageArgument(
const Def& def,
const string& name) {
233 template <
typename Def,
typename MessageType>
234 static vector<MessageType> GetRepeatedMessageArgument(
236 const string& name) {
237 return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
240 template <
typename Def>
241 static bool RemoveArgument(Def& def,
int index) {
242 if (index >= def.arg_size()) {
245 if (index < def.arg_size() - 1) {
246 def.mutable_arg()->SwapElements(index, def.arg_size() - 1);
248 def.mutable_arg()->RemoveLast();
254 bool HasArgument(
const string& name)
const;
256 template <
typename T>
257 T GetSingleArgument(
const string& name,
const T& default_value)
const;
258 template <
typename T>
259 bool HasSingleArgumentOfType(
const string& name)
const;
260 template <
typename T>
261 vector<T> GetRepeatedArgument(
263 const std::vector<T>& default_value = std::vector<T>())
const;
265 template <
typename MessageType>
266 MessageType GetMessageArgument(
const string& name)
const {
267 CAFFE_ENFORCE(arg_map_.count(name),
"Cannot find parameter named ", name);
269 if (arg_map_.at(name).has_s()) {
271 message.ParseFromString(arg_map_.at(name).s()),
272 "Faild to parse content from the string");
274 VLOG(1) <<
"Return empty message for parameter " << name;
279 template <
typename MessageType>
280 vector<MessageType> GetRepeatedMessageArgument(
const string& name)
const {
281 CAFFE_ENFORCE(arg_map_.count(name),
"Cannot find parameter named ", name);
282 vector<MessageType> messages(arg_map_.at(name).strings_size());
283 for (
int i = 0; i < messages.size(); ++i) {
285 messages[i].ParseFromString(arg_map_.at(name).strings(i)),
286 "Faild to parse content from the string");
292 CaffeMap<string, Argument> arg_map_;
299 CAFFE2_API
const Argument& GetArgument(
const OperatorDef& def,
const string& name);
300 CAFFE2_API
const Argument& GetArgument(
const NetDef& def,
const string& name);
305 CAFFE2_API
bool GetFlagArgument(
306 const OperatorDef& def,
308 bool default_value =
false);
309 CAFFE2_API
bool GetFlagArgument(
312 bool default_value =
false);
314 CAFFE2_API
Argument* GetMutableArgument(
316 const bool create_if_missing,
319 template <
typename T>
320 CAFFE2_API
Argument MakeArgument(
const string& name,
const T& value);
322 template <
typename T>
323 inline void AddArgument(
const string& name,
const T& value, OperatorDef* def) {
324 GetMutableArgument(name,
true, def)->CopyFrom(MakeArgument(name, value));
328 bool inline operator==(
const DeviceOption& dl,
const DeviceOption& dr) {
329 return IsSameDevice(dl, dr);
338 typedef caffe2::DeviceOption argument_type;
339 typedef std::size_t result_type;
340 result_type operator()(argument_type
const& device_option)
const {
341 std::string serialized;
342 CAFFE_ENFORCE(device_option.SerializeToString(&serialized));
343 return std::hash<std::string>{}(serialized);
348 #endif // CAFFE2_UTILS_PROTO_UTILS_H_
A helper class to index into arguments.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...