Caffe2 - C++ API
A deep learning, cross platform ML framework
tree.h
1 #pragma once
2 
3 #include <functional>
4 #include <memory>
5 #include <vector>
6 
7 #include <torch/csrc/jit/script/lexer.h>
8 
9 namespace torch {
10 namespace jit {
11 namespace script {
12 
13 // Tree's are used to represent all forms of TC IR, pre- and post- typechecking.
14 // Rather than have a full class hierarchy for all TC statements,
15 // Trees are a slight variation of Lisp S-expressions.
16 // for instance the expression a*b+1 is represented as:
17 // (+ (* (ident a) (ident b)) (const 1))
18 // Atoms like 'a', 'b', and '1' are represented by subclasses of Tree which
19 // define stringValue().
20 // Everything else is a Compound object, which has a 'kind' that is a token from
21 // Lexer.h's TokenKind enum, and contains a list of subtrees.
22 // Like TokenKind single-character operators like '+' are representing using the
23 // character itself, so add.kind() == '+'.
24 // Compound objects are also always associated with a SourceRange for
25 // reporting error message.
26 
27 // Memory management of trees is done using shared_ptr.
28 
29 struct Tree;
30 using TreeRef = std::shared_ptr<Tree>;
31 using TreeList = std::vector<TreeRef>;
32 
33 static const TreeList empty_trees = {};
34 
35 struct Tree : std::enable_shared_from_this<Tree> {
36  Tree(int kind_) : kind_(kind_) {}
37  int kind() const {
38  return kind_;
39  }
40  virtual bool isAtom() const {
41  return true;
42  }
43  virtual const SourceRange& range() const {
44  throw std::runtime_error("is an Atom");
45  }
46  virtual const std::string& stringValue() const {
47  throw std::runtime_error("stringValue can only be called on TK_STRING");
48  }
49  virtual const TreeList& trees() const {
50  return empty_trees;
51  }
52  const TreeRef& tree(size_t i) const {
53  return trees().at(i);
54  }
55  virtual TreeRef map(const std::function<TreeRef(TreeRef)>& fn) {
56  (void)fn;
57  return shared_from_this();
58  }
59  template <typename... Args>
60  void match(int k, Args&... args) {
61  matchD(k, "unknown", 0, args...);
62  }
63  template <typename... Args>
64  void matchD(int k, const char* filename, int lineno, Args&... args) {
65  std::initializer_list<TreeRef*> vars = {&args...};
66  matchNumSubtreesD(k, filename, lineno, vars.size(), true);
67  size_t i = 0;
68  for (TreeRef* v : vars) {
69  *v = trees()[i++];
70  }
71  }
72  void matchNumSubtrees(int k, size_t expected_subtrees) {
73  return matchNumSubtreesD(k, "unknown", 0, expected_subtrees, false);
74  }
75  void matchNumSubtreesD(
76  int k,
77  const char* filename,
78  int lineno,
79  size_t expected_subtrees,
80  bool allow_more) {
81  if (kind() != k) {
82  std::stringstream ss;
83  ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
84  << "' but found '" << kindToString(kind()) << "'\n";
85  range().highlight(ss);
86  throw std::runtime_error(ss.str());
87  }
88  if (trees().size() < expected_subtrees ||
89  (!allow_more && trees().size() != expected_subtrees)) {
90  std::stringstream ss;
91  ss << filename << ":" << lineno << ": expected at least "
92  << expected_subtrees << " subtrees, but found only " << trees().size()
93  << "\n";
94  range().highlight(ss);
95  throw std::runtime_error(ss.str());
96  }
97  }
98  virtual ~Tree() = default;
99 
100  private:
101  int kind_;
102 };
103 
104 struct String : public Tree {
105  String(std::string value) : Tree(TK_STRING), value_(std::move(value)) {}
106  const std::string& stringValue() const override {
107  return value_;
108  }
109  template <typename... Args>
110  static TreeRef create(Args&&... args) {
111  return std::make_shared<String>(std::forward<Args>(args)...);
112  }
113 
114  private:
115  std::string value_;
116 };
117 
118 static SourceRange mergeRanges(SourceRange c, const TreeList& others) {
119  for (const auto& t : others) {
120  if (t->isAtom())
121  continue;
122  size_t s = std::min(c.start(), t->range().start());
123  size_t e = std::max(c.end(), t->range().end());
124  c = SourceRange(c.file_ptr(), s, e);
125  }
126  return c;
127 }
128 
129 struct Compound : public Tree {
130  Compound(int kind, SourceRange range)
131  : Tree(kind), range_(std::move(range)) {}
132  Compound(int kind, const SourceRange& range_, TreeList&& trees_)
133  : Tree(kind),
134  range_(mergeRanges(range_, trees_)),
135  trees_(std::move(trees_)) {}
136  const TreeList& trees() const override {
137  return trees_;
138  }
139  static TreeRef create(
140  int kind,
141  const SourceRange& range_,
142  TreeList&& trees_) {
143  return std::make_shared<Compound>(kind, range_, std::move(trees_));
144  }
145  bool isAtom() const override {
146  return false;
147  }
148  TreeRef map(const std::function<TreeRef(TreeRef)>& fn) override {
149  TreeList trees_;
150  for (auto& t : trees()) {
151  trees_.push_back(fn(t));
152  }
153  return Compound::create(kind(), range(), std::move(trees_));
154  }
155 
156  const SourceRange& range() const override {
157  return range_;
158  }
159 
160  private:
161  SourceRange range_;
162  TreeList trees_;
163 };
164 
165 // tree pretty printer
166 struct pretty_tree {
167  pretty_tree(const TreeRef& tree, size_t col = 40) : tree(tree), col(col) {}
168  const TreeRef& tree;
169  size_t col;
170  std::unordered_map<TreeRef, std::string> flat_strings;
171  const std::string& get_flat(const TreeRef& t) {
172  auto it = flat_strings.find(t);
173  if (it != flat_strings.end())
174  return it->second;
175 
176  std::stringstream out;
177  switch (t->kind()) {
178  case TK_STRING:
179  out << t->stringValue();
180  break;
181  default:
182  out << "(" << kindToString(t->kind());
183  for (const auto& e : t->trees()) {
184  out << " " << get_flat(e);
185  }
186  out << ")";
187  break;
188  }
189  auto it_ = flat_strings.emplace(t, out.str());
190  return it_.first->second;
191  }
192  void print(std::ostream& out, const TreeRef& t, int indent) {
193  const std::string& s = get_flat(t);
194  if (indent + s.size() < col || t->isAtom()) {
195  out << s;
196  return;
197  }
198  std::string k = kindToString(t->kind());
199  out << "(" << k;
200  for (const auto& e : t->trees()) {
201  out << "\n" << std::string(indent + 2, ' ');
202  print(out, e, indent + 2);
203  }
204  out << ")";
205  }
206 };
207 
208 static inline std::ostream& operator<<(std::ostream& out, pretty_tree t_) {
209  t_.print(out, t_.tree, 0);
210  return out << std::endl;
211 }
212 
213 static inline std::ostream& operator<<(std::ostream& out, const TreeRef& t) {
214  return out << pretty_tree(t);
215 }
216 
217 } // namespace script
218 } // namespace jit
219 } // namespace torch
Definition: jit_type.h:17