Caffe2 - C++ API
A deep learning, cross platform ML framework
parser.cpp
1 #include <torch/csrc/jit/script/parser.h>
2 #include <c10/util/Optional.h>
3 #include <torch/csrc/jit/script/lexer.h>
4 #include <torch/csrc/jit/script/parse_string_literal.h>
5 #include <torch/csrc/jit/script/tree.h>
6 #include <torch/csrc/jit/script/tree_views.h>
7 
8 namespace torch {
9 namespace jit {
10 namespace script {
11 
12 Decl mergeTypesFromTypeComment(
13  const Decl& decl,
14  const Decl& type_annotation_decl,
15  bool is_method) {
16  auto expected_num_annotations = decl.params().size();
17  if (is_method) {
18  // `self` argument
19  expected_num_annotations -= 1;
20  }
21  if (expected_num_annotations != type_annotation_decl.params().size()) {
22  throw ErrorReport(type_annotation_decl.range())
23  << "Number of type annotations ("
24  << type_annotation_decl.params().size()
25  << ") did not match the number of "
26  << "function parameters (" << expected_num_annotations << ")";
27  }
28  auto old = decl.params();
29  auto _new = type_annotation_decl.params();
30  // Merge signature idents and ranges with annotation types
31 
32  std::vector<Param> new_params;
33  size_t i = is_method ? 1 : 0;
34  size_t j = 0;
35  if (is_method) {
36  new_params.push_back(old[0]);
37  }
38  for (; i < decl.params().size(); ++i, ++j) {
39  new_params.emplace_back(old[i].withType(_new[j].type()));
40  }
41  return Decl::create(
42  decl.range(),
43  List<Param>::create(decl.range(), new_params),
44  type_annotation_decl.return_type());
45 }
46 
47 struct ParserImpl {
48  explicit ParserImpl(const std::string& str)
49  : L(str), shared(sharedParserData()) {}
50 
51  Ident parseIdent() {
52  auto t = L.expect(TK_IDENT);
53  // whenever we parse something that has a TreeView type we always
54  // use its create method so that the accessors and the constructor
55  // of the Compound tree are in the same place.
56  return Ident::create(t.range, t.text());
57  }
58  TreeRef createApply(const Expr& expr) {
59  TreeList attributes;
60  auto range = L.cur().range;
61  TreeList inputs;
62  parseOperatorArguments(inputs, attributes);
63  return Apply::create(
64  range,
65  expr,
66  List<Expr>(makeList(range, std::move(inputs))),
67  List<Attribute>(makeList(range, std::move(attributes))));
68  }
69 
70  static bool followsTuple(int kind) {
71  switch (kind) {
72  case TK_PLUS_EQ:
73  case TK_MINUS_EQ:
74  case TK_TIMES_EQ:
75  case TK_DIV_EQ:
76  case TK_NEWLINE:
77  case '=':
78  case ')':
79  return true;
80  default:
81  return false;
82  }
83  }
84 
85  // exp | expr, | expr, expr, ...
86  Expr parseExpOrExpTuple() {
87  auto prefix = parseExp();
88  if (L.cur().kind == ',') {
89  std::vector<Expr> exprs = {prefix};
90  while (L.nextIf(',')) {
91  if (followsTuple(L.cur().kind))
92  break;
93  exprs.push_back(parseExp());
94  }
95  auto list = List<Expr>::create(prefix.range(), exprs);
96  prefix = TupleLiteral::create(list.range(), list);
97  }
98  return prefix;
99  }
100  // things like a 1.0 or a(4) that are not unary/binary expressions
101  // and have higher precedence than all of them
102  TreeRef parseBaseExp() {
103  TreeRef prefix;
104  switch (L.cur().kind) {
105  case TK_NUMBER: {
106  prefix = parseConst();
107  } break;
108  case TK_TRUE:
109  case TK_FALSE:
110  case TK_NONE: {
111  auto k = L.cur().kind;
112  auto r = L.cur().range;
113  prefix = c(k, r, {});
114  L.next();
115  } break;
116  case '(': {
117  L.next();
118  if (L.nextIf(')')) {
120  std::vector<Expr> vecExpr;
121  List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
122  prefix = TupleLiteral::create(L.cur().range, listExpr);
123  break;
124  }
125  prefix = parseExpOrExpTuple();
126  L.expect(')');
127  } break;
128  case '[': {
129  auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
130  prefix = ListLiteral::create(list.range(), List<Expr>(list));
131  } break;
132  case '{': {
133  L.next();
134  std::vector<Expr> keys;
135  std::vector<Expr> values;
136  auto range = L.cur().range;
137  if (L.cur().kind != '}') {
138  do {
139  keys.push_back(parseExp());
140  L.expect(':');
141  values.push_back(parseExp());
142  } while (L.nextIf(','));
143  }
144  L.expect('}');
145  prefix = DictLiteral::create(
146  range,
147  List<Expr>::create(range, keys),
148  List<Expr>::create(range, values));
149  } break;
150  case TK_STRINGLITERAL: {
151  prefix = parseConcatenatedStringLiterals();
152  } break;
153  default: {
154  Ident name = parseIdent();
155  prefix = Var::create(name.range(), name);
156  } break;
157  }
158  while (true) {
159  if (L.nextIf('.')) {
160  const auto name = parseIdent();
161  prefix = Select::create(name.range(), Expr(prefix), Ident(name));
162  } else if (L.cur().kind == '(') {
163  prefix = createApply(Expr(prefix));
164  } else if (L.cur().kind == '[') {
165  prefix = parseSubscript(prefix);
166  } else {
167  break;
168  }
169  }
170  return prefix;
171  }
172  TreeRef parseAssignmentOp() {
173  auto r = L.cur().range;
174  switch (L.cur().kind) {
175  case TK_PLUS_EQ:
176  case TK_MINUS_EQ:
177  case TK_TIMES_EQ:
178  case TK_DIV_EQ: {
179  int modifier = L.next().text()[0];
180  return c(modifier, r, {});
181  } break;
182  default: {
183  L.expect('=');
184  return c('=', r, {}); // no reduction
185  } break;
186  }
187  }
188  TreeRef parseTrinary(
189  TreeRef true_branch,
190  const SourceRange& range,
191  int binary_prec) {
192  auto cond = parseExp();
193  L.expect(TK_ELSE);
194  auto false_branch = parseExp(binary_prec);
195  return c(TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
196  }
197  // parse the longest expression whose binary operators have
198  // precedence strictly greater than 'precedence'
199  // precedence == 0 will parse _all_ expressions
200  // this is the core loop of 'top-down precedence parsing'
201  Expr parseExp() {
202  return parseExp(0);
203  }
204  Expr parseExp(int precedence) {
205  TreeRef prefix = nullptr;
206  int unary_prec;
207  if (shared.isUnary(L.cur().kind, &unary_prec)) {
208  auto kind = L.cur().kind;
209  auto pos = L.cur().range;
210  L.next();
211  auto unary_kind =
212  kind == '*' ? TK_STARRED : kind == '-' ? TK_UNARY_MINUS : kind;
213  auto subexp = parseExp(unary_prec);
214  // fold '-' into constant numbers, so that attributes can accept
215  // things like -1
216  if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
217  prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
218  } else {
219  prefix = c(unary_kind, pos, {subexp});
220  }
221  } else {
222  prefix = parseBaseExp();
223  }
224  int binary_prec;
225  while (shared.isBinary(L.cur().kind, &binary_prec)) {
226  if (binary_prec <= precedence) // not allowed to parse something which is
227  // not greater than 'precedence'
228  break;
229 
230  int kind = L.cur().kind;
231  auto pos = L.cur().range;
232  L.next();
233  if (shared.isRightAssociative(kind))
234  binary_prec--;
235 
236  // special case for trinary operator
237  if (kind == TK_IF) {
238  prefix = parseTrinary(prefix, pos, binary_prec);
239  continue;
240  }
241 
242  prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
243  }
244  return Expr(prefix);
245  }
246  void parseSequence(
247  int begin,
248  int sep,
249  int end,
250  const std::function<void()>& parse) {
251  if (begin != TK_NOTHING)
252  L.expect(begin);
253  if (L.cur().kind != end) {
254  do {
255  parse();
256  } while (L.nextIf(sep));
257  }
258  if (end != TK_NOTHING)
259  L.expect(end);
260  }
261  template <typename T>
262  List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
263  auto r = L.cur().range;
264  std::vector<T> elements;
265  parseSequence(
266  begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
267  return List<T>::create(r, elements);
268  }
269 
270  Const parseConst() {
271  auto range = L.cur().range;
272  auto t = L.expect(TK_NUMBER);
273  return Const::create(t.range, t.text());
274  }
275 
276  StringLiteral parseConcatenatedStringLiterals() {
277  auto range = L.cur().range;
278  std::stringstream ss;
279  while (L.cur().kind == TK_STRINGLITERAL) {
280  auto literal_range = L.cur().range;
281  ss << parseStringLiteral(literal_range, L.next().text());
282  }
283  return StringLiteral::create(range, ss.str());
284  }
285 
286  Expr parseAttributeValue() {
287  return parseExp();
288  }
289 
290  void parseOperatorArguments(TreeList& inputs, TreeList& attributes) {
291  L.expect('(');
292  if (L.cur().kind != ')') {
293  do {
294  if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
295  auto ident = parseIdent();
296  L.expect('=');
297  auto v = parseAttributeValue();
298  attributes.push_back(
299  Attribute::create(ident.range(), Ident(ident), v));
300  } else {
301  inputs.push_back(parseExp());
302  }
303  } while (L.nextIf(','));
304  }
305  L.expect(')');
306  }
307 
308  // Parse expr's of the form [a:], [:b], [a:b], [:]
309  Expr parseSubscriptExp() {
310  TreeRef first, second;
311  auto range = L.cur().range;
312  if (L.cur().kind != ':') {
313  first = parseExp();
314  }
315  if (L.nextIf(':')) {
316  if (L.cur().kind != ',' && L.cur().kind != ']') {
317  second = parseExp();
318  }
319  auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
320  : Maybe<Expr>::create(range);
321  auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
322  : Maybe<Expr>::create(range);
323  return SliceExpr::create(range, maybe_first, maybe_second);
324  } else {
325  return Expr(first);
326  }
327  }
328 
329  TreeRef parseSubscript(const TreeRef& value) {
330  const auto range = L.cur().range;
331 
332  auto subscript_exprs =
333  parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
334  return Subscript::create(range, Expr(value), subscript_exprs);
335  }
336 
337  TreeRef parseParam(bool kwarg_only) {
338  auto ident = parseIdent();
339  TreeRef type;
340  if (L.nextIf(':')) {
341  type = parseExp();
342  } else {
343  type = Var::create(L.cur().range, Ident::create(L.cur().range, "Tensor"));
344  }
345  TreeRef def;
346  if (L.nextIf('=')) {
347  def = Maybe<Expr>::create(L.cur().range, parseExp());
348  } else {
349  def = Maybe<Expr>::create(L.cur().range);
350  }
351  return Param::create(
352  type->range(), Ident(ident), Expr(type), Maybe<Expr>(def), kwarg_only);
353  }
354 
355  Param parseBareTypeAnnotation() {
356  auto type = parseExp();
357  return Param::create(
358  type.range(),
359  Ident::create(type.range(), ""),
360  type,
361  Maybe<Expr>::create(type.range()),
362  /*kwarg_only=*/false);
363  }
364 
365  Decl parseTypeComment() {
366  auto range = L.cur().range;
367  L.expect(TK_TYPE_COMMENT);
368  auto param_types =
369  parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
370  TreeRef return_type;
371  if (L.nextIf(TK_ARROW)) {
372  auto return_type_range = L.cur().range;
373  return_type = Maybe<Expr>::create(return_type_range, parseExp());
374  } else {
375  return_type = Maybe<Expr>::create(L.cur().range);
376  }
377  return Decl::create(range, param_types, Maybe<Expr>(return_type));
378  }
379 
380  // 'first' has already been parsed since expressions can exist
381  // alone on a line:
382  // first[,other,lhs] = rhs
383  TreeRef parseAssign(const Expr& lhs) {
384  auto op = parseAssignmentOp();
385  auto rhs = parseExpOrExpTuple();
386  L.expect(TK_NEWLINE);
387  if (op->kind() == '=') {
388  return Assign::create(lhs.range(), lhs, Expr(rhs));
389  } else {
390  // this is an augmented assignment
391  if (lhs.kind() == TK_TUPLE_LITERAL) {
392  throw ErrorReport(lhs.range())
393  << " augmented assignment can only have one LHS expression";
394  }
395  return AugAssign::create(lhs.range(), lhs, AugAssignKind(op), Expr(rhs));
396  }
397  }
398 
399  TreeRef parseStmt() {
400  switch (L.cur().kind) {
401  case TK_IF:
402  return parseIf();
403  case TK_WHILE:
404  return parseWhile();
405  case TK_FOR:
406  return parseFor();
407  case TK_GLOBAL: {
408  auto range = L.next().range;
409  auto idents =
410  parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
411  L.expect(TK_NEWLINE);
412  return Global::create(range, idents);
413  }
414  case TK_RETURN: {
415  auto range = L.next().range;
416  Expr value = L.cur().kind != TK_NEWLINE ? parseExpOrExpTuple()
417  : Expr(c(TK_NONE, range, {}));
418  L.expect(TK_NEWLINE);
419  return Return::create(range, value);
420  }
421  case TK_RAISE: {
422  auto range = L.next().range;
423  auto expr = parseExp();
424  L.expect(TK_NEWLINE);
425  return Raise::create(range, expr);
426  }
427  case TK_ASSERT: {
428  auto range = L.next().range;
429  auto cond = parseExp();
430  Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
431  if (L.nextIf(',')) {
432  auto msg = parseExp();
433  maybe_first = Maybe<Expr>::create(range, Expr(msg));
434  }
435  L.expect(TK_NEWLINE);
436  return Assert::create(range, cond, maybe_first);
437  }
438  case TK_PASS: {
439  auto range = L.next().range;
440  L.expect(TK_NEWLINE);
441  return Pass::create(range);
442  }
443  case TK_DEF: {
444  return parseFunction(/*is_method=*/false);
445  }
446  default: {
447  auto lhs = parseExpOrExpTuple();
448  if (L.cur().kind != TK_NEWLINE) {
449  return parseAssign(lhs);
450  } else {
451  L.expect(TK_NEWLINE);
452  return ExprStmt::create(lhs.range(), lhs);
453  }
454  }
455  }
456  }
457  TreeRef parseOptionalIdentList() {
458  TreeRef list = nullptr;
459  if (L.cur().kind == '(') {
460  list = parseList('(', ',', ')', &ParserImpl::parseIdent);
461  } else {
462  list = c(TK_LIST, L.cur().range, {});
463  }
464  return list;
465  }
466  TreeRef parseIf(bool expect_if = true) {
467  auto r = L.cur().range;
468  if (expect_if)
469  L.expect(TK_IF);
470  auto cond = parseExp();
471  L.expect(':');
472  auto true_branch = parseStatements();
473  auto false_branch = makeList(L.cur().range, {});
474  if (L.nextIf(TK_ELSE)) {
475  L.expect(':');
476  false_branch = parseStatements();
477  } else if (L.nextIf(TK_ELIF)) {
478  // NB: this needs to be a separate statement, since the call to parseIf
479  // mutates the lexer state, and thus causes a heap-use-after-free in
480  // compilers which evaluate argument expressions LTR
481  auto range = L.cur().range;
482  false_branch = makeList(range, {parseIf(false)});
483  }
484  return If::create(
485  r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
486  }
487  TreeRef parseWhile() {
488  auto r = L.cur().range;
489  L.expect(TK_WHILE);
490  auto cond = parseExp();
491  L.expect(':');
492  auto body = parseStatements();
493  return While::create(r, Expr(cond), List<Stmt>(body));
494  }
495  TreeRef parseFor() {
496  auto r = L.cur().range;
497  L.expect(TK_FOR);
498  auto targets =
499  parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
500  L.expect(TK_IN);
501  auto itrs = parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
502  L.expect(':');
503  auto body = parseStatements();
504  return For::create(r, targets, itrs, body);
505  }
506 
507  TreeRef parseStatements(bool expect_indent = true) {
508  auto r = L.cur().range;
509  if (expect_indent) {
510  L.expect(TK_INDENT);
511  }
512  TreeList stmts;
513  do {
514  stmts.push_back(parseStmt());
515  } while (!L.nextIf(TK_DEDENT));
516  return c(TK_LIST, r, std::move(stmts));
517  }
518 
519  Maybe<Expr> parseReturnAnnotation() {
520  if (L.nextIf(TK_ARROW)) {
521  // Exactly one expression for return type annotation
522  auto return_type_range = L.cur().range;
523  return Maybe<Expr>::create(return_type_range, parseExp());
524  } else {
525  return Maybe<Expr>::create(L.cur().range);
526  }
527  }
528 
529  List<Param> parseParams() {
530  auto r = L.cur().range;
531  std::vector<Param> params;
532  bool kwarg_only = false;
533  parseSequence('(', ',', ')', [&] {
534  if (!kwarg_only && L.nextIf('*')) {
535  kwarg_only = true;
536  } else {
537  params.emplace_back(parseParam(kwarg_only));
538  }
539  });
540  return List<Param>::create(r, params);
541  }
542  Decl parseDecl() {
543  // Parse return type annotation
544  List<Param> paramlist = parseParams();
545  TreeRef return_type;
546  Maybe<Expr> return_annotation = parseReturnAnnotation();
547  L.expect(':');
548  return Decl::create(
549  paramlist.range(), List<Param>(paramlist), return_annotation);
550  }
551 
552  TreeRef parseClass() {
553  L.expect(TK_CLASS_DEF);
554  const auto name = parseIdent();
555  // TODO no inheritance or () allowed right now
556  L.expect(':');
557 
558  L.expect(TK_INDENT);
559  std::vector<Def> methods;
560  while (L.cur().kind != TK_DEDENT) {
561  methods.push_back(Def(parseFunction(/*is_method=*/true)));
562  }
563  L.expect(TK_DEDENT);
564 
565  return ClassDef::create(
566  name.range(), name, List<Def>::create(name.range(), methods));
567  }
568 
569  TreeRef parseFunction(bool is_method) {
570  L.expect(TK_DEF);
571  auto name = parseIdent();
572  auto decl = parseDecl();
573 
574  // Handle type annotations specified in a type comment as the first line of
575  // the function.
576  L.expect(TK_INDENT);
577  if (L.cur().kind == TK_TYPE_COMMENT) {
578  auto type_annotation_decl = Decl(parseTypeComment());
579  L.expect(TK_NEWLINE);
580  decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
581  }
582 
583  auto stmts_list = parseStatements(false);
584  return Def::create(
585  name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
586  }
587  Lexer& lexer() {
588  return L;
589  }
590 
591  private:
592  // short helpers to create nodes
593  TreeRef c(int kind, const SourceRange& range, TreeList&& trees) {
594  return Compound::create(kind, range, std::move(trees));
595  }
596  TreeRef makeList(const SourceRange& range, TreeList&& trees) {
597  return c(TK_LIST, range, std::move(trees));
598  }
599  Lexer L;
600  SharedParserData& shared;
601 };
602 
603 Parser::Parser(const std::string& src) : pImpl(new ParserImpl(src)) {}
604 
605 Parser::~Parser() = default;
606 
607 TreeRef Parser::parseFunction(bool is_method) {
608  return pImpl->parseFunction(is_method);
609 }
610 TreeRef Parser::parseClass() {
611  return pImpl->parseClass();
612 }
613 Lexer& Parser::lexer() {
614  return pImpl->lexer();
615 }
616 Decl Parser::parseTypeComment() {
617  return pImpl->parseTypeComment();
618 }
619 Expr Parser::parseExp() {
620  return pImpl->parseExp();
621 }
622 
623 } // namespace script
624 } // namespace jit
625 } // namespace torch
Definition: jit_type.h:17