1 #include "caffe2/core/operator_schema.h" 2 #include "caffe2/core/logging.h" 8 if (def.input_size() < min_input_ || def.input_size() > max_input_) {
9 LOG(ERROR) <<
"Input size " << def.input_size()
10 <<
" not in range [min=" << min_input_ <<
", max=" 11 << max_input_ <<
"].";
14 if (!num_inputs_allowed_(def.input_size())) {
15 LOG(ERROR) <<
"Input size " << def.input_size()
16 <<
" not in allowed input sizes.";
20 if (def.output_size() < min_output_ || def.output_size() > max_output_) {
21 LOG(ERROR) <<
"Output size " << def.output_size()
22 <<
" not in range [min=" << min_output_ <<
", max=" 23 << max_output_ <<
"].";
26 if (!num_outputs_allowed_(def.output_size())) {
27 LOG(ERROR) <<
"Output size " << def.output_size()
28 <<
" not in allowed output sizes.";
31 if (!num_inputs_outputs_allowed_(def.input_size(), def.output_size())) {
32 LOG(ERROR) <<
"Combination of input size " << def.input_size()
33 <<
"and output size " << def.output_size() <<
" not in allowed.";
37 if (calculate_output_) {
38 int expected_nout = calculate_output_(def.input_size());
39 if (expected_nout != kCannotComputeNumOutputs &&
40 def.output_size() != expected_nout) {
41 LOG(ERROR) <<
"Output size " << def.output_size()
42 <<
" not matching expected output size, which is " 49 for (
int in_idx = 0; in_idx < def.input_size(); ++in_idx) {
50 for (
int out_idx = 0; out_idx < def.output_size(); ++out_idx) {
53 if (def.input(in_idx) == def.output(out_idx) &&
54 (!inplace_allowed_(in_idx, out_idx)
55 && !inplace_enforced_(in_idx, out_idx))) {
56 LOG(ERROR) <<
"Input index " << in_idx <<
" and output idx " << out_idx
57 <<
" (" << def.input(in_idx) <<
")" 58 <<
" are set to be in-place but this is actually not " 59 <<
"supported by op " << def.type();
62 if (def.input(in_idx) != def.output(out_idx) &&
63 inplace_enforced_(in_idx, out_idx)) {
64 LOG(ERROR) <<
"Input index " << in_idx <<
" (" << def.input(in_idx) <<
")" 65 <<
" and output idx " << out_idx
66 <<
" (" << def.output(in_idx) <<
")" 67 <<
" are not in-place but should be as required by op " 74 std::set<std::string> present_args{};
75 for (
const auto& arg : def.arg()) {
76 present_args.insert(arg.name());
79 for (
const auto& arg : args()) {
80 if (arg.is_required() &&
81 present_args.find(arg.name()) == present_args.end()) {
82 LOG(ERROR) <<
"Argument '" << arg.name() <<
"' is required for Operator '" 83 << def.type() <<
"'.";
103 num_inputs_allowed_ = func;
109 [allowed_input_nums](
int n)->
bool {
110 return allowed_input_nums.count(n);
125 num_outputs_allowed_ = func;
131 [allowed_output_nums](
int n)->
bool {
132 return allowed_output_nums.count(n);
137 num_inputs_outputs_allowed_ = func;
142 calculate_output_ = calc;
150 OpSchema& OpSchema::AllowInplace(std::function<
bool(
int,
int)> inplace) {
151 inplace_allowed_ = inplace;
155 OpSchema& OpSchema::AllowInplace(
set<std::pair<int, int>> inplace) {
157 [inplace](
int in,
int out)->
bool {
158 return inplace.count(std::make_pair(in, out));
162 OpSchema& OpSchema::AllowOneToOneInplace() {
163 return AllowInplace([](
int in,
int out) {
return in == out; });
166 OpSchema& OpSchema::EnforceInplace(std::function<
bool(
int,
int)> inplace) {
167 inplace_enforced_ = inplace;
171 OpSchema& OpSchema::EnforceInplace(
set<std::pair<int, int>> inplace) {
172 return EnforceInplace(
173 [inplace](
int in,
int out)->
bool {
174 return inplace.count(std::make_pair(in, out));
178 OpSchema& OpSchema::EnforceOneToOneInplace() {
179 return EnforceInplace([](
int in,
int out) {
return in == out; });
187 OpSchema& OpSchema::InputsCanCrossDevices() {
188 inputs_can_cross_devices_ =
true;
193 TensorInferenceFunctionType
function) {
194 tensor_inference_function_ =
function;
199 TensorInferenceFunctionType f) {
200 return [f](
const OperatorDef& def,
const vector<TensorShape>& in) {
201 for (
const auto& in_ts : in) {
202 if (in_ts.unknown_shape()) {
203 vector<TensorShape> out(def.output().size());
204 for (
auto& out_ts : out) {
205 out_ts.set_unknown_shape(
true);
215 onnx_schema_ = onnx_schema_name;
221 [](
const OperatorDef&,
const vector<TensorShape>& input_types) {
222 return vector<TensorShape>(input_types);
226 OpSchema& OpSchema::IdenticalTypeAndShapeOfInput(
int idx) {
228 [idx](
const OperatorDef&,
const vector<TensorShape>& input_types) {
229 vector<TensorShape> out(1);
230 out[0] = input_types[idx];
235 OpSchema& OpSchema::IdenticalTypeAndShapeOfMultipleInputs(
236 const vector<int>& indices) {
238 [indices](
const OperatorDef&,
const vector<TensorShape>& input_types) {
239 vector<TensorShape> out(indices.size());
240 for (
int i = 0; i < indices.size(); i++) {
241 out[i] = input_types[indices.at(i)];
247 OpSchema& OpSchema::IdenticalTypeAndShapeOfInputDim(
int idx,
int dim) {
249 [idx, dim](
const OperatorDef&,
const vector<TensorShape>& input_types) {
250 vector<TensorShape> out(1);
251 out[0].add_dims(input_types[idx].dims(dim));
252 out[0].set_data_type(input_types[idx].data_type());
257 OpSchema& OpSchema::ScalarType(::caffe2::TensorProto_DataType dt) {
259 [dt](
const OperatorDef& def,
const vector<TensorShape>& ) {
261 shape.set_data_type(dt);
262 vector<TensorShape> out(def.output_size(), shape);
268 cost_inference_function_ =
269 caffe2::make_unique<CostInferenceFunctionType>(
function);
273 OpSchema& OpSchema::DeviceInferenceFunction(
275 device_inference_function_ =
function;
285 OpSchema::Arg(
const char* name,
const char* description,
bool required) {
286 args_.push_back(
Argument(name, description, required));
290 #define DEFINE_STANDARG_ARG(name, str) \ 291 CAFFE2_API const char* OpSchema::Arg_##name = #str; \ 292 CAFFE2_API OpSchema& OpSchema::Arg##name(const char* description) { \ 293 return Arg(#str, description, true); \ 296 DEFINE_STANDARG_ARG(IsTest, is_test)
298 #undef DEFINE_STANDARG_ARG 300 OpSchema& OpSchema::Input(
const int n,
const char* name,
const char* description) {
301 if (input_desc_.size() <= (unsigned)n) {
302 input_desc_.resize(n + 1);
304 input_desc_[n] = std::make_pair(name, description);
308 OpSchema& OpSchema::Output(
const int n,
const char* name,
const char* description) {
309 if (output_desc_.size() <= (unsigned)n) {
310 output_desc_.resize(n + 1);
312 output_desc_[n] = std::make_pair(name, description);
324 if (min_output_ == max_output_) {
326 }
else if (calculate_output_) {
327 return calculate_output_(num_input);
329 return kCannotComputeNumOutputs;
334 void SparseLengthsFillerHelper(
335 const std::vector<std::vector<int64_t>>& shapes,
338 std::vector<TensorFiller>* fillers) {
339 CAFFE_ENFORCE_EQ(shapes[length_index].size(), 1);
341 (*fillers)[length_index].SparseLengths(shapes[value_index].front());
344 void SparseWeightsFillerHelper(
345 const std::vector<std::vector<int64_t>>& shapes,
347 std::vector<TensorFiller>* fillers) {
348 (*fillers)[weight_index]
350 .Max(shapes[weight_index].front())
354 void SparseSegmentsFillerHelper(
355 const std::vector<std::vector<int64_t>>& shapes,
357 size_t segment_index,
358 std::vector<TensorFiller>* fillers) {
359 CAFFE_ENFORCE_EQ(shapes[segment_index].size(), 1);
361 (*fillers)[value_index]
363 .Max(shapes[value_index].front() * 2)
365 (*fillers)[segment_index].SparseSegments(shapes[value_index].front() - 1);
374 OpSchema& OpSchema::ValueKeyLengthInputFillers(
377 size_t length_index) {
378 filler_supplier_ = [
this, value_index, key_index, length_index](
379 const std::vector<std::vector<int64_t>>& shapes) {
380 auto fillers = SupplyDenseFillers(shapes);
382 SparseLengthsFillerHelper(shapes, key_index, length_index, &fillers);
384 SparseSegmentsFillerHelper(shapes, value_index, key_index, &fillers);
396 OpSchema& OpSchema::WeightedValueKeyLengthInputFillers(
400 size_t weight_index) {
401 filler_supplier_ = [
this, value_index, key_index, length_index, weight_index](
402 const std::vector<std::vector<int64_t>>& shapes) {
403 auto fillers = SupplyDenseFillers(shapes);
405 SparseLengthsFillerHelper(shapes, key_index, length_index, &fillers);
407 SparseSegmentsFillerHelper(shapes, value_index, key_index, &fillers);
409 SparseWeightsFillerHelper(shapes, weight_index, &fillers);
419 OpSchema& OpSchema::ValueLengthInputFillers(
421 size_t length_index) {
422 filler_supplier_ = [
this, value_index, length_index](
423 const std::vector<std::vector<int64_t>>& shapes) {
424 auto fillers = SupplyDenseFillers(shapes);
426 SparseLengthsFillerHelper(shapes, value_index, length_index, &fillers);
432 OpSchema& OpSchema::DisallowInputFillers() {
434 [
this](
const std::vector<std::vector<int64_t>>& ) {
435 throw std::invalid_argument(type_ +
" does not have input fillers");
436 return std::vector<TensorFiller>();
441 std::vector<TensorFiller> OpSchema::InputFillers(
442 const std::vector<std::vector<int64_t>>& shapes)
const {
443 return filler_supplier_(shapes);
446 std::vector<TensorFiller> OpSchema::SupplyDenseFillers(
447 const std::vector<std::vector<int64_t>>& shapes) {
448 std::vector<TensorFiller> fillers;
449 for (
const auto& shape : shapes) {
450 fillers.emplace_back(shape);
455 C10_EXPORT std::ostream& operator<<(std::ostream& out,
const OpSchema& schema) {
456 if (!schema.args().empty()) {
457 out <<
"Arguments:" << std::endl;
458 for (
const auto& arg : schema.args()) {
459 out <<
" " << arg.name() <<
" : " << arg.description() << std::endl;
462 if (schema.max_input_ > 0) {
463 out <<
"Inputs:" << std::endl;
464 if (!schema.input_desc_.empty()) {
465 for (
size_t i = 0; i < schema.input_desc_.size(); ++i) {
466 const auto& p = schema.input_desc_[i];
467 out <<
" " << i <<
", " << (p.first ? p.first :
"(unnamed)") <<
" : " 468 << (p.second ? p.second :
"(no doc)") << std::endl;
471 out <<
" (no explicit description available)" << std::endl;
474 if (schema.max_output_ > 0) {
475 out <<
"Outputs:" << std::endl;
476 if (!schema.output_desc_.empty()) {
477 for (
size_t i = 0; i < schema.output_desc_.size(); ++i) {
478 const auto& p = schema.output_desc_[i];
479 out <<
" " << i <<
", " << (p.first ? p.first :
"(unnamed)") <<
" : " 480 << (p.second ? p.second :
"(no doc)") << std::endl;
483 out <<
" (no explicit description available)" << std::endl;
490 out <<
"(no documentation yet)" << std::endl;
494 out <<
"Defined at " << schema.file_ <<
":" << schema.line_ << std::endl;
499 CaffeMap<string, OpSchema>& OpSchemaRegistry::map() {
500 static CaffeMap<string, OpSchema> map;
std::function< std::pair< std::vector< DeviceOption >, std::vector< DeviceOption >>(const OperatorDef &def)> DeviceInferenceFunctionType
Returns the required device location of inputs and outputs.
OpSchema & NumInputs(int n)
A single input.
A class to record the schema of an op.
bool Verify(const OperatorDef &def) const
Verifies if an operator definition protobuf matches the pattern specified in the schema.
const char * doc() const
Returns the docstring of the op schema.
OpSchema & OutputCalculator(std::function< int(int)> calc)
Set the output calculator to a user-defined function.
OpSchema & IdenticalTypeAndShape()
Sets the tensor inference function to produce the same output as the input.
OpSchema & SameNumberOfOutput()
Set the number of outputs to be the same as the number of inputs.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
OpSchema & CostInferenceFunction(CostInferenceFunctionType function)
Register the Cost inference function.
OpSchema & InheritOnnxSchema()
Shortcut to InheritOnnxSchema(type_)
OpSchema & NumInputsOutputs(std::function< bool(int, int)> func)
Relationship between inputs and outputs is checked with a specified function.
static TensorInferenceFunctionType NeedsAllInputShapes(TensorInferenceFunctionType f)
A wrapper that makes an infer tensor function to return unknown shape for all outputs if any one of t...
OpSchema & TensorInferenceFunction(TensorInferenceFunctionType function)
Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
int CalculateOutput(int num_input) const
A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
OpSchema & NumOutputs(int n)
A single output.
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...