Caffe2 - C++ API
A deep learning, cross platform ML framework
python_tree_views.cpp
1 #include <torch/csrc/jit/script/python_tree_views.h>
2 
3 #include <torch/csrc/jit/script/compiler.h>
4 #include <torch/csrc/jit/script/tree_views.h>
5 
6 #include <pybind11/pybind11.h>
7 #include <pybind11/stl.h>
8 
9 #include <sstream>
10 
11 namespace py = pybind11;
12 
13 namespace torch {
14 namespace jit {
15 namespace script {
16 
18  SourceRangeFactory(std::string source)
19  : source_(std::make_shared<std::string>(std::move(source))) {
20  size_t pos = 0;
21  do {
22  line_len_prefix_sum_.push_back(pos);
23  pos++;
24  } while ((pos = source_->find('\n', pos)) != std::string::npos);
25  }
26  SourceRange create(int line, int start_col, int end_col) {
27  // Python has a weird convention where col_offset points to the column
28  // *before* the token starts.
29  start_col++;
30  end_col++;
31  // Also, lines are counted from 1.
32  line--;
33  auto line_start = line_len_prefix_sum_.at(line);
34  return SourceRange(source_, line_start + start_col, line_start + end_col);
35  }
36 
37  std::shared_ptr<std::string> source_;
38  std::vector<size_t> line_len_prefix_sum_;
39 };
40 
41 template <typename T>
42 List<T> wrap_list(const SourceRange& fallback_pos, std::vector<T>&& vec) {
43  if (vec.empty())
44  return List<T>::create(fallback_pos, std::move(vec));
45  return List<T>::create(vec.front().range(), std::move(vec));
46 }
47 
48 template <typename T>
49 Maybe<T> wrap_maybe(const SourceRange& fallback_pos, T* val) {
50  return val ? Maybe<T>::create(val->range(), *val)
51  : Maybe<T>::create(fallback_pos);
52 }
53 
54 void initTreeViewBindings(PyObject* module) {
55  auto _C = py::handle(module).cast<py::module>();
56  auto m = _C.def_submodule("_jit_tree_views");
57 
58  py::class_<SourceRange>(m, "SourceRange")
59  .def(
60  "highlight",
61  [](const SourceRange& self) {
62  std::ostringstream stream;
63  self.highlight(stream);
64  return stream.str();
65  })
66  .def_property_readonly("start", &SourceRange::start)
67  .def_property_readonly("end", &SourceRange::end);
68  py::class_<SourceRangeFactory>(m, "SourceRangeFactory")
69  .def(py::init<std::string&&>())
70  .def("make_range", &SourceRangeFactory::create)
71  .def(
72  "make_raw_range",
73  [](const SourceRangeFactory& self, size_t start, size_t end) {
74  return SourceRange(self.source_, start, end);
75  })
76  .def_property_readonly("source", [](const SourceRangeFactory& self) {
77  return *self.source_;
78  });
79 
80  py::class_<TreeView>(m, "TreeView")
81  .def("range", &TreeView::range)
82  .def(
83  "__str__",
84  [](const TreeView& tree) {
85  std::ostringstream stream;
86  stream << tree.get();
87  return stream.str();
88  })
89  .def("dump", [](const TreeView& tree) { tree.dump(); });
90 
91  py::class_<Ident, TreeView>(m, "Ident")
92  .def(py::init(&Ident::create))
93  .def_property_readonly(
94  "name", [](const Ident& self) { return self.name(); });
95 
96  py::class_<Param, TreeView>(m, "Param")
97  .def(py::init([](const Expr& type, const Ident& name, bool kwarg_only) {
98  return Param::create(
99  name.range(),
100  name,
101  type,
102  Maybe<Expr>::create(name.range()),
103  kwarg_only);
104  }));
105  py::class_<Attribute, TreeView>(m, "Attribute")
106  .def(py::init([](const Ident& name, const Expr& value) {
107  return Attribute::create(name.range(), name, value);
108  }));
109  m.def("TrueLiteral", [](const SourceRange& range) {
110  return Expr(Compound::create(TK_TRUE, range, {}));
111  });
112  m.def("FalseLiteral", [](const SourceRange& range) {
113  return Expr(Compound::create(TK_FALSE, range, {}));
114  });
115  m.def("NoneLiteral", [](const SourceRange& range) {
116  return Expr(Compound::create(TK_NONE, range, {}));
117  });
118 
119  py::class_<Stmt, TreeView>(m, "Stmt"); // NOLINT(bugprone-unused-raii)
120  py::class_<Expr, TreeView>(m, "Expr"); // NOLINT(bugprone-unused-raii)
121  py::class_<Def, TreeView>(m, "Def").def(
122  py::init([](const Ident& name, Decl decl, std::vector<Stmt> body) {
123  const auto& r = name.range();
124  return Def::create(r, name, decl, wrap_list(r, std::move(body)));
125  }));
126  py::class_<ClassDef, TreeView>(m, "ClassDef")
127  .def(py::init([](const Ident& name, std::vector<Def> body) {
128  const auto& r = name.range();
129  return ClassDef::create(r, name, wrap_list(r, std::move(body)));
130  }));
131  py::class_<Decl, TreeView>(m, "Decl").def(py::init(
132  [](const SourceRange& r, std::vector<Param> params, Expr* return_type) {
133  return Decl::create(
134  r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
135  }));
136 
137  py::class_<Assign, Stmt>(m, "Assign")
138  .def(py::init([](const Expr& lhs, const Expr& rhs) {
139  return Assign::create(lhs.range(), lhs, rhs);
140  }));
141  py::class_<AugAssign, Stmt>(m, "AugAssign")
142  .def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) {
143  const auto& r = lhs.range();
144  auto kind =
145  AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
146  return AugAssign::create(r, lhs, kind, rhs);
147  }));
148  py::class_<Return, Stmt>(m, "Return")
149  .def(py::init([](const SourceRange& range, Expr* value) {
150  return Return::create(
151  range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
152  }));
153  py::class_<Raise, Stmt>(m, "Raise")
154  .def(py::init([](const SourceRange& range, Expr* expr) {
155  return Raise::create(range, wrap_maybe(range, expr));
156  }));
157  py::class_<Assert, Stmt>(m, "Assert")
158  .def(py::init([](const SourceRange& range, const Expr& test, Expr* msg) {
159  return Assert::create(range, test, wrap_maybe(range, msg));
160  }));
161  py::class_<Pass, Stmt>(m, "Pass").def(
162  py::init([](const SourceRange& range) { return Pass::create(range); }));
163  py::class_<If, Stmt>(m, "If").def(
164  py::init([](const SourceRange& range,
165  const Expr& cond,
166  std::vector<Stmt> true_branch,
167  std::vector<Stmt> false_branch) {
168  return If::create(
169  range,
170  cond,
171  wrap_list(range, std::move(true_branch)),
172  wrap_list(range, std::move(false_branch)));
173  }));
174  py::class_<While, Stmt>(m, "While")
175  .def(py::init([](const SourceRange& range,
176  const Expr& cond,
177  std::vector<Stmt> body) {
178  return While::create(range, cond, wrap_list(range, std::move(body)));
179  }));
180  py::class_<For, Stmt>(m, "For").def(py::init([](const SourceRange range,
181  std::vector<Expr>& targets,
182  std::vector<Expr>& itrs,
183  std::vector<Stmt> body) {
184  return For::create(
185  range,
186  wrap_list(range, std::move(targets)),
187  wrap_list(range, std::move(itrs)),
188  wrap_list(range, std::move(body)));
189  }));
190  py::class_<ExprStmt, Stmt>(m, "ExprStmt").def(py::init([](const Expr& expr) {
191  return ExprStmt::create(expr.range(), expr);
192  }));
193 
194  py::class_<Var, Expr>(m, "Var")
195  .def(py::init(
196  [](const Ident& name) { return Var::create(name.range(), name); }))
197  .def_property_readonly("name", [](const Var& var) { return var.name(); });
198  py::class_<BinOp, Expr>(m, "BinOp")
199  .def(py::init([](std::string kind, const Expr& lhs, const Expr& rhs) {
200  return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
201  }));
202  // NB: we take range here, because unary ops precede their exprs, so we need
203  // to include them
204  py::class_<UnaryOp, Expr>(m, "UnaryOp")
205  .def(py::init(
206  [](const SourceRange& range, std::string kind, const Expr& expr) {
207  auto resolved_kind = stringToKind(kind);
208  resolved_kind =
209  resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind;
210  return UnaryOp::create(range, resolved_kind, expr);
211  }));
212  py::class_<Const, Expr>(m, "Const")
213  .def(py::init([](const SourceRange& range, std::string value) {
214  return Const::create(range, value);
215  }));
216  py::class_<StringLiteral, Expr>(m, "StringLiteral")
217  .def(py::init([](const SourceRange& range, std::string value) {
218  return StringLiteral::create(range, value);
219  }));
220  py::class_<Apply, Expr>(m, "Apply")
221  .def(py::init([](const Expr& expr,
222  std::vector<Expr> args,
223  std::vector<Attribute> kwargs) {
224  const auto& r = expr.range();
225  return Apply::create(
226  expr.range(),
227  expr,
228  wrap_list(r, std::move(args)),
229  wrap_list(r, std::move(kwargs)));
230  }));
231  py::class_<Select, Expr>(m, "Select")
232  .def(py::init([](const Expr& expr, const Ident& field) {
233  return Select::create(expr.range(), expr, field);
234  }));
235  py::class_<TernaryIf, Expr>(m, "TernaryIf")
236  .def(py::init(
237  [](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
238  return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
239  }));
240  py::class_<ListLiteral, Expr>(m, "ListLiteral")
241  .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
242  return ListLiteral::create(range, wrap_list(range, std::move(args)));
243  }));
244  py::class_<TupleLiteral, Expr>(m, "TupleLiteral")
245  .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
246  return TupleLiteral::create(range, wrap_list(range, std::move(args)));
247  }));
248  py::class_<DictLiteral, Expr>(m, "DictLiteral")
249  .def(py::init([](const SourceRange& range,
250  std::vector<Expr> keys,
251  std::vector<Expr> values) {
252  return DictLiteral::create(
253  range,
254  wrap_list(range, std::move(keys)),
255  wrap_list(range, std::move(values)));
256  }));
257  py::class_<Subscript, Expr>(m, "Subscript")
258  .def(py::init([](const Expr& base, std::vector<Expr> subscript_exprs) {
259  return Subscript::create(
260  base.range(),
261  base,
262  wrap_list(base.range(), std::move(subscript_exprs)));
263  }));
264  py::class_<SliceExpr, Expr>(m, "SliceExpr")
265  .def(py::init([](const SourceRange& range, Expr* lower, Expr* upper) {
266  return SliceExpr::create(
267  range, wrap_maybe(range, lower), wrap_maybe(range, upper));
268  }));
269  py::class_<Starred, Expr>(m, "Starred")
270  .def(py::init([](const SourceRange& range, Expr expr) {
271  return Starred::create(range, expr);
272  }));
273 }
274 
275 } // namespace script
276 } // namespace jit
277 } // namespace torch
Definition: module.cpp:17
Definition: jit_type.h:17