Caffe2 - C++ API
A deep learning, cross platform ML framework
operator.cpp
1 #include <torch/csrc/jit/operator.h>
2 #include <ATen/ATen.h>
3 #include <torch/csrc/jit/alias_info.h>
4 #include <torch/csrc/jit/passes/alias_analysis.h>
5 #include <torch/csrc/jit/passes/python_print.h>
6 #include <torch/csrc/jit/script/edit_distance.h>
7 #include <torch/csrc/jit/script/error_report.h>
8 #include <torch/csrc/jit/script/lexer.h>
9 #include <torch/csrc/jit/script/parse_string_literal.h>
10 #include <torch/csrc/jit/script/schema_type_parser.h>
11 #include <torch/csrc/jit/script/tree.h>
12 
13 #include <functional>
14 #include <memory>
15 #include <queue>
16 #include <utility>
17 #include <vector>
18 
19 namespace torch {
20 namespace jit {
21 
22 namespace script {
23 struct SchemaParser {
24  SchemaParser(const std::string& str)
25  : L(str), type_parser(L, /*parse_complete_tensor_types*/ false) {}
26 
27  FunctionSchema parseDeclaration() {
28  std::string name = L.expect(TK_IDENT).text();
29  if (L.nextIf(':')) {
30  L.expect(':');
31  name = name + "::" + L.expect(TK_IDENT).text();
32  }
33  std::string overload_name = "";
34  if (L.nextIf('.')) {
35  overload_name = L.expect(TK_IDENT).text();
36  }
37  std::vector<Argument> arguments;
38  std::vector<Argument> returns;
39  bool kwarg_only = false;
40  bool is_vararg = false;
41  size_t idx = 0;
42  parseList('(', ',', ')', [&] {
43  if (is_vararg)
44  throw ErrorReport(L.cur())
45  << "... must be the last element of the argument list";
46  if (L.nextIf('*')) {
47  kwarg_only = true;
48  } else if (L.nextIf(TK_DOTS)) {
49  is_vararg = true;
50  } else {
51  arguments.push_back(parseArgument(
52  idx++, /*is_return=*/false, /*kwarg_only=*/kwarg_only));
53  }
54  });
55  idx = 0;
56  L.expect(TK_ARROW);
57  if (L.cur().kind == '(') {
58  parseList('(', ',', ')', [&] {
59  returns.push_back(
60  parseArgument(idx++, /*is_return=*/true, /*kwarg_only=*/false));
61  });
62  } else {
63  returns.push_back(
64  parseArgument(0, /*is_return=*/true, /*kwarg_only=*/false));
65  }
66  return FunctionSchema{
67  std::move(name), std::move(overload_name), std::move(arguments), std::move(returns), is_vararg, false};
68  }
69 
70  std::vector<FunctionSchema> parseDeclarations() {
71  std::vector<FunctionSchema> results;
72  do {
73  results.push_back(parseDeclaration());
74  } while (L.nextIf(TK_NEWLINE));
75  L.expect(TK_EOF);
76  return results;
77  }
78 
79  TreeRef parseIdent() {
80  return String::create(L.expect(TK_IDENT).text());
81  }
82 
83  Argument parseArgument(size_t idx, bool is_return, bool kwarg_only) {
84  Argument result;
85  auto p = type_parser.parseType();
86  auto type = std::move(p.first);
87  auto alias_info = std::move(p.second);
89  c10::optional<IValue> default_value;
91  std::string name;
92  if (L.nextIf('[')) {
93  // note: an array with a size hint can only occur at the Argument level
94  type = ListType::create(type);
95  N = std::stoll(L.expect(TK_NUMBER).text());
96  L.expect(']');
97  auto container = type_parser.parseAliasAnnotation();
98  if (container && alias_info) {
99  container->addContainedType(std::move(*alias_info));
100  }
101  alias_info = std::move(container);
102  }
103  if (is_return) {
104  // optionally field names in return values
105  if (L.cur().kind == TK_IDENT) {
106  name = L.next().text();
107  } else {
108  name = "";
109  }
110  } else {
111  name = L.expect(TK_IDENT).text();
112  if (L.nextIf('=')) {
113  default_value = parseDefaultValue(type, N);
114  }
115  }
116  return Argument(
117  std::move(name),
118  std::move(type),
119  N,
120  std::move(default_value),
121  !is_return && kwarg_only,
122  std::move(alias_info));
123  }
124  IValue parseSingleConstant(TypeKind kind) {
125  switch (L.cur().kind) {
126  case TK_TRUE:
127  L.next();
128  return true;
129  case TK_FALSE:
130  L.next();
131  return false;
132  case TK_NONE:
133  L.next();
134  return IValue();
135  case TK_STRINGLITERAL: {
136  auto token = L.next();
137  return parseStringLiteral(token.range, token.text());
138  }
139  case TK_IDENT: {
140  auto tok = L.next();
141  auto text = tok.text();
142  if ("float" == text) {
143  return static_cast<int64_t>(at::kFloat);
144  } else if ("long" == text) {
145  return static_cast<int64_t>(at::kLong);
146  } else if ("strided" == text) {
147  return static_cast<int64_t>(at::kStrided);
148  } else if ("Mean" == text) {
149  return static_cast<int64_t>(Reduction::Mean);
150  } else {
151  throw ErrorReport(L.cur().range) << "invalid numeric default value";
152  }
153  }
154  default:
155  std::string n;
156  if (L.nextIf('-'))
157  n = "-" + L.expect(TK_NUMBER).text();
158  else
159  n = L.expect(TK_NUMBER).text();
160  if (kind == TypeKind::FloatType || n.find('.') != std::string::npos ||
161  n.find('e') != std::string::npos) {
162  return std::stod(n);
163  } else {
164  int64_t v = std::stoll(n);
165  return v;
166  }
167  }
168  }
169  IValue convertToList(
170  TypeKind kind,
171  const SourceRange& range,
172  std::vector<IValue> vs) {
173  switch (kind) {
174  case TypeKind::FloatType:
175  return fmap(vs, [](IValue v) { return v.toDouble(); });
176  case TypeKind::IntType:
177  return fmap(vs, [](IValue v) { return v.toInt(); });
178  case TypeKind::BoolType:
179  return fmap(vs, [](IValue v) { return v.toBool(); });
180  default:
181  throw ErrorReport(range)
182  << "lists are only supported for float or int types.";
183  }
184  }
185  IValue parseConstantList(TypeKind kind) {
186  auto tok = L.expect('[');
187  std::vector<IValue> vs;
188  if (L.cur().kind != ']') {
189  do {
190  vs.push_back(parseSingleConstant(kind));
191  } while (L.nextIf(','));
192  }
193  L.expect(']');
194  return convertToList(kind, tok.range, std::move(vs));
195  }
196 
197  IValue parseTensorDefault(const SourceRange& range) {
198  L.expect(TK_NONE);
199  return IValue();
200  }
201  IValue parseDefaultValue(
202  const TypePtr& arg_type,
203  c10::optional<int32_t> arg_N) {
204  auto range = L.cur().range;
205  switch (arg_type->kind()) {
206  case TypeKind::TensorType:
207  case TypeKind::GeneratorType: {
208  return parseTensorDefault(range);
209  } break;
210  case TypeKind::StringType:
211  case TypeKind::OptionalType:
212  case TypeKind::NumberType:
213  case TypeKind::IntType:
214  case TypeKind::BoolType:
215  case TypeKind::FloatType:
216  return parseSingleConstant(arg_type->kind());
217  break;
218  case TypeKind::DeviceObjType: {
219  auto device_text =
220  parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
221  return c10::Device(device_text);
222  break;
223  }
224  case TypeKind::ListType: {
225  auto elem_kind = arg_type->cast<ListType>()->getElementType();
226  if (L.cur().kind == TK_IDENT) {
227  return parseTensorDefault(range);
228  } else if (arg_N && L.cur().kind != '[') {
229  IValue v = parseSingleConstant(elem_kind->kind());
230  std::vector<IValue> repeated(*arg_N, v);
231  return convertToList(elem_kind->kind(), range, repeated);
232  } else {
233  return parseConstantList(elem_kind->kind());
234  }
235  } break;
236  default:
237  throw ErrorReport(range) << "unexpected type, file a bug report";
238  }
239  return IValue(); // silence warnings
240  }
241 
242  void parseList(
243  int begin,
244  int sep,
245  int end,
246  const std::function<void()>& callback) {
247  auto r = L.cur().range;
248  if (begin != TK_NOTHING)
249  L.expect(begin);
250  if (L.cur().kind != end) {
251  do {
252  callback();
253  } while (L.nextIf(sep));
254  }
255  if (end != TK_NOTHING)
256  L.expect(end);
257  }
258  Lexer L;
259  SchemaTypeParser type_parser;
260 };
261 } // namespace script
262 
263 namespace {
264 using OperatorMap =
265  std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
266 struct OperatorRegistry {
267  private:
268  std::mutex lock;
269  OperatorMap operators;
270  // list of operators whose schema have not yet been parsed, and must
271  // be registered before any call to lookup an opeator
272  std::vector<std::shared_ptr<Operator>> to_register;
273  // Those two maps are used to implement lookupByLiteral, which is needed for
274  // the n->match(...) calls. Basically, every function schema is assigned a
275  // unique string you can use to match it. However, parsing those strings or
276  // comparing and hashing them character by character would be very slow, so we
277  // use a trick here! Every string literal in your program is guaranteed to
278  // have static storage duration and so its address won't change at runtime.
279  // This allows us to memoize answers for every pointer, which is done by the
280  // operators_by_sig_literal map. Still, this map is initially empty, and so we
281  // still need to do the complete string matching at the first time, which is
282  // implemented by performing a lookup in the operators_by_sig map.
283  std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
284  std::unordered_map<const char*, std::shared_ptr<Operator>>
285  operators_by_sig_literal;
286 
287  // XXX - caller must be holding lock
288  void registerPendingOperators() {
289  for (const auto& op : to_register) {
290  Symbol sym = Symbol::fromQualString(op->schema().name());
291  operators[sym].push_back(op);
292  operators_by_sig[canonicalSchemaString(op->schema())] = op;
293  }
294  to_register.clear();
295  }
296 
297  public:
298  void registerOperator(Operator&& op) {
299  std::lock_guard<std::mutex> guard(lock);
300  to_register.push_back(std::make_shared<Operator>(std::move(op)));
301  }
302 
303  const std::shared_ptr<Operator>& lookupByLiteral(const char* name) {
304  std::lock_guard<std::mutex> guard(lock);
305  registerPendingOperators();
306  auto it = operators_by_sig_literal.find(name);
307  if (it == operators_by_sig_literal.end()) {
308  auto op_ptr_it =
309  operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
310  // Handy debugging code that dumps all operators we know about on mismatch
311 #if 0
312  if (op_ptr_it == operators_by_sig.end()) {
313  for (auto & entry : operators_by_sig) {
314  std::cout << entry.first << std::endl;
315  }
316  }
317 #endif
318  AT_CHECK(
319  op_ptr_it != operators_by_sig.end(),
320  "Couldn't find an operator for ",
321  name,
322  ". Do you have to update a set of hardcoded JIT ops?");
323  it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
324  }
325  return it->second;
326  }
327 
328  const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
329  std::lock_guard<std::mutex> guard(lock);
330  registerPendingOperators();
331  static std::vector<std::shared_ptr<Operator>> empty;
332  auto it = operators.find(name);
333  if (it != operators.end())
334  return it->second;
335  return empty;
336  }
337 
338  std::vector<Symbol> findSimilarOperators(Symbol input_op) {
339  std::lock_guard<std::mutex> guard(lock);
340  registerPendingOperators();
341 
342  using EntryPair = std::pair<int64_t, Symbol>;
343  auto cmp = [](const EntryPair& lhs, const EntryPair& rhs) {
344  return lhs.first > rhs.first;
345  };
346 
347  std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
348  rankings(cmp);
349  static constexpr size_t MAX_EDIT_DIST = 2u;
350  for (const auto& op : operators) {
351  auto edit_dist = script::ComputeEditDistance(
352  input_op.toQualString(), op.first.toQualString(), MAX_EDIT_DIST);
353  if (edit_dist <= MAX_EDIT_DIST) {
354  rankings.emplace(edit_dist, op.first);
355  }
356  }
357  std::vector<Symbol> ret;
358  while (!rankings.empty()) {
359  ret.push_back(rankings.top().second);
360  rankings.pop();
361  }
362  return ret;
363  }
364 };
365 
366 OperatorRegistry& getRegistry() {
367  static OperatorRegistry r;
368  return r;
369 }
370 } // anonymous namespace
371 
372 void registerOperator(Operator&& op) {
373  if (op.schema().is_varret()) {
374  Symbol s = Symbol::fromQualString(op.schema().name());
375  if (!printerHasSpecialCaseFor(s)) {
376  AT_ERROR(
377  "Missing special case in python printer for non-schematized"
378  " operator ",
379  op.schema().name(),
380  ". File a bug to add a case for this operator.\n");
381  }
382  if (!aliasAnalysisHasSpecialCaseFor(s)) {
383  AT_ERROR(
384  "Missing special case in alias analysis for non-schematized"
385  " operator ",
386  op.schema().name(),
387  ". File a bug to add a case for this operator.\n");
388  }
389  }
390 
391  getRegistry().registerOperator(std::move(op));
392 }
393 
394 const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name) {
395  return getRegistry().getOperators(name);
396 }
397 
398 std::vector<Symbol> findSimilarOperators(Symbol input_op) {
399  return getRegistry().findSimilarOperators(input_op);
400 }
401 
402 Operator& sig(const char* signature) {
403  return *getRegistry().lookupByLiteral(signature);
404 }
405 
406 FunctionSchema parseSchema(const std::string& schema) {
407  return script::SchemaParser(schema).parseDeclarations().at(0);
408 }
409 
410 std::string canonicalSchemaString(const FunctionSchema& schema) {
411  std::ostringstream out;
412 
413  out << schema.name();
414  out << "(";
415 
416  bool seen_kwarg_only = false;
417  for (size_t i = 0; i < schema.arguments().size(); ++i) {
418  if (i > 0)
419  out << ", ";
420  if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
421  out << "*, ";
422  seen_kwarg_only = true;
423  }
424  const auto& arg = schema.arguments()[i];
425  out << arg.type()->str() << " " << arg.name();
426  }
427 
428  out << ") -> ";
429  if (schema.returns().size() == 1) {
430  out << schema.returns().at(0).type()->str();
431  } else if (schema.returns().size() > 1) {
432  out << "(";
433  for (size_t i = 0; i < schema.returns().size(); ++i) {
434  if (i > 0)
435  out << ", ";
436  out << schema.returns()[i].type()->str();
437  }
438  out << ")";
439  }
440  return out.str();
441 }
442 
443 bool Operator::matches(const Node* node) const {
444  // wrong name
445  if (node->kind().toQualString() != schema().name()) {
446  return false;
447  }
448  at::ArrayRef<const Value*> actuals = node->inputs();
449  const auto& formals = schema().arguments();
450 
451  // not enough inputs
452  if (actuals.size() < formals.size())
453  return false;
454 
455  TypeEnv type_env;
456  for (size_t i = 0; i < formals.size(); ++i) {
457  const MatchTypeReturn matched_type =
458  matchTypeVariables(formals[i].type(), actuals[i]->type(), type_env);
459  if (!matched_type.type) {
460  return false;
461  }
462  TypePtr formal = *matched_type.type;
463  if (!actuals[i]->type()->isSubtypeOf(formal)) {
464  return false;
465  }
466  }
467 
468  // too many inputs
469  if (!schema().is_vararg() && actuals.size() != formals.size()) {
470  // std::cout << "not all inputs used\n" << input_i << " " << inputs_size <<
471  // "\n";
472  return false;
473  }
474 
475  return true;
476 }
477 
478 std::shared_ptr<Operator> findOperatorFor(const Node* node) {
479  const auto& candidates = getAllOperatorsFor(node->kind());
480  for (const auto& candidate : candidates) {
481  if (candidate->matches(node)) {
482  return candidate;
483  }
484  }
485  return nullptr;
486 }
487 
488 const Operator& getOperatorFor(const Node* node) {
489  auto op = findOperatorFor(node);
490  if (op)
491  return *op;
492 
493  auto er = script::ErrorReport(node->getSourceLocation());
494  er << "Schema not found for node. File a bug report.\n";
495  er << "Node: " << *node << "\n";
496  er << "Input types:";
497  for (size_t i = 0; i < node->inputs().size(); ++i) {
498  if (i > 0)
499  er << ", ";
500  er << *node->inputs()[i]->type();
501  }
502  er << "\ncandidates were:\n";
503  const auto& candidates = getAllOperatorsFor(node->kind());
504  for (auto& candidate : candidates) {
505  er << " " << candidate->schema() << "\n";
506  }
507  er << *node->owningGraph() << "\n";
508  throw er;
509 }
510 
511 OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
512  auto& registry = getRegistry();
513  for (const char* sig : sig_literals) {
514  auto op = registry.lookupByLiteral(sig);
515  ops[Symbol::fromQualString(op->schema().name())].push_back(op);
516  }
517 }
518 
519 Operator* OperatorSet::find(const Node* n) const {
520  auto it = ops.find(n->kind());
521  if (it == ops.end()) {
522  return nullptr;
523  }
524  for (auto& op : it->second) {
525  if (op->matches(n)) {
526  return op.get();
527  }
528  }
529  return nullptr;
530 }
531 } // namespace jit
532 } // namespace torch
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41