Caffe2 - C++ API
A deep learning, cross platform ML framework
tree_views.h
1 #pragma once
2 #include <torch/csrc/jit/script/error_report.h>
3 #include <torch/csrc/jit/script/tree.h>
4 
5 #include <functional>
6 #include <string>
7 
8 namespace torch {
9 namespace jit {
10 namespace script {
11 
12 // clang-format off
13 // TreeView provides a statically-typed way to traverse the tree, which should
14 // be formed according to the grammar below.
15 //
16 // A few notes on types and their aliases:
17 // - List<T> is really a Tree with kind TK_LIST and elements as subtrees
18 // - Maybe<T> is really a Tree with kind TK_OPTION that has 0 or 1 subtree of type T
19 // - Builtin types are: Ident (TK_IDENT), String (TK_STRING)
20 //
21 // Param = Param(Expr type, Ident name) TK_PARAM
22 //
23 // Decl = Decl(List<Param> params, Maybe<Expr> return_type) TK_DECL
24 // Def = Def(Ident name, Decl decl, List<Stmt> body) TK_DEF
25 // ClassDef = ClassDef(Ident name, List<Def> body) TK_CLASS_DEF
26 //
27 // Stmt = If(Expr cond, List<Stmt> true_body, List<Stmt> false_body) TK_IF
28 // | For(List<Expr> targets, List<Expr> iters, List<Stmt> body) TK_FOR
29 // | While(Expr cond, List<Stmt> body) TK_WHILE
30 // | Global(List<Ident> idents) TK_GLOBAL
31 // -- NB: the only type of Expr's allowed on lhs are Var
32 // Or a tuple containing Var with an optional terminating Starred
33 // | Assign(Expr lhs, Expr rhs) TK_ASSIGN
34 // | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs) TK_AUG_ASSIGN
35 // | Return(List<Expr> values) TK_RETURN
36 // | ExprStmt(List<Expr> expr) TK_EXPR_STMT
37 // | Raise(Expr expr) TK_RAISE
38 // | Def TK_DEF
39 //
40 // Expr = TernaryIf(Expr cond, Expr true_expr, Expr false_expr) TK_IF_EXPR
41 // | BinOp(Expr lhs, Expr rhs)
42 // | And TK_AND
43 // | Or TK_OR
44 // | Lt '<'
45 // | Gt '>'
46 // | Eq TK_EQ
47 // | Le TK_LE
48 // | Ge TK_GE
49 // | Ne TK_NE
50 // | Is TK_IS
51 // | IsNot TK_ISNOT
52 // | Add '+'
53 // | Sub '-'
54 // | Mul '*'
55 // | Div '/'
56 // | Mod '%'
57 // | MatMult '@'
58 // | Pow TK_POW
59 // | UnaryOp(Expr expr)
60 // | Not TK_NOT
61 // | USub '-'
62 // | Const(String value) TK_CONST
63 // -- NB: x.name(y) is desugared into name(x, y)
64 // | Apply(Ident name, List<Expr> args, List<Attribute> kwargs) TK_APPLY
65 // | Select(Expr value, Ident selector) '.'
66 // | Subscript(Expr value, List<Expr> subscript_exprs) TK_SUBSCRIPT
67 // | SliceExpr(Maybe<Expr> start, Maybe<Expr> end) TK_SLICE_EXPR
68 // | Var(Ident name) TK_VAR
69 // | ListLiteral(List<Expr> inputs) TK_LIST_LITERAL
70 // | TupleLiteral(List<Expr> inputs) TK_TUPLE_LITERAL
71 // | Starred(Expr expr) TK_STARRED
72 //
73 // -- NB: only allowed expressions are Const or List(Const)
74 // (List as a value, not type constructor)
75 // Attribute = Attribute(Ident name, Expr value) TK_ATTRIBUTE
76 //
77 // AugAssignKind =
78 // | Add() TK_PLUS_EQ
79 // | Sub() TK_MINUS_EQ
80 // | Mul() TK_TIMES_EQ
81 // | Div() TK_DIV_EQ
82 //
83 
84 // Each subclass of TreeView should provide:
85 // 1. Constructor that takes a TreeRef, and checks that it's of the right type.
86 // 2. Accessors that get underlying information out of the object. If they
87 // return subtrees, they should wrap them in appropriate views too.
88 // 3. Static method 'create' that creates the underlying TreeRef object
89 // for every TreeRef kind that has a TreeView, the parser always uses
90 // (e.g.) Ident::create rather than Compound::Create, this means that
91 // changes to the structure of Ident are always made right here rather
92 // than both in the parser and in this code.
93 // XXX: these structs should have no fields to prevent slicing when passing by value
94 // clang-format on
95 struct TreeView {
96  explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
97  TreeRef tree() const {
98  return tree_;
99  }
100  const SourceRange& range() const {
101  return tree_->range();
102  }
103  operator TreeRef() const {
104  return tree_;
105  }
106  const TreeRef& get() const {
107  return tree_;
108  }
109  int kind() const {
110  return tree_->kind();
111  }
112  void dump() const {
113  std::cout << tree_;
114  }
115 
116  protected:
117  const TreeRef& subtree(size_t i) const {
118  return tree_->trees().at(i);
119  }
120  TreeRef tree_;
121 };
122 
123 template <typename T>
124 struct ListIterator {
125  ListIterator(TreeList::const_iterator it) : it(it) {}
126  bool operator!=(const ListIterator& rhs) const {
127  return it != rhs.it;
128  }
129  bool operator==(const ListIterator& rhs) const {
130  return it == rhs.it;
131  }
132  T operator*() const {
133  return T(*it);
134  }
135  ListIterator& operator+=(std::ptrdiff_t n) {
136  it += n;
137  return *this;
138  }
139  ListIterator& operator++() {
140  ++it;
141  return *this;
142  }
143  ListIterator& operator--() {
144  --it;
145  return *this;
146  }
147 
148  private:
149  TreeList::const_iterator it;
150 };
151 
152 template <typename T>
153 struct List : public TreeView {
154  using iterator = ListIterator<T>;
156 
157  List(const TreeRef& tree) : TreeView(tree) {
158  tree->match(TK_LIST);
159  // Iterate over list to temporarily instantiate Ts that will check the type
160  for (const T& elem : *this) {
161  (void)elem; // silence unused warning
162  }
163  }
164  iterator begin() const {
165  return iterator(tree_->trees().begin());
166  }
167  iterator end() const {
168  return iterator(tree_->trees().end());
169  }
170  bool empty() const {
171  return tree_->trees().begin() == tree_->trees().end();
172  }
173  T operator[](size_t i) const {
174  return T(subtree(i));
175  }
176  TreeRef map(const std::function<TreeRef(const T&)>& fn) {
177  return tree_->map([&](TreeRef v) { return fn(T(v)); });
178  }
179  static List create(const SourceRange& range, const std::vector<T>& subtrees) {
180  TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
181  return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
182  }
183  static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
184  return List(Compound::create(TK_LIST, range, std::move(subtrees)));
185  }
186  size_t size() const {
187  return tree_->trees().size();
188  }
189 };
190 
191 template <typename T>
192 struct Maybe : public TreeView {
193  explicit Maybe(const TreeRef& tree) : TreeView(tree) {
194  tree_->match(TK_OPTION);
195  if (tree_->trees().size() > 1)
196  throw ErrorReport(tree) << "Maybe trees can have at most one subtree";
197  }
198  /* implicit */ Maybe(const T& tree) : TreeView(tree) {}
199  bool present() const {
200  return tree_->trees().size() > 0;
201  }
202  T get() const {
203  return T(tree_->trees().at(0));
204  }
205  TreeRef map(const std::function<TreeRef(const T&)>& fn) {
206  return tree_->map([&](TreeRef v) { return fn(T(v)); });
207  }
208  static Maybe<T> create(const SourceRange& range) {
209  return Maybe<T>(Compound::create(TK_OPTION, range, {}));
210  }
211  static Maybe<T> create(const SourceRange& range, const T& value) {
212  return Maybe<T>(Compound::create(TK_OPTION, range, {value}));
213  }
214 };
215 
216 struct Ident : public TreeView {
217  explicit Ident(const TreeRef& tree) : TreeView(tree) {
218  tree_->match(TK_IDENT);
219  }
220  const std::string& name() const {
221  return subtree(0)->stringValue();
222  }
223  static Ident create(const SourceRange& range, const std::string& name) {
224  return Ident(Compound::create(TK_IDENT, range, {String::create(name)}));
225  }
226 };
227 
229 // Base types (production LHS)
231 
232 struct Stmt : public TreeView {
233  explicit Stmt(const TreeRef& tree) : TreeView(tree) {
234  switch (tree->kind()) {
235  case TK_IF:
236  case TK_FOR:
237  case TK_WHILE:
238  case TK_GLOBAL:
239  case TK_ASSIGN:
240  case TK_AUG_ASSIGN:
241  case TK_RETURN:
242  case TK_EXPR_STMT:
243  case TK_RAISE:
244  case TK_ASSERT:
245  case TK_PASS:
246  case TK_DEF:
247  return;
248  default:
249  throw ErrorReport(tree)
250  << kindToString(tree->kind()) << " is not a valid Stmt";
251  }
252  }
253 };
254 
255 struct Expr : public TreeView {
256  explicit Expr(const TreeRef& tree) : TreeView(tree) {
257  switch (tree->kind()) {
258  case TK_IF_EXPR:
259  case TK_AND:
260  case TK_OR:
261  case '<':
262  case '>':
263  case TK_IS:
264  case TK_ISNOT:
265  case TK_EQ:
266  case TK_LE:
267  case TK_GE:
268  case TK_NE:
269  case '+':
270  case '-':
271  case TK_UNARY_MINUS:
272  case '*':
273  case TK_STARRED:
274  case '/':
275  case '%':
276  case TK_NOT:
277  case TK_CONST:
278  case TK_STRINGLITERAL:
279  case TK_TRUE:
280  case TK_FALSE:
281  case TK_NONE:
282  case TK_CAST:
283  case TK_APPLY:
284  case '.':
285  case TK_SUBSCRIPT:
286  case TK_SLICE_EXPR:
287  case TK_VAR:
288  case TK_LIST_LITERAL:
289  case TK_TUPLE_LITERAL:
290  case TK_DICT_LITERAL:
291  case '@':
292  case TK_POW:
293  case TK_FLOOR_DIV:
294  case '&':
295  case '^':
296  case '|':
297  return;
298  default:
299  throw ErrorReport(tree)
300  << kindToString(tree->kind()) << " is not a valid Expr";
301  }
302  }
303 };
304 
306 // Helper nodes (mostly for function arguments)
308 
309 struct Attribute : public TreeView {
310  explicit Attribute(const TreeRef& tree) : TreeView(tree) {
311  tree_->match(TK_ATTRIBUTE);
312  }
313  Ident name() const {
314  return Ident(subtree(0));
315  }
316  Expr value() const {
317  return Expr(subtree(1));
318  }
319  static Attribute create(
320  const SourceRange& range,
321  const Ident& name,
322  const TreeRef& value) {
323  return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
324  }
325 };
326 
327 struct Param : public TreeView {
328  explicit Param(const TreeRef& tree) : TreeView(tree) {
329  tree_->match(TK_PARAM);
330  }
331  static Param create(
332  const SourceRange& range,
333  const Ident& ident,
334  const Expr& type,
335  const Maybe<Expr>& def,
336  bool kwarg_only) {
337  TreeRef kwarg_only_tree =
338  Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {});
339  return Param(
340  Compound::create(TK_PARAM, range, {ident, type, def, kwarg_only_tree}));
341  }
342  Ident ident() const {
343  return Ident(subtree(0));
344  }
345  Expr type() const {
346  return Expr(subtree(1));
347  }
348  Maybe<Expr> defaultValue() const {
349  return Maybe<Expr>(subtree(2));
350  }
351  bool kwarg_only() const {
352  return TK_TRUE == subtree(3)->kind();
353  }
354  Param withType(const Expr& typ) const {
355  return Param::create(range(), ident(), typ, defaultValue(), kwarg_only());
356  }
357 };
358 
360 // Top level definitions
362 
363 struct Decl : public TreeView {
364  explicit Decl(const TreeRef& tree) : TreeView(tree) {
365  tree->match(TK_DECL);
366  }
367  List<Param> params() const {
368  return List<Param>(subtree(0));
369  }
370  Maybe<Expr> return_type() const {
371  return Maybe<Expr>(subtree(1));
372  }
373  static Decl create(
374  const SourceRange& range,
375  const List<Param>& params,
376  const Maybe<Expr>& return_type) {
377  return Decl(Compound::create(TK_DECL, range, {params, return_type}));
378  }
379 };
380 
381 struct Def : public TreeView {
382  explicit Def(const TreeRef& tree) : TreeView(tree) {
383  tree->match(TK_DEF);
384  }
385  Def withName(std::string new_name) const {
386  auto new_ident = Ident::create(name().range(), std::move(new_name));
387  return create(range(), new_ident, decl(), statements());
388  }
389  Ident name() const {
390  return Ident(subtree(0));
391  }
392  Decl decl() const {
393  return Decl(subtree(1));
394  }
395  List<Stmt> statements() const {
396  return List<Stmt>(subtree(2));
397  }
398  static Def create(
399  const SourceRange& range,
400  const Ident& name,
401  const Decl& decl,
402  const List<Stmt>& stmts) {
403  return Def(Compound::create(TK_DEF, range, {name, decl, stmts}));
404  }
405 };
406 
407 struct ClassDef : public TreeView {
408  explicit ClassDef(const TreeRef& tree) : TreeView(tree) {
409  tree->match(TK_CLASS_DEF);
410  }
411  ClassDef withName(std::string new_name) const {
412  auto new_ident = Ident::create(name().range(), std::move(new_name));
413  return create(range(), new_ident, defs());
414  }
415  Ident name() const {
416  return Ident(subtree(0));
417  }
418  List<Def> defs() const {
419  return List<Def>(subtree(1));
420  }
421  static ClassDef create(
422  const SourceRange& range,
423  const Ident& name,
424  const List<Def>& defs) {
425  return ClassDef(Compound::create(TK_CLASS_DEF, range, {name, defs}));
426  }
427 };
428 
430 // Statements
432 
433 struct If : public Stmt {
434  explicit If(const TreeRef& tree) : Stmt(tree) {
435  tree_->match(TK_IF);
436  }
437  Expr cond() const {
438  return Expr(subtree(0));
439  }
440  List<Stmt> trueBranch() const {
441  return List<Stmt>(subtree(1));
442  }
443  List<Stmt> falseBranch() const {
444  return List<Stmt>(subtree(2));
445  }
446  If withNewBranches(
447  const List<Stmt>& true_branch,
448  const List<Stmt>& false_branch) const {
449  return create(range(), cond(), true_branch, false_branch);
450  }
451  static If create(
452  const SourceRange& range,
453  const Expr& cond,
454  const List<Stmt>& true_branch,
455  const List<Stmt>& false_branch) {
456  return If(
457  Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
458  }
459 };
460 
461 struct While : public Stmt {
462  explicit While(const TreeRef& tree) : Stmt(tree) {
463  tree_->match(TK_WHILE);
464  }
465  Expr cond() const {
466  return Expr(subtree(0));
467  }
468  List<Stmt> body() const {
469  return List<Stmt>(subtree(1));
470  }
471  static While create(
472  const SourceRange& range,
473  const Expr& cond,
474  const List<Stmt>& body) {
475  return While(Compound::create(TK_WHILE, range, {cond, body}));
476  }
477 };
478 
479 struct For : public Stmt {
480  explicit For(const TreeRef& tree) : Stmt(tree) {
481  tree->match(TK_FOR);
482  }
483  List<Expr> targets() const {
484  return List<Expr>(subtree(0));
485  }
486  List<Expr> itrs() const {
487  return List<Expr>(subtree(1));
488  }
489  List<Stmt> body() const {
490  return List<Stmt>(subtree(2));
491  }
492  static For create(
493  const SourceRange& range,
494  const List<Expr>& targets,
495  const List<Expr>& itrs,
496  const List<Stmt>& body) {
497  return For(Compound::create(TK_FOR, range, {targets, itrs, body}));
498  }
499 };
500 
501 struct Global : public Stmt {
502  explicit Global(const TreeRef& tree) : Stmt(tree) {
503  tree_->match(TK_GLOBAL);
504  }
505  List<Ident> names() {
506  return List<Ident>(subtree(0));
507  }
508  static Global create(const SourceRange& range, const List<Ident>& names) {
509  return Global(Compound::create(TK_GLOBAL, range, {names}));
510  }
511 };
512 
513 struct AugAssignKind : public TreeView {
514  explicit AugAssignKind(const TreeRef& tree) : TreeView(tree) {
515  switch (tree->kind()) {
516  case '+':
517  case '-':
518  case '*':
519  case '/':
520  return;
521  default:
522  throw ErrorReport(tree) << "is not a valid AugAssignKind";
523  }
524  }
525 };
526 
527 // Augmented assignment, like "foo += bar"
528 struct AugAssign : public Stmt {
529  explicit AugAssign(const TreeRef& tree) : Stmt(tree) {
530  tree_->match(TK_AUG_ASSIGN);
531  }
532  static AugAssign create(
533  const SourceRange& range,
534  const Expr& lhs,
535  const AugAssignKind& aug_op,
536  const Expr& rhs) {
537  return AugAssign(
538  Compound::create(TK_AUG_ASSIGN, range, {lhs, aug_op, rhs}));
539  }
540  Expr lhs() const {
541  return Expr(subtree(0));
542  }
543  int aug_op() const {
544  return subtree(1)->kind();
545  }
546  Expr rhs() const {
547  return Expr(subtree(2));
548  }
549 };
550 
551 struct Assign : public Stmt {
552  explicit Assign(const TreeRef& tree) : Stmt(tree) {
553  tree_->match(TK_ASSIGN);
554  }
555  static Assign create(
556  const SourceRange& range,
557  const Expr& lhs,
558  const Expr& rhs) {
559  return Assign(Compound::create(TK_ASSIGN, range, {lhs, rhs}));
560  }
561  Expr lhs() const {
562  return Expr(subtree(0));
563  }
564  Expr rhs() const {
565  return Expr(subtree(1));
566  }
567 };
568 
569 struct Return : public Stmt {
570  explicit Return(const TreeRef& tree) : Stmt(tree) {
571  tree_->match(TK_RETURN);
572  }
573  Expr expr() const {
574  return Expr(subtree(0));
575  }
576  static Return create(const SourceRange& range, const Expr& value) {
577  return Return(Compound::create(TK_RETURN, range, {value}));
578  }
579 };
580 
581 struct Raise : public Stmt {
582  explicit Raise(const TreeRef& tree) : Stmt(tree) {
583  tree_->match(TK_RAISE);
584  }
585  Maybe<Expr> expr() const {
586  return Maybe<Expr>(subtree(0));
587  }
588  static Raise create(const SourceRange& range, const Maybe<Expr>& expr) {
589  return Raise(Compound::create(TK_RAISE, range, {expr}));
590  }
591 };
592 
593 struct Assert : public Stmt {
594  explicit Assert(const TreeRef& tree) : Stmt(tree) {
595  tree_->match(TK_ASSERT);
596  }
597  Expr test() const {
598  return Expr(subtree(0));
599  }
600  Maybe<Expr> msg() const {
601  return Maybe<Expr>(subtree(1));
602  }
603  static Assert create(
604  const SourceRange& range,
605  const Expr& test,
606  const Maybe<Expr>& msg) {
607  return Assert(Compound::create(TK_ASSERT, range, {test, msg}));
608  }
609 };
610 
611 struct Pass : public Stmt {
612  explicit Pass(const TreeRef& tree) : Stmt(tree) {
613  tree_->match(TK_PASS);
614  }
615  static Pass create(const SourceRange& range) {
616  return Pass(Compound::create(TK_PASS, range, {}));
617  }
618 };
619 
620 struct ExprStmt : public Stmt {
621  explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
622  tree_->match(TK_EXPR_STMT);
623  }
624  Expr expr() {
625  return Expr(subtree(0));
626  }
627  static ExprStmt create(const SourceRange& range, const Expr& list) {
628  return ExprStmt(Compound::create(TK_EXPR_STMT, range, {list}));
629  }
630 };
631 
633 // Expressions
635 
636 struct BinOp : public Expr {
637  explicit BinOp(const TreeRef& tree) : Expr(tree) {
638  switch (tree->kind()) {
639  case TK_AND:
640  case TK_OR:
641  case '<':
642  case '>':
643  case TK_IS:
644  case TK_ISNOT:
645  case TK_EQ:
646  case TK_LE:
647  case TK_GE:
648  case TK_NE:
649  case '+':
650  case '*':
651  case '/':
652  case '-':
653  case '@':
654  case TK_POW:
655  case '%':
656  case '&':
657  case '^':
658  case '|':
659  case TK_FLOOR_DIV:
660  if (tree->trees().size() != 2)
661  throw ErrorReport(tree)
662  << "BinOp expected 2 subtrees, found " << tree->trees().size();
663  return;
664  default:
665  throw ErrorReport(tree)
666  << kindToString(tree->kind()) << " is not a valid BinOp";
667  }
668  }
669  Expr lhs() const {
670  return Expr(subtree(0));
671  }
672  Expr rhs() const {
673  return Expr(subtree(1));
674  }
675  static BinOp create(
676  const SourceRange& range,
677  int kind,
678  const Expr& lhs,
679  const Expr& rhs) {
680  return BinOp(Compound::create(kind, range, {lhs, rhs}));
681  }
682 };
683 
684 struct UnaryOp : public Expr {
685  explicit UnaryOp(const TreeRef& tree) : Expr(tree) {
686  switch (tree->kind()) {
687  case TK_UNARY_MINUS:
688  case TK_NOT:
689  if (tree->trees().size() != 1)
690  throw ErrorReport(tree)
691  << "UnaryOp expected 1 subtree, found " << tree->trees().size();
692  return;
693  default:
694  throw ErrorReport(tree)
695  << kindToString(tree->kind()) << " is not a valid UnaryOp";
696  }
697  }
698  static UnaryOp create(const SourceRange& range, int kind, const Expr& expr) {
699  return UnaryOp(Compound::create(kind, range, {expr}));
700  }
701 };
702 
703 struct Const : public Expr {
704  explicit Const(const TreeRef& tree) : Expr(tree) {
705  tree_->matchNumSubtrees(TK_CONST, 1);
706  }
707  bool isFloatingPoint() const {
708  return subtree(0)->stringValue().find_first_of(".eE") != std::string::npos;
709  }
710  bool isIntegral() const {
711  return !isFloatingPoint();
712  }
713  int64_t asIntegral() const {
714  return std::stoll(subtree(0)->stringValue());
715  }
716  double asFloatingPoint() const {
718  subtree(0)->stringValue().c_str(), nullptr);
719  }
720  const std::string& text() const {
721  return subtree(0)->stringValue();
722  }
723  static Const create(const SourceRange& range, const std::string& value) {
724  return Const(Compound::create(TK_CONST, range, {String::create(value)}));
725  }
726 };
727 
728 struct StringLiteral : public Expr {
729  explicit StringLiteral(const TreeRef& tree) : Expr(tree) {
730  tree_->matchNumSubtrees(TK_STRINGLITERAL, 1);
731  }
732  const std::string& text() const {
733  return subtree(0)->stringValue();
734  }
735  static StringLiteral create(
736  const SourceRange& range,
737  const std::string& value) {
738  return StringLiteral(
739  Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
740  }
741 };
742 
743 struct Apply : public Expr {
744  explicit Apply(const TreeRef& tree) : Expr(tree) {
745  tree_->match(TK_APPLY);
746  }
747  Expr callee() const {
748  return Expr(subtree(0));
749  }
750  List<Expr> inputs() const {
751  return List<Expr>(subtree(1));
752  }
753  List<Attribute> attributes() const {
754  return List<Attribute>(subtree(2));
755  }
756  static Apply create(
757  const SourceRange& range,
758  const Expr& callee,
759  const List<Expr>& inputs,
760  const List<Attribute>& attributes) {
761  return Apply(
762  Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
763  }
764 };
765 
766 struct Select : public Expr {
767  explicit Select(const TreeRef& tree) : Expr(tree) {
768  tree_->match('.');
769  }
770  Expr value() const {
771  return Expr(subtree(0));
772  }
773  Ident selector() const {
774  return Ident(subtree(1));
775  }
776  static Select create(
777  const SourceRange& range,
778  const Expr& value,
779  const Ident& selector) {
780  return Select(Compound::create('.', range, {value, selector}));
781  }
782 };
783 
784 struct SliceExpr : public Expr {
785  explicit SliceExpr(const TreeRef& tree) : Expr(tree) {
786  tree_->match(TK_SLICE_EXPR);
787  }
788  Maybe<Expr> start() const {
789  return Maybe<Expr>(subtree(0));
790  }
791  Maybe<Expr> end() const {
792  return Maybe<Expr>(subtree(1));
793  }
794  Expr startOr(int alternative) const {
795  const auto startOption = start();
796  return startOption.present() ? startOption.get() : createInt(alternative);
797  }
798  Expr endOr(int alternative) const {
799  const auto endOption = end();
800  return endOption.present() ? endOption.get() : createInt(alternative);
801  }
802  static SliceExpr create(
803  const SourceRange& range,
804  const Maybe<Expr>& start,
805  const Maybe<Expr>& end) {
806  return SliceExpr(Compound::create(TK_SLICE_EXPR, range, {start, end}));
807  }
808 
809  private:
810  Expr createInt(int value) const {
811  return Expr(Const::create(range(), std::to_string(value)));
812  }
813 };
814 
815 struct Subscript : public Expr {
816  explicit Subscript(const TreeRef& tree) : Expr(tree) {
817  tree_->match(TK_SUBSCRIPT);
818  }
819  Expr value() const {
820  return Expr(subtree(0));
821  }
822  List<Expr> subscript_exprs() const {
823  return List<Expr>(subtree(1));
824  }
825  static Subscript create(
826  const SourceRange& range,
827  const Expr& value,
828  const List<Expr>& subscript_exprs) {
829  return Subscript(
830  Compound::create(TK_SUBSCRIPT, range, {value, subscript_exprs}));
831  }
832 };
833 
834 struct Var : public Expr {
835  explicit Var(const TreeRef& tree) : Expr(tree) {
836  tree_->match(TK_VAR);
837  };
838  Ident name() const {
839  return Ident(subtree(0));
840  }
841  static Var create(const SourceRange& range, const Ident& name) {
842  return Var(Compound::create(TK_VAR, range, {name}));
843  }
844 };
845 
846 struct TernaryIf : public Expr {
847  explicit TernaryIf(const TreeRef& tree) : Expr(tree) {
848  tree_->matchNumSubtrees(TK_IF_EXPR, 3);
849  };
850  Expr cond() const {
851  return Expr(subtree(0));
852  }
853  Expr true_expr() const {
854  return Expr(subtree(1));
855  }
856  Expr false_expr() const {
857  return Expr(subtree(2));
858  }
859  static TernaryIf create(
860  const SourceRange& range,
861  const Expr& cond,
862  const Expr& true_expr,
863  const Expr& false_expr) {
864  return TernaryIf(
865  Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
866  };
867 };
868 
869 struct ListLiteral : public Expr {
870  explicit ListLiteral(const TreeRef& tree) : Expr(tree) {
871  tree_->match(TK_LIST_LITERAL);
872  }
873  List<Expr> inputs() const {
874  return subtree(0);
875  }
876  static ListLiteral create(
877  const SourceRange& range,
878  const List<Expr>& inputs) {
879  return ListLiteral(Compound::create(TK_LIST_LITERAL, range, {inputs}));
880  }
881 };
882 
883 struct TupleLiteral : public Expr {
884  explicit TupleLiteral(const TreeRef& tree) : Expr(tree) {
885  tree_->match(TK_TUPLE_LITERAL);
886  }
887  List<Expr> inputs() const {
888  return subtree(0);
889  }
890  static TupleLiteral create(
891  const SourceRange& range,
892  const List<Expr>& inputs) {
893  return TupleLiteral(Compound::create(TK_TUPLE_LITERAL, range, {inputs}));
894  }
895 };
896 
897 struct DictLiteral : public Expr {
898  explicit DictLiteral(const TreeRef& tree) : Expr(tree) {
899  tree_->match(TK_DICT_LITERAL);
900  }
901  List<Expr> key_inputs() const {
902  return subtree(0);
903  }
904  List<Expr> value_inputs() const {
905  return subtree(1);
906  }
907  static DictLiteral create(
908  const SourceRange& range,
909  const List<Expr>& keys,
910  const List<Expr>& values) {
911  return DictLiteral(
912  Compound::create(TK_DICT_LITERAL, range, {keys, values}));
913  }
914 };
915 
916 struct Starred : public Expr {
917  explicit Starred(const TreeRef& tree) : Expr(tree) {
918  tree_->match(TK_STARRED);
919  }
920  Expr expr() const {
921  return Expr(subtree(0));
922  }
923  static Starred create(const SourceRange& range, const Expr& expr) {
924  return Starred(Compound::create(TK_STARRED, range, {expr}));
925  }
926 };
927 
928 } // namespace script
929 } // namespace jit
930 } // namespace torch
931 
932 namespace std {
933 
934 template <typename T>
935 struct iterator_traits<torch::jit::script::ListIterator<T>>
936  : std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
937 
938 } // namespace std
Definition: module.cpp:17
static double strtod_c(const char *str, char **end)
Definition: lexer.h:169
Definition: jit_type.h:17