2 #include <torch/csrc/jit/script/error_report.h> 3 #include <torch/csrc/jit/script/tree.h> 96 explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
97 TreeRef tree()
const {
101 return tree_->range();
103 operator TreeRef()
const {
106 const TreeRef&
get()
const {
110 return tree_->kind();
117 const TreeRef& subtree(
size_t i)
const {
118 return tree_->trees().at(i);
123 template <
typename T>
132 T operator*()
const {
149 TreeList::const_iterator it;
152 template <
typename T>
158 tree->match(TK_LIST);
160 for (
const T& elem : *
this) {
165 return iterator(tree_->trees().begin());
168 return iterator(tree_->trees().end());
171 return tree_->trees().begin() == tree_->trees().end();
173 T operator[](
size_t i)
const {
174 return T(subtree(i));
176 TreeRef map(
const std::function<TreeRef(
const T&)>& fn) {
177 return tree_->map([&](TreeRef v) {
return fn(
T(v)); });
179 static List create(
const SourceRange& range,
const std::vector<T>& subtrees) {
180 TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
181 return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
183 static List unsafeCreate(
const SourceRange& range, TreeList&& subtrees) {
184 return List(Compound::create(TK_LIST, range, std::move(subtrees)));
186 size_t size()
const {
187 return tree_->trees().size();
191 template <
typename T>
194 tree_->match(TK_OPTION);
195 if (tree_->trees().size() > 1)
196 throw ErrorReport(tree) <<
"Maybe trees can have at most one subtree";
199 bool present()
const {
200 return tree_->trees().size() > 0;
203 return T(tree_->trees().at(0));
205 TreeRef map(
const std::function<TreeRef(
const T&)>& fn) {
206 return tree_->map([&](TreeRef v) {
return fn(
T(v)); });
209 return Maybe<T>(Compound::create(TK_OPTION, range, {}));
212 return Maybe<T>(Compound::create(TK_OPTION, range, {value}));
218 tree_->match(TK_IDENT);
220 const std::string& name()
const {
221 return subtree(0)->stringValue();
224 return Ident(Compound::create(TK_IDENT, range, {String::create(name)}));
234 switch (tree->kind()) {
250 << kindToString(tree->kind()) <<
" is not a valid Stmt";
257 switch (tree->kind()) {
278 case TK_STRINGLITERAL:
288 case TK_LIST_LITERAL:
289 case TK_TUPLE_LITERAL:
290 case TK_DICT_LITERAL:
300 << kindToString(tree->kind()) <<
" is not a valid Expr";
311 tree_->match(TK_ATTRIBUTE);
314 return Ident(subtree(0));
317 return Expr(subtree(1));
322 const TreeRef& value) {
323 return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
329 tree_->match(TK_PARAM);
337 TreeRef kwarg_only_tree =
338 Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {});
340 Compound::create(TK_PARAM, range, {ident, type, def, kwarg_only_tree}));
342 Ident ident()
const {
343 return Ident(subtree(0));
346 return Expr(subtree(1));
351 bool kwarg_only()
const {
352 return TK_TRUE == subtree(3)->kind();
355 return Param::create(range(), ident(), typ, defaultValue(), kwarg_only());
365 tree->match(TK_DECL);
377 return Decl(Compound::create(TK_DECL, range, {params, return_type}));
382 explicit Def(
const TreeRef& tree) :
TreeView(tree) {
385 Def withName(std::string new_name)
const {
386 auto new_ident = Ident::create(name().range(), std::move(new_name));
387 return create(range(), new_ident, decl(), statements());
390 return Ident(subtree(0));
393 return Decl(subtree(1));
403 return Def(Compound::create(TK_DEF, range, {name, decl, stmts}));
409 tree->match(TK_CLASS_DEF);
411 ClassDef withName(std::string new_name)
const {
412 auto new_ident = Ident::create(name().range(), std::move(new_name));
413 return create(range(), new_ident, defs());
416 return Ident(subtree(0));
425 return ClassDef(Compound::create(TK_CLASS_DEF, range, {name, defs}));
434 explicit If(
const TreeRef& tree) :
Stmt(tree) {
438 return Expr(subtree(0));
449 return create(range(), cond(), true_branch, false_branch);
457 Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
462 explicit While(
const TreeRef& tree) :
Stmt(tree) {
463 tree_->match(TK_WHILE);
466 return Expr(subtree(0));
475 return While(Compound::create(TK_WHILE, range, {cond, body}));
480 explicit For(
const TreeRef& tree) :
Stmt(tree) {
497 return For(Compound::create(TK_FOR, range, {targets, itrs, body}));
502 explicit Global(
const TreeRef& tree) :
Stmt(tree) {
503 tree_->match(TK_GLOBAL);
509 return Global(Compound::create(TK_GLOBAL, range, {names}));
515 switch (tree->kind()) {
522 throw ErrorReport(tree) <<
"is not a valid AugAssignKind";
530 tree_->match(TK_AUG_ASSIGN);
538 Compound::create(TK_AUG_ASSIGN, range, {lhs, aug_op, rhs}));
541 return Expr(subtree(0));
544 return subtree(1)->kind();
547 return Expr(subtree(2));
552 explicit Assign(
const TreeRef& tree) :
Stmt(tree) {
553 tree_->match(TK_ASSIGN);
559 return Assign(Compound::create(TK_ASSIGN, range, {lhs, rhs}));
562 return Expr(subtree(0));
565 return Expr(subtree(1));
570 explicit Return(
const TreeRef& tree) :
Stmt(tree) {
571 tree_->match(TK_RETURN);
574 return Expr(subtree(0));
577 return Return(Compound::create(TK_RETURN, range, {value}));
582 explicit Raise(
const TreeRef& tree) :
Stmt(tree) {
583 tree_->match(TK_RAISE);
589 return Raise(Compound::create(TK_RAISE, range, {expr}));
594 explicit Assert(
const TreeRef& tree) :
Stmt(tree) {
595 tree_->match(TK_ASSERT);
598 return Expr(subtree(0));
607 return Assert(Compound::create(TK_ASSERT, range, {test, msg}));
612 explicit Pass(
const TreeRef& tree) :
Stmt(tree) {
613 tree_->match(TK_PASS);
616 return Pass(Compound::create(TK_PASS, range, {}));
622 tree_->match(TK_EXPR_STMT);
625 return Expr(subtree(0));
628 return ExprStmt(Compound::create(TK_EXPR_STMT, range, {list}));
637 explicit BinOp(
const TreeRef& tree) :
Expr(tree) {
638 switch (tree->kind()) {
660 if (tree->trees().size() != 2)
662 <<
"BinOp expected 2 subtrees, found " << tree->trees().size();
666 << kindToString(tree->kind()) <<
" is not a valid BinOp";
670 return Expr(subtree(0));
673 return Expr(subtree(1));
680 return BinOp(Compound::create(kind, range, {lhs, rhs}));
685 explicit UnaryOp(
const TreeRef& tree) :
Expr(tree) {
686 switch (tree->kind()) {
689 if (tree->trees().size() != 1)
691 <<
"UnaryOp expected 1 subtree, found " << tree->trees().size();
695 << kindToString(tree->kind()) <<
" is not a valid UnaryOp";
699 return UnaryOp(Compound::create(kind, range, {expr}));
704 explicit Const(
const TreeRef& tree) :
Expr(tree) {
705 tree_->matchNumSubtrees(TK_CONST, 1);
707 bool isFloatingPoint()
const {
708 return subtree(0)->stringValue().find_first_of(
".eE") != std::string::npos;
710 bool isIntegral()
const {
711 return !isFloatingPoint();
713 int64_t asIntegral()
const {
714 return std::stoll(subtree(0)->stringValue());
716 double asFloatingPoint()
const {
718 subtree(0)->stringValue().c_str(),
nullptr);
720 const std::string& text()
const {
721 return subtree(0)->stringValue();
724 return Const(Compound::create(TK_CONST, range, {String::create(value)}));
730 tree_->matchNumSubtrees(TK_STRINGLITERAL, 1);
732 const std::string& text()
const {
733 return subtree(0)->stringValue();
737 const std::string& value) {
739 Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
744 explicit Apply(
const TreeRef& tree) :
Expr(tree) {
745 tree_->match(TK_APPLY);
747 Expr callee()
const {
748 return Expr(subtree(0));
762 Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
767 explicit Select(
const TreeRef& tree) :
Expr(tree) {
771 return Expr(subtree(0));
773 Ident selector()
const {
774 return Ident(subtree(1));
779 const Ident& selector) {
780 return Select(Compound::create(
'.', range, {value, selector}));
786 tree_->match(TK_SLICE_EXPR);
794 Expr startOr(
int alternative)
const {
795 const auto startOption = start();
796 return startOption.present() ? startOption.get() : createInt(alternative);
798 Expr endOr(
int alternative)
const {
799 const auto endOption = end();
800 return endOption.present() ? endOption.get() : createInt(alternative);
806 return SliceExpr(Compound::create(TK_SLICE_EXPR, range, {start, end}));
810 Expr createInt(
int value)
const {
811 return Expr(Const::create(range(), std::to_string(value)));
817 tree_->match(TK_SUBSCRIPT);
820 return Expr(subtree(0));
830 Compound::create(TK_SUBSCRIPT, range, {value, subscript_exprs}));
835 explicit Var(
const TreeRef& tree) :
Expr(tree) {
836 tree_->match(TK_VAR);
839 return Ident(subtree(0));
842 return Var(Compound::create(TK_VAR, range, {name}));
848 tree_->matchNumSubtrees(TK_IF_EXPR, 3);
851 return Expr(subtree(0));
853 Expr true_expr()
const {
854 return Expr(subtree(1));
856 Expr false_expr()
const {
857 return Expr(subtree(2));
862 const Expr& true_expr,
863 const Expr& false_expr) {
865 Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
871 tree_->match(TK_LIST_LITERAL);
879 return ListLiteral(Compound::create(TK_LIST_LITERAL, range, {inputs}));
885 tree_->match(TK_TUPLE_LITERAL);
893 return TupleLiteral(Compound::create(TK_TUPLE_LITERAL, range, {inputs}));
899 tree_->match(TK_DICT_LITERAL);
912 Compound::create(TK_DICT_LITERAL, range, {keys, values}));
917 explicit Starred(
const TreeRef& tree) :
Expr(tree) {
918 tree_->match(TK_STARRED);
921 return Expr(subtree(0));
924 return Starred(Compound::create(TK_STARRED, range, {expr}));
934 template <
typename T>
935 struct iterator_traits<
torch::jit::script::ListIterator<T>>
936 : std::iterator_traits<torch::jit::script::TreeList::const_iterator> {};
static double strtod_c(const char *str, char **end)