Caffe2 - C++ API
A deep learning, cross platform ML framework
operator_schema.cc
1 
17 #include "caffe2/core/operator_schema.h"
18 
19 #include "caffe2/core/logging.h"
20 
21 namespace caffe2 {
22 
23 bool OpSchema::Verify(const OperatorDef& def) const {
24  // Check the number of inputs.
25  if (def.input_size() < min_input_ || def.input_size() > max_input_) {
26  LOG(ERROR) << "Input size " << def.input_size()
27  << " not in range [min=" << min_input_ << ", max="
28  << max_input_ << "].";
29  return false;
30  }
31  if (!num_inputs_allowed_(def.input_size())) {
32  LOG(ERROR) << "Input size " << def.input_size()
33  << " not in allowed input sizes.";
34  return false;
35  }
36  // Check the number of outputs.
37  if (def.output_size() < min_output_ || def.output_size() > max_output_) {
38  LOG(ERROR) << "Output size " << def.output_size()
39  << " not in range [min=" << min_output_ << ", max="
40  << max_output_ << "].";
41  return false;
42  }
43  if (!num_outputs_allowed_(def.output_size())) {
44  LOG(ERROR) << "Output size " << def.output_size()
45  << " not in allowed output sizes.";
46  return false;
47  }
48  if (!num_inputs_outputs_allowed_(def.input_size(), def.output_size())) {
49  LOG(ERROR) << "Combination of input size " << def.input_size()
50  << "and output size " << def.output_size() << " not in allowed.";
51  return false;
52  }
53  // If the number of outputs can be calculated, check if the number matches.
54  if (calculate_output_) {
55  int expected_nout = calculate_output_(def.input_size());
56  if (expected_nout != kCannotComputeNumOutputs &&
57  def.output_size() != expected_nout) {
58  LOG(ERROR) << "Output size " << def.output_size()
59  << " not matching expected output size, which is "
60  << expected_nout;
61  return false;
62  }
63  }
64 
65  // Check in-place settings.
66  for (int in_idx = 0; in_idx < def.input_size(); ++in_idx) {
67  for (int out_idx = 0; out_idx < def.output_size(); ++out_idx) {
68  // If an input is the same as an output but in-place is not opt-in
69  // either as allowed or enforced, we will fail the verification.
70  if (def.input(in_idx) == def.output(out_idx) &&
71  (!inplace_allowed_(in_idx, out_idx)
72  && !inplace_enforced_(in_idx, out_idx))) {
73  LOG(ERROR) << "Input index " << in_idx << " and output idx " << out_idx
74  << " (" << def.input(in_idx) << ")"
75  << " are set to be in-place but this is actually not "
76  << "supported by op " << def.type();
77  return false;
78  }
79  if (def.input(in_idx) != def.output(out_idx) &&
80  inplace_enforced_(in_idx, out_idx)) {
81  LOG(ERROR) << "Input index " << in_idx << " (" << def.input(in_idx) << ")"
82  << " and output idx " << out_idx
83  << " (" << def.output(in_idx) << ")"
84  << " are not in-place but should be as required by op "
85  << def.type();
86  return false;
87  }
88  }
89  }
90 
91  std::set<std::string> present_args{};
92  for (const auto& arg : def.arg()) {
93  present_args.insert(arg.name());
94  }
95 
96  for (const auto& arg : args()) {
97  if (arg.is_required() &&
98  present_args.find(arg.name()) == present_args.end()) {
99  LOG(ERROR) << "Argument '" << arg.name() << "' is required for Operator '"
100  << def.type() << "'.";
101  return false;
102  }
103  }
104 
105  // Phew. All verifications passed.
106  return true;
107 }
108 
109 OpSchema& OpSchema::NumInputs(int min, int max) {
110  min_input_ = min;
111  max_input_ = max;
112  return *this;
113 }
114 
116  return NumInputs(n, n);
117 }
118 
119 OpSchema& OpSchema::NumInputs(std::function<bool(int)> func) {
120  num_inputs_allowed_ = func;
121  return *this;
122 }
123 
124 OpSchema& OpSchema::NumInputs(set<int> allowed_input_nums) {
125  return NumInputs(
126  [allowed_input_nums](int n)->bool {
127  return allowed_input_nums.count(n);
128  });
129 }
130 
131 OpSchema& OpSchema::NumOutputs(int min, int max) {
132  min_output_ = min;
133  max_output_ = max;
134  return *this;
135 }
136 
138  return NumOutputs(n, n);
139 }
140 
141 OpSchema& OpSchema::NumOutputs(std::function<bool(int)> func) {
142  num_outputs_allowed_ = func;
143  return *this;
144 }
145 
146 OpSchema& OpSchema::NumOutputs(set<int> allowed_output_nums) {
147  return NumOutputs(
148  [allowed_output_nums](int n)->bool {
149  return allowed_output_nums.count(n);
150  });
151 }
152 
153 OpSchema& OpSchema::NumInputsOutputs(std::function<bool(int, int)> func) {
154  num_inputs_outputs_allowed_ = func;
155  return *this;
156 }
157 
158 OpSchema& OpSchema::OutputCalculator(std::function<int(int)> calc) {
159  calculate_output_ = calc;
160  return *this;
161 }
162 
164  return OutputCalculator([](int n)->int { return n; } );
165 }
166 
167 OpSchema& OpSchema::AllowInplace(std::function<bool(int, int)> inplace) {
168  inplace_allowed_ = inplace;
169  return *this;
170 }
171 
172 OpSchema& OpSchema::AllowInplace(set<std::pair<int, int>> inplace) {
173  return AllowInplace(
174  [inplace](int in, int out)->bool {
175  return inplace.count(std::make_pair(in, out));
176  });
177 }
178 
179 OpSchema& OpSchema::AllowOneToOneInplace() {
180  return AllowInplace([](int in, int out) { return in == out; });
181 }
182 
183 OpSchema& OpSchema::EnforceInplace(std::function<bool(int, int)> inplace) {
184  inplace_enforced_ = inplace;
185  return *this;
186 }
187 
188 OpSchema& OpSchema::EnforceInplace(set<std::pair<int, int>> inplace) {
189  return EnforceInplace(
190  [inplace](int in, int out)->bool {
191  return inplace.count(std::make_pair(in, out));
192  });
193 }
194 
195 OpSchema& OpSchema::EnforceOneToOneInplace() {
196  return EnforceInplace([](int in, int out) { return in == out; });
197 }
198 
199 OpSchema& OpSchema::Private() {
200  private_ = true;
201  return *this;
202 }
203 
204 OpSchema& OpSchema::InputsCanCrossDevices() {
205  inputs_can_cross_devices_ = true;
206  return *this;
207 }
208 
210  TensorInferenceFunctionType function) {
211  tensor_inference_function_ = function;
212  return *this;
213 }
214 
217  [](const OperatorDef&, const vector<TensorShape>& input_types) {
218  return vector<TensorShape>(input_types);
219  });
220 }
221 
222 OpSchema& OpSchema::IdenticalTypeAndShapeOfInput(int idx) {
224  [idx](const OperatorDef&, const vector<TensorShape>& input_types) {
225  vector<TensorShape> out(1);
226  out[0] = input_types[idx];
227  return out;
228  });
229 }
230 
231 OpSchema& OpSchema::IdenticalTypeAndShapeOfInputDim(int idx, int dim) {
233  [idx, dim](const OperatorDef&, const vector<TensorShape>& input_types) {
234  vector<TensorShape> out(1);
235  out[0].add_dims(input_types[idx].dims(dim));
236  out[0].set_data_type(input_types[idx].data_type());
237  return out;
238  });
239 }
240 
241 OpSchema& OpSchema::ScalarType(::caffe2::TensorProto_DataType dt) {
243  [dt](const OperatorDef&, const vector<TensorShape>& /*input_types*/) {
244  vector<TensorShape> out(1);
245  out[0].set_data_type(dt);
246  return out;
247  });
248 }
249 
251  cost_inference_function_ =
252  caffe2::make_unique<CostInferenceFunctionType>(function);
253  return *this;
254 }
255 
256 OpSchema& OpSchema::DeviceInferenceFunction(
257  DeviceInferenceFunctionType function) {
258  device_inference_function_ = function;
259  return *this;
260 }
261 
262 OpSchema& OpSchema::SetDoc(const string& doc) {
263  doc_ = doc;
264  return *this;
265 }
266 
267 OpSchema&
268 OpSchema::Arg(const char* name, const char* description, bool required) {
269  args_.push_back(Argument(name, description, required));
270  return *this;
271 }
272 
273 #define DEFINE_STANDARG_ARG(name, str) \
274  CAFFE2_API const char* OpSchema::Arg_##name = #str; \
275  CAFFE2_API OpSchema& OpSchema::Arg##name(const char* description) { \
276  return Arg(#str, description, true); \
277  }
278 
279 DEFINE_STANDARG_ARG(IsTest, is_test)
280 
281 #undef DEFINE_STANDARG_ARG
282 
283 OpSchema& OpSchema::Input(const int n, const char* name, const char* description) {
284  if (input_desc_.size() <= n) {
285  input_desc_.resize(n + 1);
286  }
287  input_desc_[n] = std::make_pair(name, description);
288  return *this;
289 }
290 
291 OpSchema& OpSchema::Output(const int n, const char* name, const char* description) {
292  if (output_desc_.size() <= n) {
293  output_desc_.resize(n + 1);
294  }
295  output_desc_[n] = std::make_pair(name, description);
296  return *this;
297 }
298 
299 OpSchema& OpSchema::FillUsing(std::function<void(OpSchema&)> populator) {
300  if (populator) {
301  populator(*this);
302  }
303  return *this;
304 }
305 
306 int OpSchema::CalculateOutput(int num_input) const {
307  if (min_output_ == max_output_) {
308  return min_output_;
309  } else if (calculate_output_) {
310  return calculate_output_(num_input);
311  } else {
312  return kCannotComputeNumOutputs;
313  }
314 }
315 
316 std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
317  if (!schema.args().empty()) {
318  out << "Arguments:" << std::endl;
319  for (const auto& arg : schema.args()) {
320  out << " " << arg.name() << " : " << arg.description() << std::endl;
321  }
322  }
323  if (schema.max_input_ > 0) {
324  out << "Inputs:" << std::endl;
325  if (!schema.input_desc_.empty()) {
326  for (int i = 0; i < schema.input_desc_.size(); ++i) {
327  const auto& p = schema.input_desc_[i];
328  out << " " << i << ", " << (p.first ? p.first : "(unnamed)") << " : "
329  << (p.second ? p.second : "(no doc)") << std::endl;
330  }
331  } else {
332  out << " (no explicit description available)" << std::endl;
333  }
334  }
335  if (schema.max_output_ > 0) {
336  out << "Outputs:" << std::endl;
337  if (!schema.output_desc_.empty()) {
338  for (int i = 0; i < schema.output_desc_.size(); ++i) {
339  const auto& p = schema.output_desc_[i];
340  out << " " << i << ", " << (p.first ? p.first : "(unnamed)") << " : "
341  << (p.second ? p.second : "(no doc)") << std::endl;
342  }
343  } else {
344  out << " (no explicit description available)" << std::endl;
345  }
346  }
347  out << std::endl;
348  if (schema.doc()) {
349  out << schema.doc();
350  } else {
351  out << "(no documentation yet)" << std::endl;
352  }
353  out << std::endl;
354  if (schema.line_) {
355  out << "Defined at " << schema.file_ << ":" << schema.line_ << std::endl;
356  }
357  return out;
358 }
359 
360 CaffeMap<string, OpSchema>& OpSchemaRegistry::map() {
361  static CaffeMap<string, OpSchema> map;
362  return map;
363 }
364 
365 } // namespace caffe2
std::function< std::pair< std::vector< DeviceOption >, std::vector< DeviceOption >>(const OperatorDef &def)> DeviceInferenceFunctionType
Returns the required device location of inputs and outputs.
OpSchema & NumInputs(int n)
A single input.
A class to record the schema of an op.
bool Verify(const OperatorDef &def) const
Verifies if an operator definition protobuf matches the pattern specified in the schema.
const char * doc() const
Returns the docstring of the op schema.
OpSchema & OutputCalculator(std::function< int(int)> calc)
Set the output calculator to a user-defined function.
OpSchema & IdenticalTypeAndShape()
Sets the tensor inference function to produce the same output as the input.
OpSchema & SameNumberOfOutput()
Set the number of outputs to be the same as the number of inputs.
Copyright (c) 2016-present, Facebook, Inc.
OpSchema & CostInferenceFunction(CostInferenceFunctionType function)
Register the Cost inference function.
OpSchema & NumInputsOutputs(std::function< bool(int, int)> func)
Relationship between inputs and outputs is checked with a specified function.
OpSchema & TensorInferenceFunction(TensorInferenceFunctionType function)
Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
int CalculateOutput(int num_input) const
A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
OpSchema & NumOutputs(int n)
A single output.
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...