1 #include <torch/csrc/jit/irparser.h> 2 #include <torch/csrc/jit/ir.h> 3 #include <torch/csrc/jit/script/lexer.h> 4 #include <torch/csrc/jit/script/parse_string_literal.h> 5 #include <torch/csrc/jit/script/schema_type_parser.h> 22 type_parser(L,
true) {}
24 std::string parseVar();
29 void parseGraphInputs();
30 void parseReturnOperator();
32 void parseBlocks(
Node* parentNode);
33 void parseBlock(
Node* parentNode);
34 void parseBlockInputs(
Block* b);
35 void parseBlockOutputs(
Block* b);
37 void parseOperatorsList(
Block* b);
38 void parseOperator(
Block* b);
39 void parseOperatorOutputs(std::vector<VarWithType>* outs);
40 std::string parseOperatorName();
41 void parseOperatorInputs(
Node* n);
42 void parseAttrs(
Node* n);
43 void parseAttr(
Node* n);
49 const std::function<
void()>& callback);
53 std::unordered_map<std::string, Value*> vmap;
60 AttributeKind k = AttributeKind::t;
65 std::vector<int64_t> is;
66 std::vector<std::string> ss;
67 std::vector<double> fs;
84 r.type = TensorType::get();
86 auto type_alias = type_parser.parseType();
87 AT_ASSERTM(!type_alias.second,
"Parsing IR with Alias Info not handled");
88 r.type = type_alias.first;
93 std::string IRParser::parseVar() {
95 if (L.cur().kind == TK_IDENT) {
96 auto name = L.expect(TK_IDENT).text();
97 if (L.cur().kind == TK_NUMBER) {
98 auto suffix = L.expect(TK_NUMBER).text();
99 AT_ASSERT(suffix[0] ==
'.');
104 return L.expect(TK_NUMBER).text();
108 void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
109 if (L.cur().kind !=
'%') {
112 parseList(TK_NOTHING,
',', TK_NOTHING, [&] {
113 outs->push_back(parseVarWithType());
120 auto token = L.cur();
123 switch (token.kind) {
124 case TK_STRINGLITERAL:
125 r.k = AttributeKind::s;
126 r.s = parseStringLiteral(token.range, token.text());
135 str += L.cur().text();
137 if (str.find(
'.') != std::string::npos ||
138 str.find(
'e') != std::string::npos) {
139 r.k = AttributeKind::f;
140 r.f = std::stod(str);
142 r.k = AttributeKind::i;
143 r.i = std::stoll(str);
149 <<
"Could not parse literal" << token.text();
165 void IRParser::parseAttr(
Node* n) {
166 std::string attrname = L.expect(TK_IDENT).text();
168 if (L.cur().kind ==
'[') {
170 AttributeKind k = AttributeKind::ts;
171 std::vector<int64_t> is;
172 std::vector<std::string> ss;
173 std::vector<double> fs;
175 parseList(
'[',
',',
']', [&] {
178 case AttributeKind::s:
180 AT_ASSERT(!elem_num++ || k == AttributeKind::ss);
181 k = AttributeKind::ss;
183 case AttributeKind::i:
185 AT_ASSERT(!elem_num++ || k == AttributeKind::is);
186 k = AttributeKind::is;
188 case AttributeKind::f:
190 AT_ASSERT(!elem_num++ || k == AttributeKind::fs);
191 k = AttributeKind::fs;
194 throw ErrorReport(L.cur().range) <<
"Unexpected attr type";
198 case AttributeKind::ts:
199 n->ts_(Symbol::attr(attrname), {});
201 case AttributeKind::ss:
202 n->ss_(Symbol::attr(attrname), ss);
204 case AttributeKind::fs:
205 n->fs_(Symbol::attr(attrname), fs);
207 case AttributeKind::is:
208 n->is_(Symbol::attr(attrname), is);
211 throw ErrorReport(L.cur().range) <<
"Unexpected attr type";
217 case AttributeKind::s:
218 n->s_(Symbol::attr(attrname), r.s);
220 case AttributeKind::i:
221 n->i_(Symbol::attr(attrname), r.i);
223 case AttributeKind::f:
224 n->f_(Symbol::attr(attrname), r.f);
227 throw ErrorReport(L.cur().range) <<
"Unexpected attr type";
233 void IRParser::parseAttrs(
Node* n) {
234 parseList(
'[',
',',
']', [&] { parseAttr(n); });
237 void IRParser::parseOperatorInputs(
Node* n) {
238 if (L.cur().kind ==
'[') {
241 parseList(
'(',
',',
')', [&] {
242 std::string var_name = parseVar();
243 AT_ASSERT(vmap.count(var_name));
244 n->addInput(vmap[var_name]);
248 void IRParser::parseBlocks(
Node* parentNode) {
250 while (L.cur().kind != TK_DEDENT) {
251 parseBlock(parentNode);
256 void IRParser::parseBlockInputs(
Block* b) {
257 parseList(
'(',
',',
')', [&] {
260 std::string uniq_name = Value::isValidName(v.name) ? v.name :
"";
261 vmap[v.name] = b->addInput(uniq_name);
262 vmap[v.name]->setType(v.type);
266 void IRParser::parseBlockOutputs(
Block* b) {
268 parseList(
'(',
',',
')', [&] {
269 std::string var_name = parseVar();
270 AT_ASSERT(vmap.count(var_name));
271 b->registerOutput(vmap[var_name]);
273 L.expect(TK_NEWLINE);
287 void IRParser::parseBlock(
Node* parentNode) {
288 Block* b = parentNode->addBlock();
289 L.expect(TK_IDENT).text();
292 parseOperatorsList(b);
293 parseBlockOutputs(b);
301 void IRParser::parseOperatorsList(
Block* b) {
303 while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) {
308 std::string IRParser::parseOperatorName() {
309 std::string name = L.expect(TK_IDENT).text();
312 name +=
"::" + L.expect(TK_IDENT).text();
323 void IRParser::parseOperator(
Block* b) {
325 std::vector<VarWithType> outs;
326 parseOperatorOutputs(&outs);
329 std::string name = parseOperatorName();
330 Node* n = g->create(Symbol::fromQualString(name), {}, outs.size());
333 parseOperatorInputs(n);
338 vmap[v.name] = n->outputs()[idx++];
339 vmap[v.name]->setType(v.type);
346 if (L.cur().kind == TK_INDENT) {
349 L.nextIf(TK_NEWLINE);
352 void IRParser::parseGraphInputs() {
353 parseList(
'(',
',',
')', [&] {
356 std::string uniq_name = Value::isValidName(v.name) ? v.name :
"";
357 vmap[v.name] = g->addInput(uniq_name);
358 vmap[v.name]->setType(v.type);
367 void IRParser::parseReturnOperator() {
371 parseList(
'(',
',',
')', [&] {
372 std::string var_name = parseVar();
375 AT_ASSERT(vmap.count(var_name));
376 g->registerOutput(vmap.at(var_name));
380 if (L.cur().kind != TK_EOF) {
381 L.expect(TK_NEWLINE);
396 void IRParser::parse() {
399 std::string graphName = L.expect(TK_IDENT).text();
404 parseOperatorsList(g->block());
407 parseReturnOperator();
410 void IRParser::parseList(
414 const std::function<
void()>& callback) {
415 if (begin != TK_NOTHING) {
418 if (L.cur().kind != end) {
421 }
while (L.nextIf(sep));
423 if (end != TK_NOTHING) {