7 #include <torch/csrc/jit/script/lexer.h> 30 using TreeRef = std::shared_ptr<Tree>;
31 using TreeList = std::vector<TreeRef>;
33 static const TreeList empty_trees = {};
35 struct Tree : std::enable_shared_from_this<Tree> {
36 Tree(
int kind_) : kind_(kind_) {}
40 virtual bool isAtom()
const {
44 throw std::runtime_error(
"is an Atom");
46 virtual const std::string& stringValue()
const {
47 throw std::runtime_error(
"stringValue can only be called on TK_STRING");
49 virtual const TreeList& trees()
const {
52 const TreeRef& tree(
size_t i)
const {
55 virtual TreeRef map(
const std::function<TreeRef(TreeRef)>& fn) {
57 return shared_from_this();
59 template <
typename... Args>
60 void match(
int k, Args&... args) {
61 matchD(k,
"unknown", 0, args...);
63 template <
typename... Args>
64 void matchD(
int k,
const char* filename,
int lineno, Args&... args) {
65 std::initializer_list<TreeRef*> vars = {&args...};
66 matchNumSubtreesD(k, filename, lineno, vars.size(),
true);
68 for (TreeRef* v : vars) {
72 void matchNumSubtrees(
int k,
size_t expected_subtrees) {
73 return matchNumSubtreesD(k,
"unknown", 0, expected_subtrees,
false);
75 void matchNumSubtreesD(
79 size_t expected_subtrees,
83 ss << filename <<
":" << lineno <<
": expecting kind '" << kindToString(k)
84 <<
"' but found '" << kindToString(kind()) <<
"'\n";
85 range().highlight(ss);
86 throw std::runtime_error(ss.str());
88 if (trees().size() < expected_subtrees ||
89 (!allow_more && trees().size() != expected_subtrees)) {
91 ss << filename <<
":" << lineno <<
": expected at least " 92 << expected_subtrees <<
" subtrees, but found only " << trees().size()
94 range().highlight(ss);
95 throw std::runtime_error(ss.str());
98 virtual ~
Tree() =
default;
105 String(std::string value) :
Tree(TK_STRING), value_(std::move(value)) {}
106 const std::string& stringValue()
const override {
109 template <
typename... Args>
110 static TreeRef create(Args&&... args) {
111 return std::make_shared<String>(std::forward<Args>(args)...);
119 for (
const auto& t : others) {
122 size_t s = std::min(c.start(), t->range().start());
123 size_t e = std::max(c.end(), t->range().end());
131 :
Tree(kind), range_(std::move(range)) {}
134 range_(mergeRanges(range_, trees_)),
135 trees_(std::move(trees_)) {}
136 const TreeList& trees()
const override {
139 static TreeRef create(
143 return std::make_shared<Compound>(kind, range_, std::move(trees_));
145 bool isAtom()
const override {
148 TreeRef map(
const std::function<TreeRef(TreeRef)>& fn)
override {
150 for (
auto& t : trees()) {
151 trees_.push_back(fn(t));
153 return Compound::create(kind(), range(), std::move(trees_));
167 pretty_tree(
const TreeRef& tree,
size_t col = 40) : tree(tree), col(col) {}
170 std::unordered_map<TreeRef, std::string> flat_strings;
171 const std::string& get_flat(
const TreeRef& t) {
172 auto it = flat_strings.find(t);
173 if (it != flat_strings.end())
176 std::stringstream out;
179 out << t->stringValue();
182 out <<
"(" << kindToString(t->kind());
183 for (
const auto& e : t->trees()) {
184 out <<
" " << get_flat(e);
189 auto it_ = flat_strings.emplace(t, out.str());
190 return it_.first->second;
192 void print(std::ostream& out,
const TreeRef& t,
int indent) {
193 const std::string& s = get_flat(t);
194 if (indent + s.size() < col || t->isAtom()) {
198 std::string k = kindToString(t->kind());
200 for (
const auto& e : t->trees()) {
201 out <<
"\n" << std::string(indent + 2,
' ');
202 print(out, e, indent + 2);
208 static inline std::ostream& operator<<(std::ostream& out,
pretty_tree t_) {
209 t_.print(out, t_.tree, 0);
210 return out << std::endl;
213 static inline std::ostream& operator<<(std::ostream& out,
const TreeRef& t) {