1 #include <torch/csrc/jit/operator.h> 3 #include <torch/csrc/jit/alias_info.h> 4 #include <torch/csrc/jit/passes/alias_analysis.h> 5 #include <torch/csrc/jit/passes/python_print.h> 6 #include <torch/csrc/jit/script/edit_distance.h> 7 #include <torch/csrc/jit/script/error_report.h> 8 #include <torch/csrc/jit/script/lexer.h> 9 #include <torch/csrc/jit/script/parse_string_literal.h> 10 #include <torch/csrc/jit/script/schema_type_parser.h> 11 #include <torch/csrc/jit/script/tree.h> 25 : L(str), type_parser(L,
false) {}
28 std::string name = L.expect(TK_IDENT).text();
31 name = name +
"::" + L.expect(TK_IDENT).text();
33 std::string overload_name =
"";
35 overload_name = L.expect(TK_IDENT).text();
37 std::vector<Argument> arguments;
38 std::vector<Argument> returns;
39 bool kwarg_only =
false;
40 bool is_vararg =
false;
42 parseList(
'(',
',',
')', [&] {
45 <<
"... must be the last element of the argument list";
48 }
else if (L.nextIf(TK_DOTS)) {
51 arguments.push_back(parseArgument(
52 idx++,
false, kwarg_only));
57 if (L.cur().kind ==
'(') {
58 parseList(
'(',
',',
')', [&] {
60 parseArgument(idx++,
true,
false));
64 parseArgument(0,
true,
false));
67 std::move(name), std::move(overload_name), std::move(arguments), std::move(returns), is_vararg,
false};
70 std::vector<FunctionSchema> parseDeclarations() {
71 std::vector<FunctionSchema> results;
73 results.push_back(parseDeclaration());
74 }
while (L.nextIf(TK_NEWLINE));
79 TreeRef parseIdent() {
80 return String::create(L.expect(TK_IDENT).text());
83 Argument parseArgument(
size_t idx,
bool is_return,
bool kwarg_only) {
85 auto p = type_parser.parseType();
86 auto type = std::move(p.first);
87 auto alias_info = std::move(p.second);
94 type = ListType::create(type);
95 N = std::stoll(L.expect(TK_NUMBER).text());
97 auto container = type_parser.parseAliasAnnotation();
98 if (container && alias_info) {
99 container->addContainedType(std::move(*alias_info));
101 alias_info = std::move(container);
105 if (L.cur().kind == TK_IDENT) {
106 name = L.next().text();
111 name = L.expect(TK_IDENT).text();
113 default_value = parseDefaultValue(type, N);
120 std::move(default_value),
121 !is_return && kwarg_only,
122 std::move(alias_info));
124 IValue parseSingleConstant(TypeKind kind) {
125 switch (L.cur().kind) {
135 case TK_STRINGLITERAL: {
136 auto token = L.next();
137 return parseStringLiteral(token.range, token.text());
141 auto text = tok.text();
142 if (
"float" == text) {
143 return static_cast<int64_t
>(at::kFloat);
144 }
else if (
"long" == text) {
145 return static_cast<int64_t
>(at::kLong);
146 }
else if (
"strided" == text) {
147 return static_cast<int64_t
>(at::kStrided);
148 }
else if (
"Mean" == text) {
149 return static_cast<int64_t
>(Reduction::Mean);
151 throw ErrorReport(L.cur().range) <<
"invalid numeric default value";
157 n =
"-" + L.expect(TK_NUMBER).text();
159 n = L.expect(TK_NUMBER).text();
160 if (kind == TypeKind::FloatType || n.find(
'.') != std::string::npos ||
161 n.find(
'e') != std::string::npos) {
164 int64_t v = std::stoll(n);
172 std::vector<IValue> vs) {
174 case TypeKind::FloatType:
175 return fmap(vs, [](
IValue v) {
return v.toDouble(); });
176 case TypeKind::IntType:
177 return fmap(vs, [](
IValue v) {
return v.toInt(); });
178 case TypeKind::BoolType:
179 return fmap(vs, [](
IValue v) {
return v.toBool(); });
182 <<
"lists are only supported for float or int types.";
185 IValue parseConstantList(TypeKind kind) {
186 auto tok = L.expect(
'[');
187 std::vector<IValue> vs;
188 if (L.cur().kind !=
']') {
190 vs.push_back(parseSingleConstant(kind));
191 }
while (L.nextIf(
','));
194 return convertToList(kind, tok.range, std::move(vs));
202 const TypePtr& arg_type,
204 auto range = L.cur().range;
205 switch (arg_type->kind()) {
206 case TypeKind::TensorType:
207 case TypeKind::GeneratorType: {
208 return parseTensorDefault(range);
210 case TypeKind::StringType:
211 case TypeKind::OptionalType:
212 case TypeKind::NumberType:
213 case TypeKind::IntType:
214 case TypeKind::BoolType:
215 case TypeKind::FloatType:
216 return parseSingleConstant(arg_type->kind());
218 case TypeKind::DeviceObjType: {
220 parseStringLiteral(range, L.expect(TK_STRINGLITERAL).text());
224 case TypeKind::ListType: {
225 auto elem_kind = arg_type->cast<
ListType>()->getElementType();
226 if (L.cur().kind == TK_IDENT) {
227 return parseTensorDefault(range);
228 }
else if (arg_N && L.cur().kind !=
'[') {
229 IValue v = parseSingleConstant(elem_kind->kind());
230 std::vector<IValue> repeated(*arg_N, v);
231 return convertToList(elem_kind->kind(), range, repeated);
233 return parseConstantList(elem_kind->kind());
237 throw ErrorReport(range) <<
"unexpected type, file a bug report";
246 const std::function<
void()>& callback) {
247 auto r = L.cur().range;
248 if (begin != TK_NOTHING)
250 if (L.cur().kind != end) {
253 }
while (L.nextIf(sep));
255 if (end != TK_NOTHING)
265 std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
266 struct OperatorRegistry {
269 OperatorMap operators;
272 std::vector<std::shared_ptr<Operator>> to_register;
283 std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
284 std::unordered_map<const char*, std::shared_ptr<Operator>>
285 operators_by_sig_literal;
288 void registerPendingOperators() {
289 for (
const auto& op : to_register) {
290 Symbol sym = Symbol::fromQualString(op->schema().name());
291 operators[sym].push_back(op);
292 operators_by_sig[canonicalSchemaString(op->schema())] = op;
298 void registerOperator(
Operator&& op) {
299 std::lock_guard<std::mutex> guard(lock);
300 to_register.push_back(std::make_shared<Operator>(std::move(op)));
303 const std::shared_ptr<Operator>& lookupByLiteral(
const char* name) {
304 std::lock_guard<std::mutex> guard(lock);
305 registerPendingOperators();
306 auto it = operators_by_sig_literal.find(name);
307 if (it == operators_by_sig_literal.end()) {
309 operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
312 if (op_ptr_it == operators_by_sig.end()) {
313 for (
auto & entry : operators_by_sig) {
314 std::cout << entry.first << std::endl;
319 op_ptr_it != operators_by_sig.end(),
320 "Couldn't find an operator for ",
322 ". Do you have to update a set of hardcoded JIT ops?");
323 it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
328 const std::vector<std::shared_ptr<Operator>>& getOperators(
Symbol name) {
329 std::lock_guard<std::mutex> guard(lock);
330 registerPendingOperators();
331 static std::vector<std::shared_ptr<Operator>> empty;
332 auto it = operators.find(name);
333 if (it != operators.end())
338 std::vector<Symbol> findSimilarOperators(
Symbol input_op) {
339 std::lock_guard<std::mutex> guard(lock);
340 registerPendingOperators();
342 using EntryPair = std::pair<int64_t, Symbol>;
343 auto cmp = [](
const EntryPair& lhs,
const EntryPair& rhs) {
344 return lhs.first > rhs.first;
347 std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
349 static constexpr
size_t MAX_EDIT_DIST = 2u;
350 for (
const auto& op : operators) {
351 auto edit_dist = script::ComputeEditDistance(
352 input_op.toQualString(), op.first.toQualString(), MAX_EDIT_DIST);
353 if (edit_dist <= MAX_EDIT_DIST) {
354 rankings.emplace(edit_dist, op.first);
357 std::vector<Symbol> ret;
358 while (!rankings.empty()) {
359 ret.push_back(rankings.top().second);
366 OperatorRegistry& getRegistry() {
367 static OperatorRegistry r;
372 void registerOperator(
Operator&& op) {
373 if (op.schema().is_varret()) {
374 Symbol s = Symbol::fromQualString(op.schema().name());
375 if (!printerHasSpecialCaseFor(s)) {
377 "Missing special case in python printer for non-schematized" 380 ". File a bug to add a case for this operator.\n");
382 if (!aliasAnalysisHasSpecialCaseFor(s)) {
384 "Missing special case in alias analysis for non-schematized" 387 ". File a bug to add a case for this operator.\n");
391 getRegistry().registerOperator(std::move(op));
394 const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(
Symbol name) {
395 return getRegistry().getOperators(name);
398 std::vector<Symbol> findSimilarOperators(
Symbol input_op) {
399 return getRegistry().findSimilarOperators(input_op);
402 Operator& sig(
const char* signature) {
403 return *getRegistry().lookupByLiteral(signature);
411 std::ostringstream out;
413 out << schema.name();
416 bool seen_kwarg_only =
false;
417 for (
size_t i = 0; i < schema.arguments().size(); ++i) {
420 if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
422 seen_kwarg_only =
true;
424 const auto& arg = schema.arguments()[i];
425 out << arg.type()->str() <<
" " << arg.name();
429 if (schema.returns().size() == 1) {
430 out << schema.returns().at(0).type()->str();
431 }
else if (schema.returns().size() > 1) {
433 for (
size_t i = 0; i < schema.returns().size(); ++i) {
436 out << schema.returns()[i].type()->str();
443 bool Operator::matches(
const Node* node)
const {
445 if (node->kind().toQualString() != schema().name()) {
449 const auto& formals = schema().arguments();
452 if (actuals.
size() < formals.size())
456 for (
size_t i = 0; i < formals.size(); ++i) {
458 matchTypeVariables(formals[i].type(), actuals[i]->type(), type_env);
459 if (!matched_type.type) {
462 TypePtr formal = *matched_type.type;
463 if (!actuals[i]->type()->isSubtypeOf(formal)) {
469 if (!schema().is_vararg() && actuals.
size() != formals.size()) {
478 std::shared_ptr<Operator> findOperatorFor(
const Node* node) {
479 const auto& candidates = getAllOperatorsFor(node->kind());
480 for (
const auto& candidate : candidates) {
481 if (candidate->matches(node)) {
489 auto op = findOperatorFor(node);
494 er <<
"Schema not found for node. File a bug report.\n";
495 er <<
"Node: " << *node <<
"\n";
496 er <<
"Input types:";
497 for (
size_t i = 0; i < node->inputs().size(); ++i) {
500 er << *node->inputs()[i]->type();
502 er <<
"\ncandidates were:\n";
503 const auto& candidates = getAllOperatorsFor(node->kind());
504 for (
auto& candidate : candidates) {
505 er <<
" " << candidate->schema() <<
"\n";
507 er << *node->owningGraph() <<
"\n";
511 OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
512 auto& registry = getRegistry();
513 for (
const char* sig : sig_literals) {
514 auto op = registry.lookupByLiteral(sig);
515 ops[Symbol::fromQualString(op->schema().name())].push_back(op);
520 auto it = ops.find(n->kind());
521 if (it == ops.end()) {
524 for (
auto& op : it->second) {
525 if (op->matches(n)) {
Represents a a compute device on which a tensor is located.
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...