1 #include <torch/csrc/jit/script/parser.h> 2 #include <c10/util/Optional.h> 3 #include <torch/csrc/jit/script/lexer.h> 4 #include <torch/csrc/jit/script/parse_string_literal.h> 5 #include <torch/csrc/jit/script/tree.h> 6 #include <torch/csrc/jit/script/tree_views.h> 12 Decl mergeTypesFromTypeComment(
14 const Decl& type_annotation_decl,
16 auto expected_num_annotations = decl.params().size();
19 expected_num_annotations -= 1;
21 if (expected_num_annotations != type_annotation_decl.params().size()) {
22 throw ErrorReport(type_annotation_decl.range())
23 <<
"Number of type annotations (" 24 << type_annotation_decl.params().size()
25 <<
") did not match the number of " 26 <<
"function parameters (" << expected_num_annotations <<
")";
28 auto old = decl.params();
29 auto _new = type_annotation_decl.params();
32 std::vector<Param> new_params;
33 size_t i = is_method ? 1 : 0;
36 new_params.push_back(old[0]);
38 for (; i < decl.params().size(); ++i, ++j) {
39 new_params.emplace_back(old[i].withType(_new[j].type()));
43 List<Param>::create(decl.range(), new_params),
44 type_annotation_decl.return_type());
49 : L(str), shared(sharedParserData()) {}
52 auto t = L.expect(TK_IDENT);
56 return Ident::create(t.range, t.text());
58 TreeRef createApply(
const Expr& expr) {
60 auto range = L.cur().range;
62 parseOperatorArguments(inputs, attributes);
66 List<Expr>(makeList(range, std::move(inputs))),
70 static bool followsTuple(
int kind) {
86 Expr parseExpOrExpTuple() {
87 auto prefix = parseExp();
88 if (L.cur().kind ==
',') {
89 std::vector<Expr> exprs = {prefix};
90 while (L.nextIf(
',')) {
91 if (followsTuple(L.cur().kind))
93 exprs.push_back(parseExp());
96 prefix = TupleLiteral::create(list.range(), list);
104 switch (L.cur().kind) {
106 prefix = parseConst();
111 auto k = L.cur().kind;
112 auto r = L.cur().range;
113 prefix = c(k, r, {});
120 std::vector<Expr> vecExpr;
122 prefix = TupleLiteral::create(L.cur().range, listExpr);
125 prefix = parseExpOrExpTuple();
129 auto list = parseList(
'[',
',',
']', &ParserImpl::parseExp);
130 prefix = ListLiteral::create(list.range(),
List<Expr>(list));
134 std::vector<Expr> keys;
135 std::vector<Expr> values;
136 auto range = L.cur().range;
137 if (L.cur().kind !=
'}') {
139 keys.push_back(parseExp());
141 values.push_back(parseExp());
142 }
while (L.nextIf(
','));
145 prefix = DictLiteral::create(
150 case TK_STRINGLITERAL: {
151 prefix = parseConcatenatedStringLiterals();
154 Ident name = parseIdent();
155 prefix = Var::create(name.range(), name);
160 const auto name = parseIdent();
161 prefix = Select::create(name.range(),
Expr(prefix),
Ident(name));
162 }
else if (L.cur().kind ==
'(') {
163 prefix = createApply(
Expr(prefix));
164 }
else if (L.cur().kind ==
'[') {
165 prefix = parseSubscript(prefix);
172 TreeRef parseAssignmentOp() {
173 auto r = L.cur().range;
174 switch (L.cur().kind) {
179 int modifier = L.next().text()[0];
180 return c(modifier, r, {});
184 return c(
'=', r, {});
188 TreeRef parseTrinary(
192 auto cond = parseExp();
194 auto false_branch = parseExp(binary_prec);
195 return c(TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
204 Expr parseExp(
int precedence) {
205 TreeRef prefix =
nullptr;
207 if (shared.isUnary(L.cur().kind, &unary_prec)) {
208 auto kind = L.cur().kind;
209 auto pos = L.cur().range;
212 kind ==
'*' ? TK_STARRED : kind ==
'-' ? TK_UNARY_MINUS : kind;
213 auto subexp = parseExp(unary_prec);
216 if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
217 prefix = Const::create(subexp.range(),
"-" +
Const(subexp).text());
219 prefix = c(unary_kind, pos, {subexp});
225 while (shared.isBinary(L.cur().kind, &binary_prec)) {
226 if (binary_prec <= precedence)
230 int kind = L.cur().kind;
231 auto pos = L.cur().range;
233 if (shared.isRightAssociative(kind))
238 prefix = parseTrinary(prefix, pos, binary_prec);
242 prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
250 const std::function<
void()>& parse) {
251 if (begin != TK_NOTHING)
253 if (L.cur().kind != end) {
256 }
while (L.nextIf(sep));
258 if (end != TK_NOTHING)
261 template <
typename T>
263 auto r = L.cur().range;
264 std::vector<T> elements;
266 begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
271 auto range = L.cur().range;
272 auto t = L.expect(TK_NUMBER);
273 return Const::create(t.range, t.text());
277 auto range = L.cur().range;
278 std::stringstream ss;
279 while (L.cur().kind == TK_STRINGLITERAL) {
280 auto literal_range = L.cur().range;
281 ss << parseStringLiteral(literal_range, L.next().text());
283 return StringLiteral::create(range, ss.str());
286 Expr parseAttributeValue() {
290 void parseOperatorArguments(TreeList& inputs, TreeList& attributes) {
292 if (L.cur().kind !=
')') {
294 if (L.cur().kind == TK_IDENT && L.lookahead().kind ==
'=') {
295 auto ident = parseIdent();
297 auto v = parseAttributeValue();
298 attributes.push_back(
299 Attribute::create(ident.range(),
Ident(ident), v));
301 inputs.push_back(parseExp());
303 }
while (L.nextIf(
','));
309 Expr parseSubscriptExp() {
310 TreeRef first, second;
311 auto range = L.cur().range;
312 if (L.cur().kind !=
':') {
316 if (L.cur().kind !=
',' && L.cur().kind !=
']') {
320 : Maybe<Expr>::create(range);
321 auto maybe_second = second ? Maybe<Expr>::create(range,
Expr(second))
322 : Maybe<Expr>::create(range);
323 return SliceExpr::create(range, maybe_first, maybe_second);
329 TreeRef parseSubscript(
const TreeRef& value) {
330 const auto range = L.cur().range;
332 auto subscript_exprs =
333 parseList(
'[',
',',
']', &ParserImpl::parseSubscriptExp);
334 return Subscript::create(range,
Expr(value), subscript_exprs);
337 TreeRef parseParam(
bool kwarg_only) {
338 auto ident = parseIdent();
343 type = Var::create(L.cur().range, Ident::create(L.cur().range,
"Tensor"));
351 return Param::create(
355 Param parseBareTypeAnnotation() {
356 auto type = parseExp();
357 return Param::create(
359 Ident::create(type.range(),
""),
365 Decl parseTypeComment() {
366 auto range = L.cur().range;
367 L.expect(TK_TYPE_COMMENT);
369 parseList(
'(',
',',
')', &ParserImpl::parseBareTypeAnnotation);
371 if (L.nextIf(TK_ARROW)) {
372 auto return_type_range = L.cur().range;
377 return Decl::create(range, param_types,
Maybe<Expr>(return_type));
383 TreeRef parseAssign(
const Expr& lhs) {
384 auto op = parseAssignmentOp();
385 auto rhs = parseExpOrExpTuple();
386 L.expect(TK_NEWLINE);
387 if (op->kind() ==
'=') {
388 return Assign::create(lhs.range(), lhs,
Expr(rhs));
391 if (lhs.kind() == TK_TUPLE_LITERAL) {
393 <<
" augmented assignment can only have one LHS expression";
399 TreeRef parseStmt() {
400 switch (L.cur().kind) {
408 auto range = L.next().range;
410 parseList(TK_NOTHING,
',', TK_NOTHING, &ParserImpl::parseIdent);
411 L.expect(TK_NEWLINE);
412 return Global::create(range, idents);
415 auto range = L.next().range;
416 Expr value = L.cur().kind != TK_NEWLINE ? parseExpOrExpTuple()
417 :
Expr(c(TK_NONE, range, {}));
418 L.expect(TK_NEWLINE);
419 return Return::create(range, value);
422 auto range = L.next().range;
423 auto expr = parseExp();
424 L.expect(TK_NEWLINE);
425 return Raise::create(range, expr);
428 auto range = L.next().range;
429 auto cond = parseExp();
432 auto msg = parseExp();
435 L.expect(TK_NEWLINE);
436 return Assert::create(range, cond, maybe_first);
439 auto range = L.next().range;
440 L.expect(TK_NEWLINE);
441 return Pass::create(range);
444 return parseFunction(
false);
447 auto lhs = parseExpOrExpTuple();
448 if (L.cur().kind != TK_NEWLINE) {
449 return parseAssign(lhs);
451 L.expect(TK_NEWLINE);
452 return ExprStmt::create(lhs.range(), lhs);
457 TreeRef parseOptionalIdentList() {
458 TreeRef list =
nullptr;
459 if (L.cur().kind ==
'(') {
460 list = parseList(
'(',
',',
')', &ParserImpl::parseIdent);
462 list = c(TK_LIST, L.cur().range, {});
466 TreeRef parseIf(
bool expect_if =
true) {
467 auto r = L.cur().range;
470 auto cond = parseExp();
472 auto true_branch = parseStatements();
473 auto false_branch = makeList(L.cur().range, {});
474 if (L.nextIf(TK_ELSE)) {
476 false_branch = parseStatements();
477 }
else if (L.nextIf(TK_ELIF)) {
481 auto range = L.cur().range;
482 false_branch = makeList(range, {parseIf(
false)});
487 TreeRef parseWhile() {
488 auto r = L.cur().range;
490 auto cond = parseExp();
492 auto body = parseStatements();
496 auto r = L.cur().range;
499 parseList(TK_NOTHING,
',', TK_NOTHING, &ParserImpl::parseExp);
501 auto itrs = parseList(TK_NOTHING,
',', TK_NOTHING, &ParserImpl::parseExp);
503 auto body = parseStatements();
504 return For::create(r, targets, itrs, body);
507 TreeRef parseStatements(
bool expect_indent =
true) {
508 auto r = L.cur().range;
514 stmts.push_back(parseStmt());
515 }
while (!L.nextIf(TK_DEDENT));
516 return c(TK_LIST, r, std::move(stmts));
520 if (L.nextIf(TK_ARROW)) {
522 auto return_type_range = L.cur().range;
530 auto r = L.cur().range;
531 std::vector<Param> params;
532 bool kwarg_only =
false;
533 parseSequence(
'(',
',',
')', [&] {
534 if (!kwarg_only && L.nextIf(
'*')) {
537 params.emplace_back(parseParam(kwarg_only));
546 Maybe<Expr> return_annotation = parseReturnAnnotation();
549 paramlist.range(),
List<Param>(paramlist), return_annotation);
552 TreeRef parseClass() {
553 L.expect(TK_CLASS_DEF);
554 const auto name = parseIdent();
559 std::vector<Def> methods;
560 while (L.cur().kind != TK_DEDENT) {
561 methods.push_back(
Def(parseFunction(
true)));
565 return ClassDef::create(
569 TreeRef parseFunction(
bool is_method) {
571 auto name = parseIdent();
572 auto decl = parseDecl();
577 if (L.cur().kind == TK_TYPE_COMMENT) {
578 auto type_annotation_decl =
Decl(parseTypeComment());
579 L.expect(TK_NEWLINE);
580 decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
583 auto stmts_list = parseStatements(
false);
593 TreeRef c(
int kind,
const SourceRange& range, TreeList&& trees) {
594 return Compound::create(kind, range, std::move(trees));
596 TreeRef makeList(
const SourceRange& range, TreeList&& trees) {
597 return c(TK_LIST, range, std::move(trees));
603 Parser::Parser(
const std::string& src) : pImpl(
new ParserImpl(src)) {}
605 Parser::~Parser() =
default;
607 TreeRef Parser::parseFunction(
bool is_method) {
608 return pImpl->parseFunction(is_method);
610 TreeRef Parser::parseClass() {
611 return pImpl->parseClass();
613 Lexer& Parser::lexer() {
614 return pImpl->lexer();
616 Decl Parser::parseTypeComment() {
617 return pImpl->parseTypeComment();
619 Expr Parser::parseExp() {
620 return pImpl->parseExp();