Caffe2 - C++ API
A deep learning, cross platform ML framework
lexer.h
1 #pragma once
2 #include <c10/util/Exception.h>
3 #include <torch/csrc/jit/source_range.h>
4 #include <torch/csrc/utils/memory.h>
5 #include <algorithm>
6 #include <clocale>
7 #include <iostream>
8 #include <memory>
9 #include <sstream>
10 #include <string>
11 #include <unordered_map>
12 #include <vector>
13 
14 namespace torch {
15 namespace jit {
16 namespace script {
17 
18 // single character tokens are just the character itself '+'
19 // multi-character tokens need an entry here
20 // if the third entry is not the empty string, it is used
21 // in the lexer to match this token.
22 
23 // These kinds are also used in Tree.h as the kind of the AST node.
24 // Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the
25 // lexer.
26 
27 #define TC_FORALL_TOKEN_KINDS(_) \
28  _(TK_EOF, "eof", "") \
29  _(TK_WHITESPACE, "whitespace", "") \
30  _(TK_WHITESPACE_EOF, "whitespace_eof", "") \
31  _(TK_NUMBER, "number", "") \
32  _(TK_NEWLINE, "newline", "") \
33  _(TK_INDENT, "indent", "") \
34  _(TK_DEDENT, "dedent", "") \
35  _(TK_DEF, "def", "def") \
36  _(TK_EQUIVALENT, "equivalent", "<=>") \
37  _(TK_IDENT, "ident", "") \
38  _(TK_STRING, "string", "") \
39  _(TK_STRINGLITERAL, "string_literal", "") \
40  _(TK_CONST, "const", "") \
41  _(TK_LIST, "list", "") \
42  _(TK_DICT, "dict", "") \
43  _(TK_OPTION, "option", "") \
44  _(TK_APPLY, "apply", "") \
45  _(TK_COMPREHENSION, "comprehension", "") \
46  _(TK_RANGE_CONSTRAINT, "range_constraint", "") \
47  _(TK_PARAM, "param", "") \
48  _(TK_INFERRED, "inferred", "") \
49  _(TK_ACCESS, "access", "") \
50  _(TK_ASSIGN, "assign", "") \
51  _(TK_AUG_ASSIGN, "aug_assign", "") \
52  _(TK_ATTRIBUTE, "attribute", "") \
53  _(TK_IF, "if", "if") \
54  _(TK_ELSE, "else", "else") \
55  _(TK_ELIF, "elif", "elif") \
56  _(TK_WHILE, "while", "while") \
57  _(TK_EXPR_STMT, "expression statement", "") \
58  _(TK_RETURN, "return", "return") \
59  _(TK_IS, "is", "is") \
60  _(TK_ISNOT, "is not", "is not") \
61  _(TK_NE, "ne", "!=") \
62  _(TK_EQ, "eq", "==") \
63  _(TK_LE, "le", "<=") \
64  _(TK_GE, "ge", ">=") \
65  _(TK_FLOOR_DIV, "floordiv", "//") \
66  _(TK_IF_EXPR, "if", "") \
67  _(TK_TRUE, "True", "True") \
68  _(TK_FALSE, "False", "False") \
69  _(TK_NONE, "None", "None") \
70  _(TK_AND, "and", "and") \
71  _(TK_OR, "or", "or") \
72  _(TK_NOT, "not", "not") \
73  _(TK_CAST, "cast", "") \
74  _(TK_PLUS_EQ, "+=", "+=") \
75  _(TK_MINUS_EQ, "-=", "-=") \
76  _(TK_TIMES_EQ, "*=", "*=") \
77  _(TK_DIV_EQ, "/=", "/=") \
78  _(TK_GLOBAL, "global", "global") \
79  _(TK_BUILT_IN, "built-in", "") \
80  _(TK_SUBSCRIPT, "subscript", "") \
81  _(TK_VAR, "variable", "") \
82  _(TK_NOTHING, "nothing", "") \
83  _(TK_DICT_LITERAL, "dict-literal", "") \
84  _(TK_LIST_LITERAL, "list-literal", "") \
85  _(TK_TUPLE_LITERAL, "tuple-literal", "") \
86  _(TK_FOR, "for", "for") \
87  _(TK_IN, "in", "in") \
88  _(TK_STARRED, "starred", "") \
89  _(TK_UNARY_MINUS, "unary minus", "") \
90  _(TK_POW, "pow operator", "**") \
91  _(TK_ARROW, "arrow", "->") \
92  _(TK_DECL, "decl", "") \
93  _(TK_SLICE_EXPR, "slice expr", "") \
94  _(TK_TYPE_COMMENT, "type comment", "# type:") \
95  _(TK_RAISE, "raise", "raise") \
96  _(TK_ASSERT, "assert", "assert") \
97  _(TK_DOTS, "dots", "...") \
98  _(TK_PASS, "pass", "pass") \
99  _(TK_CLASS_DEF, "class", "class")
100 
101 static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|";
102 
103 enum TokenKind {
104  // we use characters to represent themselves so skip all valid characters
105  // before
106  // assigning enum values to multi-char tokens.
107  TK_DUMMY_START = 256,
108 #define DEFINE_TOKEN(tok, _, _2) tok,
109  TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
110 #undef DEFINE_TOKEN
111 };
112 
113 std::string kindToString(int kind);
114 int stringToKind(const std::string& str);
115 
116 // nested hash tables that indicate char-by-char what is a valid token.
117 struct TokenTrie;
118 using TokenTrieRef = std::unique_ptr<TokenTrie>;
119 struct TokenTrie {
120  TokenTrie() : kind(0) {}
121  void insert(const char* str, int tok) {
122  if (*str == '\0') {
123  AT_ASSERT(kind == 0);
124  kind = tok;
125  return;
126  }
127 
128  for (size_t i = 0, e = child_chars.size(); i < e; ++i) {
129  if (child_chars[i] == *str) {
130  child_tries[i]->insert(str + 1, tok);
131  return;
132  }
133  }
134 
135  child_chars.emplace_back(*str);
136  child_tries.emplace_back(torch::make_unique<TokenTrie>());
137  child_tries.back()->insert(str + 1, tok);
138  }
139  int kind; // 0 == invalid token
140 
141  std::vector<char> child_chars;
142  std::vector<TokenTrieRef> child_tries;
143 };
144 
145 // stuff that is shared against all TC lexers/parsers and is initialized only
146 // once.
148  SharedParserData() : head(new TokenTrie()) {
149  std::stringstream ss;
150  for (const char* c = valid_single_char_tokens; *c; c++) {
151  std::string str(1, *c);
152  head->insert(str.c_str(), *c);
153  }
154 
155 #define ADD_CASE(tok, _, tokstring) \
156  if (*(tokstring) != '\0') { \
157  head->insert((tokstring), (tok)); \
158  }
159  TC_FORALL_TOKEN_KINDS(ADD_CASE)
160 #undef ADD_CASE
161  }
162 #ifdef _WIN32
163  static double strtod_c(const char* str, char** end) {
165  static _locale_t loc = _create_locale(LC_ALL, "C");
166  return _strtod_l(str, end, loc);
167  }
168 #else
169  static double strtod_c(const char* str, char** end) {
171  static locale_t loc = newlocale(LC_ALL_MASK, "C", nullptr);
172  return strtod_l(str, end, loc);
173  }
174 #endif
175  // 1. skip whitespace
176  // 2. handle comment or newline
177  //
178  bool isNumber(const std::string& str, size_t start, size_t* len) {
179  char first = str[start];
180  // strtod allows numbers to start with + or - or nan or inf
181  // http://en.cppreference.com/w/cpp/string/byte/strtof
182  // but we want only the number part, otherwise 1+3 will turn into two
183  // adjacent numbers in the lexer
184  if (first == '-' || first == '+' || isalpha(first))
185  return false;
186  const char* startptr = str.c_str() + start;
187  char* endptr;
188  strtod_c(startptr, &endptr);
189  *len = endptr - startptr;
190  return *len > 0;
191  }
192 
193  bool isCharCount(char c, const std::string& str, size_t start, int len) {
194  // count checks from [start, start + len)
195  return start + len <= str.size() &&
196  std::count(str.begin() + start, str.begin() + start + len, c) == len;
197  }
198 
199  // python concatenates all adjacent strings "a" "b" == "ab"
200  // strings can be enclosed with 1 or 3 single or double quotes
201  // if enclosed with 3 quotes newlines are valid
202  // as elsewhere, backslash and new line should be ignored
203  bool isString(const std::string& str, size_t start, size_t* len) {
204  char quote = str[start];
205  if (quote != '\"' && quote != '\'')
206  return false;
207  int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1;
208 
209  // end is now set past the opening quotation marks
210  size_t end = start + quote_len;
211  while (end < str.size() && !isCharCount(quote, str, end, quote_len)) {
212  if (str[end] == '\n' && quote_len != 3) {
213  return false;
214  }
215  // handle escaped characters. advances past escaped quotation marks,
216  // escaped newlines and escaped backslashes
217  // multi-char escapes like \x1A are handled fine here because the
218  // remainder of the escape are valid string characters anyway
219  if (str[end] == '\\') {
220  end++;
221  }
222  end++;
223  }
224  // set length equal to the complete string including quotations
225  *len = end - start + quote_len;
226  // if end finished without going past the last character of the string than
227  // there is a match
228  return end < str.size();
229  }
230 
231  bool isblank(int n) {
232  return isspace(n) && n != '\n';
233  }
234  // Make an exception ignoring comments for type annotation comments
235  bool isTypeComment(const std::string& str, size_t pos) {
236  const std::string type_string = "# type:";
237  if (str.size() < pos + type_string.length()) {
238  return false;
239  }
240  auto match_string = str.substr(pos, type_string.size());
241  return match_string == type_string;
242  }
243  // find the longest match of str.substring(pos) against a token, return true
244  // if successful filling in kind, start,and len
245  bool match(
246  const std::string& str,
247  size_t pos,
248  bool continuation, // are we inside a scope where newlines don't count
249  // (e.g. inside parens)
250  bool whitespace_token, // should we treat whitespace as a token
251  int* kind,
252  size_t* start,
253  size_t* len) {
254  *start = pos;
255  // skip whitespace
256  while (pos < str.size() && isblank(str[pos]))
257  pos++;
258 
259  // special handling
260  if (pos < str.size()) {
261  if (str[pos] == '#' && !isTypeComment(str, pos)) {
262  // skip comments
263  while (pos < str.size() && str[pos] != '\n')
264  pos++;
265  // tail call, handle whitespace and more comments
266  return match(
267  str, pos, continuation, whitespace_token, kind, start, len);
268  }
269  if (str[pos] == '\\' && pos + 1 < str.size() && str[pos + 1] == '\n' &&
270  !whitespace_token) {
271  return match(str, pos + 2, continuation, false, kind, start, len);
272  }
273  if (str[pos] == '\n') {
274  return match(
275  str, pos + 1, continuation, !continuation, kind, start, len);
276  }
277  }
278  // we handle white space before EOF because in the case we have something
279  // like the following where we need to generate the dedent token if foo:
280  // ...
281  // else:
282  // pass
283  if (whitespace_token) {
284  *kind = pos == str.size() ? TK_WHITESPACE_EOF : TK_WHITESPACE;
285  *len = pos - *start;
286  return true;
287  }
288  if (pos == str.size()) {
289  *kind = TK_EOF;
290  *start = pos;
291  *len = 0;
292  return true;
293  }
294  // invariant: the next token is not whitespace or newline
295  *start = pos;
296  // check for a valid number
297  if (isNumber(str, pos, len)) {
298  *kind = TK_NUMBER;
299  return true;
300  }
301  // check for string
302  if (isString(str, pos, len)) {
303  *kind = TK_STRINGLITERAL;
304  return true;
305  }
306 
307  // check for either an ident or a token
308  // ident tracks whether what we have scanned so far could be an identifier
309  // matched indicates if we have found any match.
310  bool matched = false;
311  bool ident = true;
312  TokenTrie* cur = head.get();
313  for (size_t i = 0; pos + i < str.size() && (ident || cur != nullptr); i++) {
314  ident = ident && validIdent(i, str[pos + i]);
315  if (ident) {
316  matched = true;
317  *len = i + 1;
318  *kind = TK_IDENT;
319  }
320  // check for token second, so that e.g. 'max' matches the token TK_MAX
321  // rather the
322  // identifier 'max'
323  if (cur) {
324  size_t child_offset = 0;
325  for (size_t e = cur->child_chars.size(); child_offset < e;
326  ++child_offset) {
327  if (cur->child_chars[child_offset] == str[pos + i])
328  break;
329  }
330 
331  cur = (child_offset == cur->child_chars.size())
332  ? nullptr
333  : cur->child_tries[child_offset].get();
334 
335  if (cur && cur->kind != 0) {
336  matched = true;
337  *len = i + 1;
338  *kind = cur->kind;
339  }
340  }
341  }
342  return matched;
343  }
344  bool isUnary(int kind, int* prec);
345  bool isBinary(int kind, int* prec);
346  bool isRightAssociative(int kind) {
347  switch (kind) {
348  case '?':
349  case TK_POW:
350  return true;
351  default:
352  return false;
353  }
354  }
355 
356  private:
357  bool validIdent(size_t i, char n) {
358  return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
359  }
360  TokenTrieRef head;
361 };
362 
363 SharedParserData& sharedParserData();
364 
365 struct Token {
366  int kind;
367  SourceRange range;
368  Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
369  std::string text() {
370  return range.text();
371  }
372  std::string kindString() const {
373  return kindToString(kind);
374  }
375 };
376 
377 struct Lexer {
378  explicit Lexer(const std::string& str)
379  : file(std::make_shared<std::string>(str)),
380  pos(0),
381  nesting(0),
382  indent_stack(),
383  next_tokens(),
384  shared(sharedParserData()) {
385  auto first_indent = lexRaw(true);
386  indent_stack.push_back(first_indent.range.size());
387  lex();
388  }
389  // Return the current token, and then move to the next one
390  Token next() {
391  if (next_tokens.size() == 0)
392  reportError("Lexer invariant violated: empty token queue");
393  Token r = next_tokens.front();
394  next_tokens.erase(next_tokens.begin());
395  if (next_tokens.size() == 0) {
396  lex();
397  }
398  return r;
399  }
400  // Skip the current token if it matches the given kind
401  bool nextIf(int kind) {
402  if (cur().kind != kind)
403  return false;
404  next();
405  return true;
406  }
407 
408  [[noreturn]] void reportError(const std::string& what) {
409  reportError(what, cur());
410  }
411  [[noreturn]] void reportError(const std::string& what, const Token& t) {
412  std::stringstream ss;
413  ss << what << ":\n";
414  t.range.highlight(ss);
415  throw std::runtime_error(ss.str());
416  }
417  [[noreturn]] void expected(const std::string& what, const Token& t) {
418  std::stringstream ss;
419  ss << "expected " << what << " but found '" << t.kindString()
420  << "' here:\n";
421  t.range.highlight(ss);
422  throw std::runtime_error(ss.str());
423  }
424  [[noreturn]] void expected(const std::string& what) {
425  expected(what, cur());
426  }
427  // Check that the current token has a given kind, return the current token,
428  // and advance to the next one.
429  Token expect(int kind) {
430  if (cur().kind != kind) {
431  expected(kindToString(kind));
432  }
433  return next();
434  }
435  Token& lookahead() {
436  if (next_tokens.size() < 2) {
437  lex();
438  }
439  return next_tokens[1];
440  }
441  Token& cur() {
442  return next_tokens.front();
443  }
444 
445  private:
446  void lex() {
447  auto r = lexRaw();
448  switch (r.kind) {
449  case '(':
450  case '[':
451  case '{':
452  nesting++;
453  break;
454  case ')':
455  case ']':
456  case '}':
457  nesting--;
458  break;
459  case TK_WHITESPACE:
460  case TK_WHITESPACE_EOF: {
461  int depth =
462  r.kind == TK_WHITESPACE_EOF ? indent_stack.front() : r.range.size();
463  // note: TK_WHITESPACE_EOF is whitespace right before the EOF token
464  // just like we allow the code to be indented to a particular initial
465  // indent level, we allow the final indent to be anything and set
466  // it back to the initial indent level. This allows the code to be
467  // put into string literals inside code without worrying about final
468  // whitespace
469  if (depth > indent_stack.back()) {
470  indent_stack.push_back(depth);
471  r.kind = TK_INDENT;
472  } else if (depth == indent_stack.back()) {
473  r.kind = TK_NEWLINE;
474  } else {
475  next_tokens.emplace_back(TK_NEWLINE, r.range);
476  while (indent_stack.back() != depth) {
477  indent_stack.pop_back();
478  next_tokens.emplace_back(TK_DEDENT, r.range);
479  if (indent_stack.size() == 0) {
480  reportError("invalid indent level " + std::to_string(depth), r);
481  }
482  }
483  return; // We've already queued the tokens
484  }
485  } break;
486  default:
487  break;
488  }
489  next_tokens.push_back(std::move(r));
490  }
491  Token lexRaw(bool whitespace_token = false) {
492  int kind;
493  size_t start;
494  size_t length;
495  AT_ASSERT(file);
496  if (!shared.match(
497  *file,
498  pos,
499  nesting > 0,
500  whitespace_token,
501  &kind,
502  &start,
503  &length)) {
504  expected(
505  "a valid token",
506  Token((*file)[start], SourceRange(file, start, start + 1)));
507  }
508  auto t = Token(kind, SourceRange(file, start, start + length));
509  pos = start + length;
510  return t;
511  }
512 
513  std::shared_ptr<std::string> file;
514  size_t pos;
515  size_t nesting; // depth of ( [ { nesting...
516  std::vector<int> indent_stack; // stack of identation level of blocks
517  // Invariant: this should always contain at least a single element
518  std::vector<Token> next_tokens;
519  SharedParserData& shared;
520 };
521 } // namespace script
522 } // namespace jit
523 } // namespace torch
static double strtod_c(const char *str, char **end)
Definition: lexer.h:169
Definition: jit_type.h:17