Caffe2 - C++ API
A deep learning, cross platform ML framework
schema_matching.cpp
1 #include <torch/csrc/jit/script/schema_matching.h>
2 #include <torch/csrc/jit/operator.h>
3 #include <torch/csrc/jit/script/builtin_functions.h>
4 #include <torch/csrc/jit/script/error_report.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace script {
9 
10 inline TypePtr unwrapOptional(TypePtr opt_type) {
11  if (auto unwrap_list_type = opt_type->cast<OptionalType>()) {
12  return unwrap_list_type->getElementType();
13  }
14  return opt_type;
15 }
16 
17 static inline bool isIntOrFloatUsedAsList(
18  const Value* value,
19  const Argument& arg) {
20  // Look for int[N] or float[N]
21  const auto& v_type = value->type();
22  if (v_type != FloatType::get() && v_type != IntType::get())
23  return false;
24  auto arg_type = unwrapOptional(arg.type());
25  auto list_type = arg_type->cast<ListType>();
26  return list_type && list_type->getElementType() == v_type && arg.N();
27 }
28 
29 inline bool convertibleToList(const TypePtr& type, const TypePtr& list_type_) {
30  auto list_type = list_type_->cast<ListType>();
31  if (!list_type) {
32  return false;
33  }
34  if (type->isSubtypeOf(list_type_)) {
35  return true;
36  }
37  if (auto tuple = type->cast<TupleType>()) {
38  return std::all_of(
39  tuple->elements().begin(),
40  tuple->elements().end(),
41  [&](const TypePtr& t) {
42  return t->isSubtypeOf(list_type->getElementType());
43  });
44  }
45  return false;
46 }
47 
48 // applies implict conversion from value trying to turn it into type
49 // concrete_type it succeeds if the return_value->isSubclassOf(concrete_type)
50 Value* tryConvertToType(
51  const SourceRange& loc,
52  Graph& graph,
53  const TypePtr& concrete_type,
54  Value* value,
55  bool allow_conversions) {
56  if (auto value_tuple = value->type()->cast<TupleType>()) {
57  // Allow homogeneous tuples to be casted implicitly to lists of appropriate
58  // types
59  if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
60  auto unpacked = createTupleUnpack(value);
61  auto elem_type =
62  unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
63  value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
64  }
65  // inductively apply implicit conversions to tuples
66  if (auto concrete_tuple = concrete_type->cast<TupleType>()) {
67  if (!value_tuple->isSubtypeOf(concrete_tuple) &&
68  concrete_tuple->elements().size() == value_tuple->elements().size()) {
69  auto unpacked = createTupleUnpack(value);
70  std::vector<Value*> converted;
71  for (size_t i = 0; i < concrete_tuple->elements().size(); ++i) {
72  converted.emplace_back(tryConvertToType(
73  loc,
74  graph,
75  concrete_tuple->elements().at(i),
76  unpacked.at(i),
77  allow_conversions));
78  }
79  value = graph.insertNode(graph.createTuple(converted))->output();
80  }
81  }
82  }
83 
84  if (value->type()->isSubtypeOf(NoneType::get()) &&
85  !concrete_type->isSubtypeOf(NoneType::get())) {
86  if (auto optional_type = concrete_type->cast<OptionalType>()) {
87  value =
88  graph.insertNode(graph.createNone(optional_type->getElementType()))
89  ->output();
90  } else {
91  // When try to convert None to non-optional concrete type, create a None
92  // node with the return value type of Optional[concrete_type]
93  value = graph.insertNode(graph.createNone(concrete_type))->output();
94  }
95  }
96 
97  // implicit conversions
98  if (allow_conversions) {
99  if (concrete_type->isSubtypeOf(NumberType::get()) &&
100  value->type()->isSubtypeOf(TensorType::get())) {
101  auto n = graph.createImplicitTensorToNum(concrete_type, value);
102  value = graph.insertNode(n)
103  ->setSourceLocation(std::make_shared<SourceRange>(loc))
104  ->output();
105  }
106  if (value->type()->isSubtypeOf(StringType::get()) &&
107  DeviceObjType::get()->isSubtypeOf(concrete_type)) {
108  return graph.insert(aten::device, {value}, {}, loc);
109  }
110  if (concrete_type == FloatType::get() &&
111  value->type() == NumberType::get()) {
112  return graph.insert(prim::Float, {value}, {}, loc);
113  }
114  }
115 
116  return value;
117 }
118 
119 Value* tryMatchArgument(
120  const Argument& arg,
121  Graph& graph,
122  const SourceRange& loc,
123  const NamedValue& named_value,
124  const std::function<std::ostream&()>& err,
125  bool allow_conversions,
126  TypeEnv& type_env) {
127  Value* value = named_value.value(graph);
128 
129  // some functions that take lists of integers or floats for fixed size arrays
130  // also allow single ints/floats to be passed in their place.
131  // the single int/float is then repeated to the length of the list
132  if (isIntOrFloatUsedAsList(value, arg)) {
133  std::vector<Value*> repeated(*arg.N(), value);
134  value =
135  graph.insertNode(graph.createList(value->type(), repeated))->output();
136  }
137 
138  const MatchTypeReturn matched_type =
139  matchTypeVariables(arg.type(), value->type(), type_env);
140  if (!matched_type.type) {
141  err() << "could not match type " << value->type()->str() << " to "
142  << arg.type()->str() << " in argument '" << arg.name()
143  << "': " << matched_type.errMsg << "\n"
144  << named_value.locOr(loc);
145  return nullptr;
146  }
147  const auto concrete_type = *matched_type.type;
148 
149  value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
150 
151  if (!value->type()->isSubtypeOf(concrete_type)) {
152  auto& ostream = err() << "expected a value of type " << concrete_type->str()
153  << " for argument '" << arg.name() << "' but found "
154  << value->type()->str() << "\n";
155 
156  if (value->type() == NumberType::get() &&
157  value->node()->kind() == aten::item) {
158  ostream
159  << "Use int(tensor) or float(tensor) to retrieve item() from a tensor with the appropriate type\n";
160  }
161  ostream << named_value.locOr(loc);
162  return nullptr;
163  }
164  return value;
165 }
166 
167 c10::optional<size_t> findInputWithName(
168  const std::string& name,
169  at::ArrayRef<NamedValue> kwargs) {
170  for (size_t i = 0; i < kwargs.size(); ++i) {
171  if (kwargs[i].name() == name)
172  return i;
173  }
174  return c10::nullopt;
175 }
176 
177 Value* tryCreateList(
178  const TypePtr& elem_type,
179  Graph& graph,
180  const SourceRange& loc,
181  at::ArrayRef<NamedValue> varargs,
182  const std::function<std::ostream&()>& err,
183  bool convert_tensor_to_num,
184  TypeEnv& type_env) {
185  Argument elem_arg("<varargs>", elem_type);
186  std::vector<Value*> list_ctor;
187  for (const auto& a : varargs) {
188  Value* av = tryMatchArgument(
189  elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
190  if (!av)
191  return nullptr;
192  list_ctor.push_back(av);
193  }
194  return graph.insertNode(graph.createList(elem_type, list_ctor))->output();
195 }
196 
197 c10::optional<MatchedSchema> tryMatchSchema(
198  const FunctionSchema& schema,
199  const SourceRange& loc,
200  Graph& graph,
204  std::ostream& failure_messages,
205  bool allow_conversions) {
206  auto err = [&]() -> std::ostream& {
207  failure_messages << "\nfor operator " << schema << ":\n";
208  return failure_messages;
209  };
210 
211  TypeEnv type_env;
212  std::vector<Value*> positional_inputs;
213  std::vector<bool> used_kwarg(kwargs.size(), false);
214 
215  // if we finish the loop will we have consumed all arguments?
216  size_t used_args = 0;
217  for (size_t schema_i = 0; schema_i < schema.arguments().size(); ++schema_i) {
218  const auto& arg = schema.arguments()[schema_i];
220  if (arg.name() == "self" && self) {
221  v = self;
222  self = c10::nullopt;
223  } else if (!arg.kwarg_only() && used_args < args.size()) {
224  // allow zeros(IntArrayRef sizes) to work with zeros(1, 2) or zeros(1)
225  if (allow_conversions &&
226  arg.type()->kind() ==
227  TypeKind::ListType && // the formal must be a list
228  !arg.N() && // it must not be a broadcasting list like int[3],
229  // otherwise a single int is a valid input
230  (schema_i + 1 == schema.arguments().size() ||
231  schema.arguments()[schema_i + 1]
232  .kwarg_only())) { // must be the last position argument
233  auto actual_type = args[used_args].value(graph)->type();
234  if (actual_type->kind() != TypeKind::ListType &&
235  !convertibleToList(
236  actual_type,
237  unwrapOptional(arg.type()))) { // and the actual should not be a
238  // list already
239  auto elem_type =
240  unwrapOptional(arg.type())->expect<ListType>()->getElementType();
241  Value* list = tryCreateList(
242  elem_type,
243  graph,
244  loc,
245  at::ArrayRef<NamedValue>(args).slice(used_args),
246  err,
247  allow_conversions,
248  type_env);
249  if (!list)
250  return c10::nullopt;
251  used_args = args.size();
252  positional_inputs.push_back(list);
253  continue;
254  }
255  }
256 
257  v = args[used_args];
258  used_args++;
259  } else if (auto idx = findInputWithName(arg.name(), kwargs)) {
260  const NamedValue& nv = kwargs[*idx];
261  if (used_kwarg[*idx]) {
262  err() << "argument " << nv.name()
263  << " specified twice in schema, submit a bug report!\n"
264  << nv.locOr(loc);
265  return c10::nullopt;
266  }
267  used_kwarg[*idx] = true;
268  v = nv;
269  } else if (arg.default_value()) {
270  v = NamedValue(*arg.default_value());
271  } else {
272  err() << "argument " << schema.arguments()[schema_i].name()
273  << " not provided.\n"
274  << loc;
275  return c10::nullopt;
276  }
277  Value* positional =
278  tryMatchArgument(arg, graph, loc, *v, err, allow_conversions, type_env);
279  if (!positional)
280  return c10::nullopt;
281  positional_inputs.push_back(positional);
282  }
283  // check for unused self argument
284  if (self != c10::nullopt) {
285  err() << "provided self argument not used in schema\n";
286  }
287 
288  if (schema.is_vararg()) {
289  for (; used_args < args.size(); ++used_args) {
290  positional_inputs.push_back(args[used_args].value(graph));
291  }
292  }
293 
294  // check for unused positional arguments
295  if (used_args < args.size()) {
296  err() << "expected at most " << used_args << " arguments "
297  << "but found " << args.size() << " positional arguments.\n"
298  << loc << "\n";
299  return c10::nullopt;
300  }
301  // check for unused kwargs
302  for (size_t i = 0; i < kwargs.size(); ++i) {
303  const auto& nv = kwargs[i];
304  if (!used_kwarg[i]) {
305  if (!schema.argumentIndexWithName(nv.name())) {
306  err() << "keyword argument " << nv.name() << " unknown\n";
307  } else {
308  err() << "keyword argument " << nv.name() << " specified twice\n";
309  }
310  return c10::nullopt;
311  }
312  }
313 
314  const auto& returns = schema.returns();
315  auto return_types = fmap(returns, [&](const Argument& r) {
316  return evalTypeVariables(r.type(), type_env);
317  });
318  // Codegen does not support return of namedtuples with undefined field names.
319  // Therefore, either all or none returns has field names.
320  bool return_has_field_names =
321  std::all_of(returns.begin(), returns.end(), [&](const Argument& r) {
322  return r.name().length() > 0;
323  });
324  c10::OptNameList return_field_names = c10::nullopt;
325  if (return_has_field_names) {
326  return_field_names =
327  fmap(returns, [&](const Argument& r) { return r.name(); });
328  }
329  return MatchedSchema{std::move(positional_inputs),
330  std::move(return_types),
331  std::move(return_field_names)};
332 }
333 
334 // pack outputs of a function following python rules. If there is a single value
335 // return a SimpleValue, otherwise pack all the values into a Tuple.
336 Value* packOutputs(
337  Graph& g,
338  at::ArrayRef<Value*> values,
339  c10::OptNameList field_names) {
340  if (values.size() == 1) {
341  return values[0];
342  }
343  return g.insertNode(g.createTuple(values, std::move(field_names)))->output();
344 }
345 
346 // Given a successful match between operator schema and symbol, emit a node
347 // with the appropriate inputs and outputs.
348 static Value* emitBuiltinNode(
349  const MatchedSchema& matched_schema,
350  const SourceRange& loc,
351  Graph& graph,
352  Symbol name) {
353  auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
354  ->setSourceLocation(std::make_shared<SourceRange>(loc));
355 
356  for (auto& ret : matched_schema.return_types) {
357  n->addOutput()->setType(ret);
358  }
359 
360  // assert that we did indeed create an op that has implementation
361  // otherwise schema and dispatch are not in sync
362  getOperation(n);
363 
364  return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
365 }
366 
367 static std::string prefixLine(
368  const std::string& str,
369  const std::string& prefix) {
370  std::stringstream ss;
371  bool was_newline = true;
372  for (auto c : str) {
373  if (was_newline)
374  ss << prefix;
375  ss.put(c);
376  was_newline = c == '\n';
377  }
378  return ss.str();
379 }
380 
381 // Search for operators matching the provided symbol name and input types.
382 // If one is found, emit a node to the graph for that operator.
383 Value* emitBuiltinCall(
384  const SourceRange& loc,
385  Graph& graph,
386  Symbol name,
387  const c10::optional<NamedValue>& self,
389  at::ArrayRef<NamedValue> attributes,
390  // if true, emitBuiltinCall will throw an exception if this builtin does not
391  // exist, otherwise it will return nullptr if the builtin is not found.
392  bool required) {
393  const auto& variants = getAllOperatorsFor(name);
394  const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
395 
396  std::stringstream failure_messages;
397  // first we try to match the schema without any conversion
398  // if no schema matches then insert ImplicitTensorToNum
399  for (bool allow_conversions : {false, true}) {
400  // clear previous error messages
401  failure_messages.str("");
402  for (const std::shared_ptr<Operator>& op : variants) {
403  const auto matched_schema = tryMatchSchema(
404  op->schema(),
405  loc,
406  graph,
407  self,
408  inputs,
409  attributes,
410  failure_messages,
411  allow_conversions);
412  if (matched_schema) {
413  return emitBuiltinNode(*matched_schema, loc, graph, name);
414  }
415  }
416  for (Method* method : builtin_functions) {
417  if (auto result = try_emit_call_to(
418  graph,
419  loc,
420  *method,
421  self,
422  inputs,
423  attributes,
424  failure_messages,
425  nullptr,
426  allow_conversions)) {
427  return result;
428  }
429  }
430  }
431 
432  // none of the options worked
433  if (!required) {
434  return nullptr;
435  }
436  // no operators found with the same name, print out similarly named operators
437  if (variants.size() == 0) {
438  const auto close_symbols = findSimilarOperators(name);
439  auto error = ErrorReport(loc);
440  const auto& user_function_name = name.toQualString();
441  error << "unknown builtin op: " << user_function_name << "\n";
442  if (close_symbols.size() == 0) {
443  error
444  << "Could not find any similar ops to " << user_function_name
445  << ". This op may not exist or may not be currently supported in TorchScript\n";
446  } else {
447  error << "Here are some suggestions: \n";
448  for (const auto& sym : close_symbols) {
449  error << "\t" << sym.toQualString() << "\n";
450  }
451  }
452  throw error;
453  }
454 
455  throw ErrorReport(loc) << "arguments for call are not valid:\n"
456  << prefixLine(failure_messages.str(), " ")
457  << "for call at";
458 }
459 } // namespace script
460 } // namespace jit
461 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41