2 #include <c10/util/Exception.h> 3 #include <torch/csrc/jit/source_range.h> 4 #include <torch/csrc/utils/memory.h> 11 #include <unordered_map> 27 #define TC_FORALL_TOKEN_KINDS(_) \ 28 _(TK_EOF, "eof", "") \ 29 _(TK_WHITESPACE, "whitespace", "") \ 30 _(TK_WHITESPACE_EOF, "whitespace_eof", "") \ 31 _(TK_NUMBER, "number", "") \ 32 _(TK_NEWLINE, "newline", "") \ 33 _(TK_INDENT, "indent", "") \ 34 _(TK_DEDENT, "dedent", "") \ 35 _(TK_DEF, "def", "def") \ 36 _(TK_EQUIVALENT, "equivalent", "<=>") \ 37 _(TK_IDENT, "ident", "") \ 38 _(TK_STRING, "string", "") \ 39 _(TK_STRINGLITERAL, "string_literal", "") \ 40 _(TK_CONST, "const", "") \ 41 _(TK_LIST, "list", "") \ 42 _(TK_DICT, "dict", "") \ 43 _(TK_OPTION, "option", "") \ 44 _(TK_APPLY, "apply", "") \ 45 _(TK_COMPREHENSION, "comprehension", "") \ 46 _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ 47 _(TK_PARAM, "param", "") \ 48 _(TK_INFERRED, "inferred", "") \ 49 _(TK_ACCESS, "access", "") \ 50 _(TK_ASSIGN, "assign", "") \ 51 _(TK_AUG_ASSIGN, "aug_assign", "") \ 52 _(TK_ATTRIBUTE, "attribute", "") \ 53 _(TK_IF, "if", "if") \ 54 _(TK_ELSE, "else", "else") \ 55 _(TK_ELIF, "elif", "elif") \ 56 _(TK_WHILE, "while", "while") \ 57 _(TK_EXPR_STMT, "expression statement", "") \ 58 _(TK_RETURN, "return", "return") \ 59 _(TK_IS, "is", "is") \ 60 _(TK_ISNOT, "is not", "is not") \ 61 _(TK_NE, "ne", "!=") \ 62 _(TK_EQ, "eq", "==") \ 63 _(TK_LE, "le", "<=") \ 64 _(TK_GE, "ge", ">=") \ 65 _(TK_FLOOR_DIV, "floordiv", "//") \ 66 _(TK_IF_EXPR, "if", "") \ 67 _(TK_TRUE, "True", "True") \ 68 _(TK_FALSE, "False", "False") \ 69 _(TK_NONE, "None", "None") \ 70 _(TK_AND, "and", "and") \ 71 _(TK_OR, "or", "or") \ 72 _(TK_NOT, "not", "not") \ 73 _(TK_CAST, "cast", "") \ 74 _(TK_PLUS_EQ, "+=", "+=") \ 75 _(TK_MINUS_EQ, "-=", "-=") \ 76 _(TK_TIMES_EQ, "*=", "*=") \ 77 _(TK_DIV_EQ, "/=", "/=") \ 78 _(TK_GLOBAL, "global", "global") \ 79 _(TK_BUILT_IN, "built-in", "") \ 80 _(TK_SUBSCRIPT, "subscript", "") \ 81 _(TK_VAR, "variable", "") \ 82 _(TK_NOTHING, "nothing", "") \ 83 _(TK_DICT_LITERAL, "dict-literal", "") \ 84 _(TK_LIST_LITERAL, "list-literal", "") \ 85 _(TK_TUPLE_LITERAL, "tuple-literal", "") \ 86 _(TK_FOR, "for", "for") \ 87 _(TK_IN, "in", "in") \ 88 _(TK_STARRED, "starred", "") \ 89 _(TK_UNARY_MINUS, "unary minus", "") \ 90 _(TK_POW, "pow operator", "**") \ 91 _(TK_ARROW, "arrow", "->") \ 92 _(TK_DECL, "decl", "") \ 93 _(TK_SLICE_EXPR, "slice expr", "") \ 94 _(TK_TYPE_COMMENT, "type comment", "# type:") \ 95 _(TK_RAISE, "raise", "raise") \ 96 _(TK_ASSERT, "assert", "assert") \ 97 _(TK_DOTS, "dots", "...") \ 98 _(TK_PASS, "pass", "pass") \ 99 _(TK_CLASS_DEF, "class", "class") 101 static const char* valid_single_char_tokens =
"+-*/%@()[]:,={}><.?!&^|";
107 TK_DUMMY_START = 256,
108 #define DEFINE_TOKEN(tok, _, _2) tok, 109 TC_FORALL_TOKEN_KINDS(DEFINE_TOKEN)
113 std::string kindToString(
int kind);
114 int stringToKind(
const std::string& str);
118 using TokenTrieRef = std::unique_ptr<TokenTrie>;
121 void insert(
const char* str,
int tok) {
123 AT_ASSERT(kind == 0);
128 for (
size_t i = 0, e = child_chars.size(); i < e; ++i) {
129 if (child_chars[i] == *str) {
130 child_tries[i]->insert(str + 1, tok);
135 child_chars.emplace_back(*str);
136 child_tries.emplace_back(torch::make_unique<TokenTrie>());
137 child_tries.back()->insert(str + 1, tok);
141 std::vector<char> child_chars;
142 std::vector<TokenTrieRef> child_tries;
149 std::stringstream ss;
150 for (
const char* c = valid_single_char_tokens; *c; c++) {
151 std::string str(1, *c);
152 head->insert(str.c_str(), *c);
155 #define ADD_CASE(tok, _, tokstring) \ 156 if (*(tokstring) != '\0') { \ 157 head->insert((tokstring), (tok)); \ 159 TC_FORALL_TOKEN_KINDS(ADD_CASE)
163 static double strtod_c(
const char* str,
char** end) {
165 static _locale_t loc = _create_locale(LC_ALL,
"C");
166 return _strtod_l(str, end, loc);
169 static double strtod_c(
const char* str,
char** end) {
171 static locale_t loc = newlocale(LC_ALL_MASK,
"C",
nullptr);
172 return strtod_l(str, end, loc);
178 bool isNumber(
const std::string& str,
size_t start,
size_t* len) {
179 char first = str[start];
184 if (first ==
'-' || first ==
'+' || isalpha(first))
186 const char* startptr = str.c_str() + start;
188 strtod_c(startptr, &endptr);
189 *len = endptr - startptr;
193 bool isCharCount(
char c,
const std::string& str,
size_t start,
int len) {
195 return start + len <= str.size() &&
196 std::count(str.begin() + start, str.begin() + start + len, c) == len;
203 bool isString(
const std::string& str,
size_t start,
size_t* len) {
204 char quote = str[start];
205 if (quote !=
'\"' && quote !=
'\'')
207 int quote_len = isCharCount(quote, str, start, 3) ? 3 : 1;
210 size_t end = start + quote_len;
211 while (end < str.size() && !isCharCount(quote, str, end, quote_len)) {
212 if (str[end] ==
'\n' && quote_len != 3) {
219 if (str[end] ==
'\\') {
225 *len = end - start + quote_len;
228 return end < str.size();
231 bool isblank(
int n) {
232 return isspace(n) && n !=
'\n';
235 bool isTypeComment(
const std::string& str,
size_t pos) {
236 const std::string type_string =
"# type:";
237 if (str.size() < pos + type_string.length()) {
240 auto match_string = str.substr(pos, type_string.size());
241 return match_string == type_string;
246 const std::string& str,
250 bool whitespace_token,
256 while (pos < str.size() && isblank(str[pos]))
260 if (pos < str.size()) {
261 if (str[pos] ==
'#' && !isTypeComment(str, pos)) {
263 while (pos < str.size() && str[pos] !=
'\n')
267 str, pos, continuation, whitespace_token, kind, start, len);
269 if (str[pos] ==
'\\' && pos + 1 < str.size() && str[pos + 1] ==
'\n' &&
271 return match(str, pos + 2, continuation,
false, kind, start, len);
273 if (str[pos] ==
'\n') {
275 str, pos + 1, continuation, !continuation, kind, start, len);
283 if (whitespace_token) {
284 *kind = pos == str.size() ? TK_WHITESPACE_EOF : TK_WHITESPACE;
288 if (pos == str.size()) {
297 if (isNumber(str, pos, len)) {
302 if (isString(str, pos, len)) {
303 *kind = TK_STRINGLITERAL;
310 bool matched =
false;
313 for (
size_t i = 0; pos + i < str.size() && (ident || cur !=
nullptr); i++) {
314 ident = ident && validIdent(i, str[pos + i]);
324 size_t child_offset = 0;
325 for (
size_t e = cur->child_chars.size(); child_offset < e;
327 if (cur->child_chars[child_offset] == str[pos + i])
331 cur = (child_offset == cur->child_chars.size())
333 : cur->child_tries[child_offset].get();
335 if (cur && cur->kind != 0) {
344 bool isUnary(
int kind,
int* prec);
345 bool isBinary(
int kind,
int* prec);
346 bool isRightAssociative(
int kind) {
357 bool validIdent(
size_t i,
char n) {
358 return isalpha(n) || n ==
'_' || (i > 0 && isdigit(n));
372 std::string kindString()
const {
373 return kindToString(kind);
378 explicit Lexer(
const std::string& str)
379 : file(std::make_shared<std::string>(str)),
384 shared(sharedParserData()) {
385 auto first_indent = lexRaw(
true);
386 indent_stack.push_back(first_indent.range.size());
391 if (next_tokens.size() == 0)
392 reportError(
"Lexer invariant violated: empty token queue");
393 Token r = next_tokens.front();
394 next_tokens.erase(next_tokens.begin());
395 if (next_tokens.size() == 0) {
401 bool nextIf(
int kind) {
402 if (cur().kind != kind)
408 [[noreturn]]
void reportError(
const std::string& what) {
409 reportError(what, cur());
411 [[noreturn]]
void reportError(
const std::string& what,
const Token& t) {
412 std::stringstream ss;
414 t.range.highlight(ss);
415 throw std::runtime_error(ss.str());
417 [[noreturn]]
void expected(
const std::string& what,
const Token& t) {
418 std::stringstream ss;
419 ss <<
"expected " << what <<
" but found '" << t.kindString()
421 t.range.highlight(ss);
422 throw std::runtime_error(ss.str());
424 [[noreturn]]
void expected(
const std::string& what) {
425 expected(what, cur());
429 Token expect(
int kind) {
430 if (cur().kind != kind) {
431 expected(kindToString(kind));
436 if (next_tokens.size() < 2) {
439 return next_tokens[1];
442 return next_tokens.front();
460 case TK_WHITESPACE_EOF: {
462 r.kind == TK_WHITESPACE_EOF ? indent_stack.front() : r.range.size();
469 if (depth > indent_stack.back()) {
470 indent_stack.push_back(depth);
472 }
else if (depth == indent_stack.back()) {
475 next_tokens.emplace_back(TK_NEWLINE, r.range);
476 while (indent_stack.back() != depth) {
477 indent_stack.pop_back();
478 next_tokens.emplace_back(TK_DEDENT, r.range);
479 if (indent_stack.size() == 0) {
480 reportError(
"invalid indent level " + std::to_string(depth), r);
489 next_tokens.push_back(std::move(r));
491 Token lexRaw(
bool whitespace_token =
false) {
509 pos = start + length;
513 std::shared_ptr<std::string> file;
516 std::vector<int> indent_stack;
518 std::vector<Token> next_tokens;
static double strtod_c(const char *str, char **end)