1 #ifndef CAFFE2_CORE_OPERATOR_SCHEMA_H_ 2 #define CAFFE2_CORE_OPERATOR_SCHEMA_H_ 6 #include <initializer_list> 10 #include <unordered_map> 12 #include "c10/util/Registry.h" 13 #include "caffe2/core/common.h" 14 #include "caffe2/core/logging.h" 15 #include "caffe2/proto/caffe2_pb.h" 16 #include "caffe2/utils/filler.h" 22 constexpr
int kCannotComputeNumOutputs = -1;
41 OpSchema() : type_(
"unknown"), file_(
"unknown"), line_(0) {}
42 OpSchema(
const string& type,
const string& file,
const int line)
43 : type_(type), file_(file), line_(line) {}
48 inline const string&
file()
const {
62 inline const char*
doc()
const {
63 return doc_.empty() ?
nullptr : doc_.c_str();
70 bool Verify(
const OperatorDef& def)
const;
82 OpSchema& NumInputs(
int min,
int max);
86 OpSchema& NumInputs(set<int> allowed_input_nums);
90 OpSchema& NumInputs(std::function<
bool(
int)> func);
102 OpSchema& NumOutputs(
int min,
int max);
106 OpSchema& NumOutputs(set<int> allowed_output_nums);
110 OpSchema& NumOutputs(std::function<
bool(
int)> func);
116 OpSchema& NumInputsOutputs(std::function<
bool(
int,
int)> func);
123 OpSchema& OutputCalculator(std::function<
int(
int)> calc);
130 OpSchema& AllowInplace(std::function<
bool(
int,
int)> inplace);
131 OpSchema& AllowInplace(
set<std::pair<int, int>> inplace);
134 OpSchema& EnforceInplace(std::function<
bool(
int,
int)> inplace);
135 OpSchema& EnforceInplace(
set<std::pair<int, int>> inplace);
142 typedef std::function<
143 vector<TensorShape>(
const OperatorDef&,
const vector<TensorShape>&)>
144 TensorInferenceFunctionType;
150 OpSchema& TensorInferenceFunction(TensorInferenceFunctionType
function);
157 static TensorInferenceFunctionType NeedsAllInputShapes(
158 TensorInferenceFunctionType f);
163 OpSchema& InheritOnnxSchema(
const std::string& onnx_schema_name);
169 return InheritOnnxSchema(type_);
177 OpSchema& IdenticalTypeAndShapeOfInput(
int idx);
178 OpSchema& IdenticalTypeAndShapeOfInputDim(
int idx,
int dim);
179 OpSchema& IdenticalTypeAndShapeOfMultipleInputs(
const vector<int>& indices);
180 OpSchema& ScalarType(::caffe2::TensorProto_DataType dt);
187 const OperatorDef& def,
188 const vector<TensorShape>& input_type_shape)
const {
189 return tensor_inference_function_(def, input_type_shape);
198 uint64_t bytes_read{0};
199 uint64_t bytes_written{0};
200 uint64_t params_bytes{0};
207 typedef std::function<
208 struct Cost(const OperatorDef&,
const vector<TensorShape>&)>
216 #if 0 // def _MSC_VER 220 template <
typename T,
221 typename = std::enable_if<
222 std::is_same<CostInferenceFunctionType&&, T>:value
224 inline OpSchema& CostInferenceFunction(
T func) {
232 bool HasCostInferenceFunction()
const {
233 return !!cost_inference_function_;
236 inline struct Cost InferCost(
237 const OperatorDef& def,
238 const vector<TensorShape>& input_tensor_shape)
const {
240 cost_inference_function_,
"Cost inference function not defined.");
241 return (*cost_inference_function_)(def, input_tensor_shape);
245 OpSchema& SetDoc(
const string& doc);
248 Argument(
const char* name,
const char* description,
bool required)
249 : name_{name}, description_{description}, required_{required} {}
251 const char* name()
const {
255 const char* description()
const {
259 bool is_required()
const {
265 const char* description_;
266 const bool required_;
270 Arg(
const char* name,
const char* description,
bool required =
false);
272 #define DECLARE_STANDARD_ARG(name, str) \ 273 static const char* Arg_##name; \ 274 OpSchema& Arg##name(const char* description); 276 DECLARE_STANDARD_ARG(IsTest, is_test)
278 #undef DECLARE_STANDARD_ARG 280 OpSchema& Input(
const int n,
const char* name,
const char* description);
281 OpSchema& Output(
const int n,
const char* name,
const char* description);
296 int CalculateOutput(
int num_input)
const;
298 const std::string& onnx_schema()
const {
302 int min_input()
const {
306 int max_input()
const {
310 int min_output()
const {
314 int max_output()
const {
318 bool num_inputs_allowed(
int x)
const {
319 return num_inputs_allowed_(x);
322 bool num_outputs_allowed(
int x)
const {
323 return num_outputs_allowed_(x);
326 bool num_inputs_outputs_allowed(
int x,
int y)
const {
327 return num_inputs_outputs_allowed_(x, y);
331 return std::numeric_limits<int>::max();
334 bool inplace_enforced(
int x,
int y)
const {
335 return inplace_enforced_(x, y);
338 CAFFE2_API
friend std::ostream& operator<<(std::ostream& out,
const OpSchema& schema);
340 const std::vector<Argument>& args()
const {
344 const std::vector<std::pair<const char*, const char*>>& input_desc()
const {
347 const std::vector<std::pair<const char*, const char*>>& output_desc()
const {
353 bool inputs_can_cross_devices()
const {
354 return inputs_can_cross_devices_;
361 std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>(
362 const OperatorDef& def)>;
369 inline std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
371 return device_inference_function_(def);
381 OpSchema& WeightedValueKeyLengthInputFillers(
385 size_t weight_index);
393 OpSchema& ValueKeyLengthInputFillers(
396 size_t length_index);
402 OpSchema& ValueLengthInputFillers(
size_t value_index,
size_t length_index);
406 std::vector<TensorFiller> InputFillers(
407 const std::vector<std::vector<int64_t>>& shapes)
const;
410 std::vector<TensorFiller> SupplyDenseFillers(
411 const std::vector<std::vector<int64_t>>& shapes);
418 std::vector<Argument> args_{};
419 std::vector<std::pair<const char*, const char*>> input_desc_{};
420 std::vector<std::pair<const char*, const char*>> output_desc_{};
423 int max_input_ = std::numeric_limits<int>::max();
425 int max_output_ = std::numeric_limits<int>::max();
426 bool private_ =
false;
427 bool inputs_can_cross_devices_ =
false;
428 std::function<bool(int)> num_inputs_allowed_ = [](int) {
return true; };
429 std::function<bool(int)> num_outputs_allowed_ = [](int) {
return true; };
430 std::function<bool(int, int)> num_inputs_outputs_allowed_ = [](int, int) {
433 std::function<int(int)> calculate_output_;
435 std::function<bool(int, int)> inplace_allowed_ = [](int, int) {
438 std::function<bool(int, int)> inplace_enforced_ = [](int, int) {
441 TensorInferenceFunctionType tensor_inference_function_ =
442 [](
const OperatorDef& def,
const vector<TensorShape>&) {
443 vector<TensorShape> out;
444 for (
int i = 0; i < def.output_size(); i++) {
446 ts.set_unknown_shape(
true);
451 std::unique_ptr<CostInferenceFunctionType> cost_inference_function_ =
nullptr;
453 [](
const OperatorDef& def) {
455 def.has_device_option() ? def.device_option() : DeviceOption();
456 vector<DeviceOption> in_dev(def.input_size(), op_device);
457 vector<DeviceOption> out_dev(def.output_size(), op_device);
458 return std::make_pair(in_dev, out_dev);
461 std::function<std::vector<TensorFiller>(
462 const std::vector<std::vector<int64_t>>&)>
464 [
this](
const std::vector<std::vector<int64_t>>& shapes) {
465 return SupplyDenseFillers(shapes);
475 NewSchema(
const string& key,
const string& file,
const int line) {
477 auto it = m.find(key);
479 const auto& schema = it->second;
480 std::ios_base::Init init;
481 std::cerr <<
"Trying to register schema with name " << key
482 <<
" from file " << file <<
" line " << line
483 <<
", but it is already registered from file " << schema.file()
484 <<
" line " << schema.line();
487 m.emplace(std::make_pair(key,
OpSchema(key, file, line)));
491 static const OpSchema* Schema(
const string& key) {
493 auto it = m.find(key);
515 static CaffeMap<string, OpSchema>& map();
519 template <
typename T_I =
int>
520 inline TensorShape CreateTensorShape(
522 ::caffe2::TensorProto_DataType dt) {
527 ts.set_data_type(dt);
532 inline vector<int64_t> GetDimsVector(
const TensorShape& shape) {
533 vector<int64_t> dims;
534 for (
auto d : shape.dims()) {
541 inline uint64_t nElemFromDim(
const TensorShape& X,
int dim = 0) {
542 CAFFE_ENFORCE_GE(dim, 0,
"Invalid maximum index specified");
545 for (
int i = dim; i < X.dims_size(); ++i) {
552 inline uint64_t nElemBetweenDim(
const TensorShape& X,
int start,
int stop) {
553 CAFFE_ENFORCE_GE(start, 0,
"Invalid maximum index specified");
554 CAFFE_ENFORCE_LE(stop, X.dims_size(),
"Invalid maximum index specified");
557 for (
int i = start; i < stop; ++i) {
564 inline std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
565 InferOpInputOutputDevice(
const OperatorDef& op) {
566 auto op_schema = OpSchemaRegistry::Schema(op.type());
569 return op_schema->InferDevice(op);
574 return temp_schema.InferDevice(op);
578 template <u
int64_t OpsPerPo
int>
581 const vector<TensorShape>& inputs) {
583 const TensorShape X = inputs[0];
584 uint64_t nElemX = nElemFromDim(X);
585 uint64_t nElemRead = 0;
586 for (
size_t i = 0; i < inputs.size(); ++i) {
587 nElemRead += nElemFromDim(inputs[i]);
590 c.flops = nElemX * OpsPerPoint;
591 c.bytes_read = nElemRead *
sizeof(X.data_type());
592 c.bytes_written = nElemX *
sizeof(X.data_type());
598 #ifndef CAFFE2_NO_OPERATOR_SCHEMA 600 #define OPERATOR_SCHEMA(name) \ 601 C10_EXPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ 602 static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ 603 &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) 605 #else // CAFFE2_NO_OPERATOR_SCHEMA 607 #define OPERATOR_SCHEMA(name) \ 608 C10_EXPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ 609 static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ 610 1 ? nullptr : &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) 612 #endif // CAFFE2_NO_OPERATOR_SCHEMA 614 #ifdef CAFFE2_NO_GRADIENT_OPS 616 #define GRADIENT_OPERATOR_SCHEMA(name) \ 617 C10_EXPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \ 618 static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \ 619 1 ? nullptr : &OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) 623 #define GRADIENT_OPERATOR_SCHEMA(name) OPERATOR_SCHEMA(name) 626 #endif // CAFFE2_CORE_OPERATOR_SCHEMA_H_ std::function< std::pair< std::vector< DeviceOption >, std::vector< DeviceOption >>(const OperatorDef &def)> DeviceInferenceFunctionType
Returns the required device location of inputs and outputs.
A class to record the schema of an op.
vector< TensorShape > InferTensor(const OperatorDef &def, const vector< TensorShape > &input_type_shape) const
A function to allow one to infer the type and shape from the op schema.
A registry to hold all the operator schemas.
int line() const
Returns the line in file that the op schema is registered from.
const char * doc() const
Returns the docstring of the op schema.
std::pair< std::vector< DeviceOption >, std::vector< DeviceOption > > InferDevice(const OperatorDef &def) const
Infer required device location of an op's inputs and outputs.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
OpSchema & InheritOnnxSchema()
Shortcut to InheritOnnxSchema(type_)
const string & file() const
Returns the file that the op schema is registered from.
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...