1 #include <torch/csrc/jit/script/python_tree_views.h> 3 #include <torch/csrc/jit/script/compiler.h> 4 #include <torch/csrc/jit/script/tree_views.h> 6 #include <pybind11/pybind11.h> 7 #include <pybind11/stl.h> 19 : source_(std::make_shared<std::string>(std::move(source))) {
22 line_len_prefix_sum_.push_back(pos);
24 }
while ((pos = source_->find(
'\n', pos)) != std::string::npos);
26 SourceRange create(
int line,
int start_col,
int end_col) {
33 auto line_start = line_len_prefix_sum_.at(line);
34 return SourceRange(source_, line_start + start_col, line_start + end_col);
37 std::shared_ptr<std::string> source_;
38 std::vector<size_t> line_len_prefix_sum_;
54 void initTreeViewBindings(PyObject* module) {
55 auto _C = py::handle(module).cast<py::module>();
56 auto m = _C.def_submodule(
"_jit_tree_views");
58 py::class_<SourceRange>(m,
"SourceRange")
62 std::ostringstream stream;
63 self.highlight(stream);
66 .def_property_readonly(
"start", &SourceRange::start)
67 .def_property_readonly(
"end", &SourceRange::end);
68 py::class_<SourceRangeFactory>(m,
"SourceRangeFactory")
69 .def(py::init<std::string&&>())
70 .def(
"make_range", &SourceRangeFactory::create)
80 py::class_<TreeView>(m,
"TreeView")
81 .def(
"range", &TreeView::range)
85 std::ostringstream stream;
89 .def(
"dump", [](
const TreeView& tree) { tree.dump(); });
91 py::class_<Ident, TreeView>(m,
"Ident")
92 .def(py::init(&Ident::create))
93 .def_property_readonly(
94 "name", [](
const Ident&
self) {
return self.name(); });
96 py::class_<Param, TreeView>(m,
"Param")
97 .def(py::init([](
const Expr& type,
const Ident& name,
bool kwarg_only) {
105 py::class_<Attribute, TreeView>(m,
"Attribute")
106 .def(py::init([](
const Ident& name,
const Expr& value) {
107 return Attribute::create(name.range(), name, value);
109 m.def(
"TrueLiteral", [](
const SourceRange& range) {
110 return Expr(Compound::create(TK_TRUE, range, {}));
112 m.def(
"FalseLiteral", [](
const SourceRange& range) {
113 return Expr(Compound::create(TK_FALSE, range, {}));
115 m.def(
"NoneLiteral", [](
const SourceRange& range) {
116 return Expr(Compound::create(TK_NONE, range, {}));
119 py::class_<Stmt, TreeView>(m,
"Stmt");
120 py::class_<Expr, TreeView>(m,
"Expr");
121 py::class_<Def, TreeView>(m,
"Def").def(
122 py::init([](
const Ident& name,
Decl decl, std::vector<Stmt> body) {
123 const auto& r = name.range();
124 return Def::create(r, name, decl, wrap_list(r, std::move(body)));
126 py::class_<ClassDef, TreeView>(m,
"ClassDef")
127 .def(py::init([](
const Ident& name, std::vector<Def> body) {
128 const auto& r = name.range();
129 return ClassDef::create(r, name, wrap_list(r, std::move(body)));
131 py::class_<Decl, TreeView>(m,
"Decl").def(py::init(
132 [](
const SourceRange& r, std::vector<Param> params,
Expr* return_type) {
134 r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
137 py::class_<Assign, Stmt>(m,
"Assign")
138 .def(py::init([](
const Expr& lhs,
const Expr& rhs) {
139 return Assign::create(lhs.range(), lhs, rhs);
141 py::class_<AugAssign, Stmt>(m,
"AugAssign")
142 .def(py::init([](
const Expr& lhs, std::string kind_str,
const Expr& rhs) {
143 const auto& r = lhs.range();
145 AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
146 return AugAssign::create(r, lhs, kind, rhs);
148 py::class_<Return, Stmt>(m,
"Return")
150 return Return::create(
151 range, value ? *value :
Expr(Compound::create(TK_NONE, range, {})));
153 py::class_<Raise, Stmt>(m,
"Raise")
155 return Raise::create(range, wrap_maybe(range, expr));
157 py::class_<Assert, Stmt>(m,
"Assert")
159 return Assert::create(range, test, wrap_maybe(range, msg));
161 py::class_<Pass, Stmt>(m,
"Pass").def(
162 py::init([](
const SourceRange& range) {
return Pass::create(range); }));
163 py::class_<If, Stmt>(m,
"If").def(
166 std::vector<Stmt> true_branch,
167 std::vector<Stmt> false_branch) {
171 wrap_list(range, std::move(true_branch)),
172 wrap_list(range, std::move(false_branch)));
174 py::class_<While, Stmt>(m,
"While")
177 std::vector<Stmt> body) {
178 return While::create(range, cond, wrap_list(range, std::move(body)));
180 py::class_<For, Stmt>(m,
"For").def(py::init([](
const SourceRange range,
181 std::vector<Expr>& targets,
182 std::vector<Expr>& itrs,
183 std::vector<Stmt> body) {
186 wrap_list(range, std::move(targets)),
187 wrap_list(range, std::move(itrs)),
188 wrap_list(range, std::move(body)));
190 py::class_<ExprStmt, Stmt>(m,
"ExprStmt").def(py::init([](
const Expr& expr) {
191 return ExprStmt::create(expr.range(), expr);
194 py::class_<Var, Expr>(m,
"Var")
196 [](
const Ident& name) {
return Var::create(name.range(), name); }))
197 .def_property_readonly(
"name", [](
const Var& var) {
return var.name(); });
198 py::class_<BinOp, Expr>(m,
"BinOp")
199 .def(py::init([](std::string kind,
const Expr& lhs,
const Expr& rhs) {
200 return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
204 py::class_<UnaryOp, Expr>(m,
"UnaryOp")
207 auto resolved_kind = stringToKind(kind);
209 resolved_kind ==
'-' ? TK_UNARY_MINUS : resolved_kind;
210 return UnaryOp::create(range, resolved_kind, expr);
212 py::class_<Const, Expr>(m,
"Const")
213 .def(py::init([](
const SourceRange& range, std::string value) {
214 return Const::create(range, value);
216 py::class_<StringLiteral, Expr>(m,
"StringLiteral")
217 .def(py::init([](
const SourceRange& range, std::string value) {
218 return StringLiteral::create(range, value);
220 py::class_<Apply, Expr>(m,
"Apply")
221 .def(py::init([](
const Expr& expr,
222 std::vector<Expr> args,
223 std::vector<Attribute> kwargs) {
224 const auto& r = expr.range();
225 return Apply::create(
228 wrap_list(r, std::move(args)),
229 wrap_list(r, std::move(kwargs)));
231 py::class_<Select, Expr>(m,
"Select")
232 .def(py::init([](
const Expr& expr,
const Ident& field) {
233 return Select::create(expr.range(), expr, field);
235 py::class_<TernaryIf, Expr>(m,
"TernaryIf")
237 [](
const Expr& cond,
const Expr& true_expr,
const Expr& false_expr) {
238 return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
240 py::class_<ListLiteral, Expr>(m,
"ListLiteral")
241 .def(py::init([](
const SourceRange& range, std::vector<Expr> args) {
242 return ListLiteral::create(range, wrap_list(range, std::move(args)));
244 py::class_<TupleLiteral, Expr>(m,
"TupleLiteral")
245 .def(py::init([](
const SourceRange& range, std::vector<Expr> args) {
246 return TupleLiteral::create(range, wrap_list(range, std::move(args)));
248 py::class_<DictLiteral, Expr>(m,
"DictLiteral")
250 std::vector<Expr> keys,
251 std::vector<Expr> values) {
252 return DictLiteral::create(
254 wrap_list(range, std::move(keys)),
255 wrap_list(range, std::move(values)));
257 py::class_<Subscript, Expr>(m,
"Subscript")
258 .def(py::init([](
const Expr& base, std::vector<Expr> subscript_exprs) {
259 return Subscript::create(
262 wrap_list(base.range(), std::move(subscript_exprs)));
264 py::class_<SliceExpr, Expr>(m,
"SliceExpr")
266 return SliceExpr::create(
267 range, wrap_maybe(range, lower), wrap_maybe(range, upper));
269 py::class_<Starred, Expr>(m,
"Starred")
271 return Starred::create(range, expr);