Caffe2 - C++ API
A deep learning, cross platform ML framework
irparser.cpp
1 #include <torch/csrc/jit/irparser.h>
2 #include <torch/csrc/jit/ir.h>
3 #include <torch/csrc/jit/script/lexer.h>
4 #include <torch/csrc/jit/script/parse_string_literal.h>
5 #include <torch/csrc/jit/script/schema_type_parser.h>
6 
7 #include <string>
8 #include <vector>
9 
10 namespace torch {
11 namespace jit {
12 namespace script {
13 
14 struct VarWithType;
15 struct ParsedLiteral;
16 
17 class IRParser {
18  friend void parseIR(const std::string& str, torch::jit::Graph* graph);
19  IRParser(const std::string& str, torch::jit::Graph* graph)
20  : L(str),
21  g(graph),
22  type_parser(L, /*parse_complete_tensor_types*/ true) {}
23 
24  std::string parseVar();
25  VarWithType parseVarWithType();
26  ParsedLiteral parseScalarLiteral(Node* n);
27 
28  void parse();
29  void parseGraphInputs();
30  void parseReturnOperator();
31 
32  void parseBlocks(Node* parentNode);
33  void parseBlock(Node* parentNode);
34  void parseBlockInputs(Block* b);
35  void parseBlockOutputs(Block* b);
36 
37  void parseOperatorsList(Block* b);
38  void parseOperator(Block* b);
39  void parseOperatorOutputs(std::vector<VarWithType>* outs);
40  std::string parseOperatorName();
41  void parseOperatorInputs(Node* n);
42  void parseAttrs(Node* n);
43  void parseAttr(Node* n);
44 
45  void parseList(
46  int begin,
47  int sep,
48  int end,
49  const std::function<void()>& callback);
50 
52  torch::jit::Graph* g = nullptr;
53  std::unordered_map<std::string, Value*> vmap;
54  SchemaTypeParser type_parser;
55 };
56 
57 struct ParsedLiteral {
58  ParsedLiteral() = default;
59 
60  AttributeKind k = AttributeKind::t;
61 
62  int64_t i = 0;
63  std::string s = "";
64  double f = 0.0;
65  std::vector<int64_t> is;
66  std::vector<std::string> ss;
67  std::vector<double> fs;
68 };
69 
70 struct VarWithType {
71  VarWithType() = default;
72  std::string name;
73  TypePtr type;
74 };
75 
76 void parseIR(const std::string& str, torch::jit::Graph* graph) {
77  torch::jit::script::IRParser p(str, graph);
78  p.parse();
79 }
80 
81 VarWithType IRParser::parseVarWithType() {
82  VarWithType r;
83  r.name = parseVar();
84  r.type = TensorType::get();
85  if (L.nextIf(':')) {
86  auto type_alias = type_parser.parseType();
87  AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
88  r.type = type_alias.first;
89  }
90  return r;
91 }
92 
93 std::string IRParser::parseVar() {
94  L.expect('%');
95  if (L.cur().kind == TK_IDENT) {
96  auto name = L.expect(TK_IDENT).text();
97  if (L.cur().kind == TK_NUMBER) {
98  auto suffix = L.expect(TK_NUMBER).text();
99  AT_ASSERT(suffix[0] == '.');
100  name += suffix;
101  }
102  return name;
103  } else {
104  return L.expect(TK_NUMBER).text();
105  }
106 }
107 
108 void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
109  if (L.cur().kind != '%') {
110  return;
111  }
112  parseList(TK_NOTHING, ',', TK_NOTHING, [&] {
113  outs->push_back(parseVarWithType());
114  });
115  L.expect('=');
116 }
117 
118 // Parse string or numeric literal and return it along with its type.
119 ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
120  auto token = L.cur();
121  std::string str;
122  ParsedLiteral r;
123  switch (token.kind) {
124  case TK_STRINGLITERAL:
125  r.k = AttributeKind::s;
126  r.s = parseStringLiteral(token.range, token.text());
127  L.next();
128  return r;
129  case '-':
130  str = "-";
131  L.next();
132  L.expect(TK_NUMBER);
133  // Fallthrough
134  case TK_NUMBER:
135  str += L.cur().text();
136 
137  if (str.find('.') != std::string::npos ||
138  str.find('e') != std::string::npos) {
139  r.k = AttributeKind::f;
140  r.f = std::stod(str);
141  } else {
142  r.k = AttributeKind::i;
143  r.i = std::stoll(str);
144  }
145  L.next();
146  return r;
147  default:
148  throw ErrorReport(token.range)
149  << "Could not parse literal" << token.text();
150  }
151 }
152 
165 void IRParser::parseAttr(Node* n) {
166  std::string attrname = L.expect(TK_IDENT).text();
167  L.expect('=');
168  if (L.cur().kind == '[') {
169  // list
170  AttributeKind k = AttributeKind::ts;
171  std::vector<int64_t> is;
172  std::vector<std::string> ss;
173  std::vector<double> fs;
174  int elem_num = 0;
175  parseList('[', ',', ']', [&] {
176  ParsedLiteral r = parseScalarLiteral(n);
177  switch (r.k) {
178  case AttributeKind::s:
179  ss.push_back(r.s);
180  AT_ASSERT(!elem_num++ || k == AttributeKind::ss);
181  k = AttributeKind::ss;
182  break;
183  case AttributeKind::i:
184  is.push_back(r.i);
185  AT_ASSERT(!elem_num++ || k == AttributeKind::is);
186  k = AttributeKind::is;
187  break;
188  case AttributeKind::f:
189  fs.push_back(r.f);
190  AT_ASSERT(!elem_num++ || k == AttributeKind::fs);
191  k = AttributeKind::fs;
192  break;
193  default:
194  throw ErrorReport(L.cur().range) << "Unexpected attr type";
195  }
196  });
197  switch (k) {
198  case AttributeKind::ts:
199  n->ts_(Symbol::attr(attrname), {});
200  break;
201  case AttributeKind::ss:
202  n->ss_(Symbol::attr(attrname), ss);
203  break;
204  case AttributeKind::fs:
205  n->fs_(Symbol::attr(attrname), fs);
206  break;
207  case AttributeKind::is:
208  n->is_(Symbol::attr(attrname), is);
209  break;
210  default:
211  throw ErrorReport(L.cur().range) << "Unexpected attr type";
212  }
213  } else {
214  // scalar
215  ParsedLiteral r = parseScalarLiteral(n);
216  switch (r.k) {
217  case AttributeKind::s:
218  n->s_(Symbol::attr(attrname), r.s);
219  break;
220  case AttributeKind::i:
221  n->i_(Symbol::attr(attrname), r.i);
222  break;
223  case AttributeKind::f:
224  n->f_(Symbol::attr(attrname), r.f);
225  break;
226  default:
227  throw ErrorReport(L.cur().range) << "Unexpected attr type";
228  }
229  return;
230  }
231 }
232 
233 void IRParser::parseAttrs(Node* n) {
234  parseList('[', ',', ']', [&] { parseAttr(n); });
235 }
236 
237 void IRParser::parseOperatorInputs(Node* n) {
238  if (L.cur().kind == '[') {
239  parseAttrs(n);
240  }
241  parseList('(', ',', ')', [&] {
242  std::string var_name = parseVar();
243  AT_ASSERT(vmap.count(var_name));
244  n->addInput(vmap[var_name]);
245  });
246 }
247 
248 void IRParser::parseBlocks(Node* parentNode) {
249  L.expect(TK_INDENT);
250  while (L.cur().kind != TK_DEDENT) {
251  parseBlock(parentNode);
252  }
253  L.expect(TK_DEDENT);
254 }
255 
256 void IRParser::parseBlockInputs(Block* b) {
257  parseList('(', ',', ')', [&] {
258  VarWithType v = parseVarWithType();
259  // If the name isn't valid, don't use it
260  std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
261  vmap[v.name] = b->addInput(uniq_name);
262  vmap[v.name]->setType(v.type);
263  });
264 }
265 
266 void IRParser::parseBlockOutputs(Block* b) {
267  L.expect(TK_ARROW);
268  parseList('(', ',', ')', [&] {
269  std::string var_name = parseVar();
270  AT_ASSERT(vmap.count(var_name));
271  b->registerOutput(vmap[var_name]);
272  });
273  L.expect(TK_NEWLINE);
274  L.expect(TK_DEDENT);
275 }
276 
287 void IRParser::parseBlock(Node* parentNode) {
288  Block* b = parentNode->addBlock();
289  L.expect(TK_IDENT).text(); // Block name is not used anywhere.
290  parseBlockInputs(b);
291  L.expect(':');
292  parseOperatorsList(b);
293  parseBlockOutputs(b);
294 }
295 
301 void IRParser::parseOperatorsList(Block* b) {
302  L.expect(TK_INDENT);
303  while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) {
304  parseOperator(b);
305  }
306 }
307 
308 std::string IRParser::parseOperatorName() {
309  std::string name = L.expect(TK_IDENT).text();
310  L.expect(':');
311  L.expect(':');
312  name += "::" + L.expect(TK_IDENT).text();
313  return name;
314 }
315 
323 void IRParser::parseOperator(Block* b) {
324  // Parse lefthand side.
325  std::vector<VarWithType> outs;
326  parseOperatorOutputs(&outs);
327 
328  // Parse the name and create the corresponding node in the graph.
329  std::string name = parseOperatorName();
330  Node* n = g->create(Symbol::fromQualString(name), {}, outs.size());
331 
332  // Parse attributes and inputs.
333  parseOperatorInputs(n);
334 
335  // Register outputs.
336  int idx = 0;
337  for (const VarWithType& v : outs) {
338  vmap[v.name] = n->outputs()[idx++];
339  vmap[v.name]->setType(v.type);
340  }
341 
342  // Insert the new node into block B.
343  b->appendNode(n);
344 
345  // If the statement has nested blocks, parse them:
346  if (L.cur().kind == TK_INDENT) {
347  parseBlocks(n);
348  }
349  L.nextIf(TK_NEWLINE);
350 }
351 
352 void IRParser::parseGraphInputs() {
353  parseList('(', ',', ')', [&] {
354  VarWithType v = parseVarWithType();
355  // If the name isn't valid, don't use it
356  std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
357  vmap[v.name] = g->addInput(uniq_name);
358  vmap[v.name]->setType(v.type);
359  });
360 }
361 
367 void IRParser::parseReturnOperator() {
368  L.expect(TK_RETURN);
369 
370  // Parse output names and types
371  parseList('(', ',', ')', [&] {
372  std::string var_name = parseVar();
373  // Outputs should already be in VMAP, otherwise we're trying to return
374  // undefined value.
375  AT_ASSERT(vmap.count(var_name));
376  g->registerOutput(vmap.at(var_name));
377  });
378 
379  // Consume ending tokens
380  if (L.cur().kind != TK_EOF) {
381  L.expect(TK_NEWLINE);
382  L.expect(TK_DEDENT);
383  }
384 }
385 
396 void IRParser::parse() {
397  // Parse graph definition, it should look like the following:
398  // graphName (input1, input2, ... inputN):
399  std::string graphName = L.expect(TK_IDENT).text();
400  parseGraphInputs();
401  L.expect(':');
402 
403  // After the definition we should have a list of statements, parse it:
404  parseOperatorsList(g->block());
405 
406  // The last statement should be return, which specifies graph outputs
407  parseReturnOperator();
408 }
409 
410 void IRParser::parseList(
411  int begin,
412  int sep,
413  int end,
414  const std::function<void()>& callback) {
415  if (begin != TK_NOTHING) {
416  L.expect(begin);
417  }
418  if (L.cur().kind != end) {
419  do {
420  callback();
421  } while (L.nextIf(sep));
422  }
423  if (end != TK_NOTHING) {
424  L.expect(end);
425  }
426 }
427 } // namespace script
428 } // namespace jit
429 } // namespace torch
Definition: jit_type.h:17