Caffe2 - C++ API
A deep learning, cross platform ML framework
python_print.cpp
1 #include <c10/util/Exception.h>
2 #include <torch/csrc/jit/attributes.h>
3 #include <torch/csrc/jit/export.h>
4 #include <torch/csrc/jit/ir.h>
5 #include <torch/csrc/jit/ir_views.h>
6 #include <torch/csrc/jit/passes/python_print.h>
7 #include <torch/csrc/jit/resource_guard.h>
8 #include <torch/csrc/jit/script/error_report.h>
9 #include <torch/csrc/jit/script/module.h>
10 
11 namespace torch {
12 namespace jit {
13 
14 // unix isprint but insensitive to locale
15 static bool isPrint(char s) {
16  return s > 0x1f && s < 0x7f;
17 }
18 
19 void printQuotedString(std::ostream& stmt, const std::string& str) {
20  stmt << "\"";
21  for (auto s : str) {
22  switch (s) {
23  case '\\':
24  stmt << "\\\\";
25  break;
26  case '\'':
27  stmt << "\\'";
28  break;
29  case '\"':
30  stmt << "\\\"";
31  break;
32  case '\a':
33  stmt << "\\a";
34  break;
35  case '\b':
36  stmt << "\\b";
37  break;
38  case '\f':
39  stmt << "\\f";
40  break;
41  case '\n':
42  stmt << "\\n";
43  break;
44  case '\r':
45  stmt << "\\r";
46  break;
47  case '\t':
48  stmt << "\\t";
49  break;
50  case '\v':
51  stmt << "\\v";
52  break;
53  default:
54  if (isPrint(s)) {
55  stmt << s;
56  } else {
57  // C++ io has stateful formatting settings. Messing with
58  // them is probably worse than doing this manually.
59  char buf[4] = "000";
60  buf[2] += s % 8;
61  s /= 8;
62  buf[1] += s % 8;
63  s /= 8;
64  buf[0] += s;
65  stmt << "\\" << buf;
66  }
67  break;
68  }
69  }
70  stmt << "\"";
71 }
72 
73 static bool isValidIdentifierChar(char c, size_t pos) {
74  return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
75 }
76 
77 static bool isValidIdentifier(const std::string& name) {
78  if (name.size() == 0)
79  return false;
80  for (size_t i = 0; i < name.size(); ++i) {
81  if (!isValidIdentifierChar(name[i], i))
82  return false;
83  }
84  return true;
85 }
86 
87 // handles names of the form, e.g., self.a.b
88 // if a field is not a valid identifier, then it will print as, e.g.
89 // getattr(self, "0").b
90 struct QualifiedName;
91 using QualifiedNamePtr = c10::intrusive_ptr<QualifiedName>;
93  QualifiedName(QualifiedNamePtr prefix, std::string name)
94  : prefix_(std::move(prefix)), name_(std::move(name)) {}
95  QualifiedNamePtr prefix_;
96  std::string name_;
97  static QualifiedNamePtr create(QualifiedNamePtr prefix, std::string name) {
98  return c10::make_intrusive<QualifiedName>(
99  std::move(prefix), std::move(name));
100  }
101  static QualifiedNamePtr create(std::string name) {
102  return c10::make_intrusive<QualifiedName>(
103  QualifiedNamePtr(), std::move(name));
104  }
105  std::string str() const {
106  std::stringstream ss;
107  emit(ss);
108  return ss.str();
109  }
110 
111  private:
112  void emit(std::ostream& out) const {
113  if (isValidIdentifier(name_)) {
114  if (prefix_) {
115  prefix_->emit(out);
116  out << ".";
117  }
118  out << name_;
119  } else {
120  AT_ASSERT(prefix_);
121  out << "getattr(";
122  prefix_->emit(out);
123  out << ", ";
124  printQuotedString(out, name_);
125  out << ")";
126  }
127  }
128 };
129 
130 void createTensorToParameterNameMap(
131  const script::Module& module,
132  const QualifiedNamePtr& prefix,
133  std::unordered_map<IValue*, QualifiedNamePtr>& result) {
134  for (const auto& elem : module.get_parameters()) {
135  const script::NamedIValue& param = elem.value();
136  result[param.slot()] = QualifiedName::create(prefix, param.name_);
137  }
138  for (const auto& elem : module.get_attributes()) {
139  const script::NamedIValue& param = elem.value();
140  result[param.slot()] = QualifiedName::create(prefix, param.name_);
141  }
142  for (const auto& elem : module.get_modules()) {
143  createTensorToParameterNameMap(
144  *elem->module, QualifiedName::create(prefix, elem.key()), result);
145  }
146 }
147 
148 // some names are valid identifiers but off limits because
149 // they are keywords or namespaces used in the output
150 const static std::unordered_set<std::string> reserved_names = {
151  // identifiers in the environment while parsing
152  "_", // avoid the confusing unnamed _
153  "aten",
154  "attribute",
155  "CONSTANTS",
156  "fork",
157  "getattr",
158  "inf",
159  "nan",
160  "ops",
161  "self",
162  // the python keywords
163  "and",
164  "as",
165  "assert",
166  "async",
167  "await",
168  "break",
169  "class",
170  "continue",
171  "def",
172  "del",
173  "elif",
174  "else",
175  "except",
176  "False",
177  "finally",
178  "for",
179  "from",
180  "global",
181  "if",
182  "import",
183  "in",
184  "is",
185  "lambda",
186  "None",
187  "nonlocal",
188  "not",
189  "or",
190  "pass",
191  "raise",
192  "return",
193  "True",
194  "try",
195  "while",
196  "with",
197  "yield",
198 };
199 
201  std::ostream& out;
202 
203  // constants are written to this table, and given then named CONSTANTS.cN
204  // where N is the index into this table.
205  std::vector<at::Tensor>& tensor_table_;
206 
207  // Any classes used are written to this table, to be later written out as
208  // dependencies.
209  std::vector<ClassTypePtr>& class_table_;
210  // Helper to avoid duplicating class types
211  void addToClassTable(const ClassTypePtr& classType) {
212  if (std::find(class_table_.cbegin(), class_table_.cend(), classType) ==
213  class_table_.cend()) {
214  class_table_.push_back(classType);
215  }
216  }
217 
218  // When printing this node, is it safe to write it inline (i.e. without
219  // assigning a temporary variable
220  std::unordered_set<Node*> output_inline_;
221 
222  // when we print this, should we error if the resulting output would
223  // not be able to be reparsed?
224  bool enforce_importable_;
225 
226  // what valid identifiers are in use for the current function
227  std::unordered_set<std::string> used_names_;
228 
229  // used method names
230  std::unordered_set<std::string> used_method_names_;
231 
232  // for fork,
233  // subgraphs get added to the worklist, and will be printed later
234  std::vector<std::function<void(void)>> worklist;
235 
236  // scanValue, scanNode, scanBlock:
237  // decide if it is safe to omit the output of a temporary variable,
238  // and inline the expression into its use
239  // we only do this if
240  // (1) it is a constant, or
241  // (2) the temporary is unnamed, is single output, is used once,
242  // and would appear in the same order when the expression tree is
243  // reparsed.
244  // The last case can be checked
245  // becuase when we emit a expresion tree in the parser,
246  // we do a left-to-right postorder traversal of the expression tree (emit
247  // children, then emit op). The reverse of this is a right-to-left preorder
248  // traversal of the tree. By doing a right-to-left preorder traversal of the
249  // inputs of a node, while also scanning the list of emitted nodes backward,
250  // we can see if they line up with what would happen when parsed the node as
251  // an expression. While they line up we collapse them into an inline
252  // expression.
253 
254  // The inductive step is that the right-most input should be produced by the
255  // node immediatly before the current node if it is in tree order.
256 
257  bool canInline(Value* v) {
258  Node* n = v->node();
259  // there must be only 1 values, otherwise we need an assignment to handle
260  // the multiple outout values
261  if (n->outputs().size() != 1)
262  return false;
263  // if it is used more than once, then we need a variable
264  if (v->uses().size() != 1)
265  return false;
266  auto use = v->uses().at(0);
267  // if it has a name set, then it was written as a variable so preserve that
268  // unless it is being fed directly to the end of the block.
269  // in which case it is not as useful to give it a name just to return it
270  if (v->hasUniqueName() && use.user->kind() != prim::Return)
271  return false;
272  // don't try to inline control blocks
273  if (n->blocks().size() != 0)
274  return false;
275  // if it is a loop-carried input, we need a variable
276  // otherwise the condition or trip count may be emitted in the wrong order
277  // w.r.t. to it
278  if (use.user->kind() == prim::Loop && use.offset >= 2)
279  return false;
280  return true;
281  }
282 
283  // block_point is the current node in the reverse linear scan of the emitted
284  // nodes v is the current value in the tree traversal that may match with
285  // block_point's output.
286  Node* scanValue(Node* block_point, Value* v) {
287  Node* n = v->node();
288  AT_ASSERT(n->kind() == prim::Constant || output_inline_.count(n) == 0);
289 
290  if (n == block_point &&
291  canInline(v)) { // the node must be at the expected point of the typical
292  // tree traversal
293  // recursively see if we can inline the inputs to this input
294  block_point = scanNode(block_point);
295  output_inline_.insert(n);
296  } else if (n->kind() == prim::Constant) {
297  // constant nodes can always be inlined, we will de-dup them on parsing
298  // and put them at the top of the function regardless
299  output_inline_.insert(n);
300  }
301  return block_point;
302  }
303  Node* previousNonConstant(Node* n) {
304  do {
305  n = n->prev();
306  } while (n->kind() == prim::Constant);
307  return n;
308  }
309 
310  Node* scanNode(Node* n) {
311  // don't bother to scan nodes we have already determined to be inline
312  if (output_inline_.count(n)) {
313  return n;
314  }
315  for (auto b : n->blocks()) {
316  scanBlock(b);
317  }
318  Node* block_point = previousNonConstant(n);
319  for (auto it = n->inputs().rbegin(), end = n->inputs().rend(); it != end;
320  ++it) {
321  block_point = scanValue(block_point, *it);
322  }
323  return block_point;
324  }
325 
326  void scanBlock(Block* b) {
327  scanNode(b->return_node());
328  for (auto node : b->nodes().reverse()) {
329  scanNode(node);
330  }
331  }
332 
333  size_t getOrAddTensorConstant(at::Tensor t) {
334  // XXX - N^2 warning. This code does the exact same thing as
335  // ConstantPool, which is also N^2 in the size of the constants,
336  // because it doesn't hash any information about the tensors.
337  // We will probably need to optimize this at some point using hashing.
338  for (size_t i = 0; i < tensor_table_.size(); ++i) {
339  if (t.type() == tensor_table_[i].type() && t.equal(tensor_table_[i])) {
340  return i;
341  }
342  }
343  AT_ASSERT(t.is_variable());
344  tensor_table_.emplace_back(std::move(t));
345  return tensor_table_.size() - 1;
346  }
347 
348  std::unordered_set<Node*> seen_constants;
349  void buildConstantList(Node* n, std::vector<Node*>& constants) {
350  for (auto input : n->inputs()) {
351  if (input->node()->kind() == prim::Constant &&
352  seen_constants.count(input->node()) == 0) {
353  constants.push_back(input->node());
354  seen_constants.insert(input->node());
355  }
356  }
357  for (auto b : n->blocks()) {
358  buildConstantList(b, constants);
359  }
360  }
361  void buildConstantList(Block* b, std::vector<Node*>& constants) {
362  for (auto n : b->nodes())
363  buildConstantList(n, constants);
364  buildConstantList(b->return_node(), constants);
365  }
366 
367  // get a new name unique across calls to uniqueName() and
368  // anything we have used.
369  std::unordered_map<std::string, size_t> next_id;
370 
371  std::string genNameImpl(
372  const std::string& candidate,
373  std::unordered_set<std::string>& used) {
374  std::string name = candidate;
375  while (used.count(name) || reserved_names.count(name)) {
376  name = candidate + std::to_string(next_id[name]++);
377  }
378  used.insert(name);
379  return name;
380  }
381  std::string genName(const std::string& candidate) {
382  return genNameImpl(candidate, used_names_);
383  }
384 
385  // methods self.foo are in a different namespace than
386  // global identifiers, so they have a different procedure for finding a
387  // uniquename
388  std::string genMethodName(const std::string& candidate) {
389  return genNameImpl(candidate, used_method_names_);
390  }
391 
392  // unique names might not be valid identifiers,
393  // force them to be by rewriting them
394  static std::string makeValidIdentifier(const std::string& candidate) {
395  std::stringstream ss;
396  if (candidate.size() == 0 || isdigit(candidate[0]))
397  ss << "_";
398  for (char c : candidate) {
399  if (isupper(c) || islower(c) || isdigit(c) || c == '_')
400  ss << c;
401  else
402  ss << '_';
403  }
404  return ss.str();
405  }
406  // if we have to assign 'v' a name, what should it be?
407  // use the uniqueName if it was set, otherwise generate a name.
408  std::string genUniqueNameFor(Value* v) {
409  return genName(
410  v->hasUniqueName() ? makeValidIdentifier(v->uniqueNameBase()) : "_");
411  }
412 
413  // map from Value to how it should be printed at each use
414  std::unordered_map<Value*, std::string> value_names_;
415 
416  std::string useOf(Value* v) const {
417  return value_names_.at(v);
418  }
419  void assignValue(Value* v, const std::string& s) {
420  value_names_[v] = s;
421  }
422  void assignValue(Value* v, Value* w) {
423  assignValue(v, useOf(w));
424  }
425  void assignValuesToTheirUniqueNames(at::ArrayRef<Value*> values) {
426  for (auto v : values) {
427  assignValue(v, genUniqueNameFor(v));
428  }
429  }
430 
431  size_t level = 0;
432  // indent to the current indent level
433  std::ostream& indent() {
434  for (size_t i = 0; i < level; ++i) {
435  out << " ";
436  }
437  return out;
438  }
439 
440  ResourceGuard WithIndented() {
441  level++;
442  return ResourceGuard([this] { level--; });
443  }
444 
445  template <class T0, class T1, class F>
446  void zipWith(at::ArrayRef<T0> list_a, at::ArrayRef<T1> list_b, F action)
447  const {
448  auto it_a = list_a.begin();
449  auto it_b = list_b.begin();
450 
451  if (list_a.size() != list_b.size()) {
452  AT_ERROR("Python printer expected 2 lists of same size");
453  }
454 
455  for (; it_a != list_a.end(); ++it_a, ++it_b) {
456  action(*it_a, *it_b);
457  }
458  }
459 
460  void printValueList(
461  std::ostream& stmt,
463  const char* begin = "",
464  const char* end = "") {
465  stmt << begin;
466  auto delimiter = "";
467  for (auto* value : list) {
468  stmt << delimiter;
469  stmt << useOf(value);
470  delimiter = ", ";
471  }
472  stmt << end;
473  }
474 
475  void printDict(
476  std::ostream& stmt,
477  at::ArrayRef<Value*> key_value_pairs,
478  const char* begin = "{",
479  const char* end = "}") {
480  stmt << begin;
481  auto delimiter = "";
482  for (size_t i = 0; i < key_value_pairs.size(); i += 2) {
483  stmt << delimiter;
484  auto key = key_value_pairs[i];
485  auto value = key_value_pairs[i + 1];
486 
487  stmt << useOf(key) << ": " << useOf(value);
488 
489  delimiter = ", ";
490  }
491  stmt << end;
492  }
493 
494  void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) {
495  if (lhs.size() > 0) {
496  indent();
497  printValueList(out, lhs);
498  out << " = ";
499  printValueList(out, rhs);
500  out << "\n";
501  }
502  }
503 
504  void printIf(IfView stmt) {
505  assignValuesToTheirUniqueNames(stmt.outputs());
506  indent() << "if " << useOf(stmt.cond()) << ":\n";
507  {
508  auto guard = WithIndented();
509  // Print node contents
510  printBlock(stmt.thenBlock(), stmt.outputs().size() > 0);
511  printAssignment(stmt.outputs(), stmt.thenOutputs());
512  }
513  indent() << "else:\n";
514  {
515  auto guard = WithIndented();
516  printBlock(stmt.elseBlock(), stmt.outputs().size() > 0);
517  printAssignment(stmt.outputs(), stmt.elseOutputs());
518  }
519  }
520 
521  // our way of encoding loops makes them difficult to turn back into python
522  // syntax. we have to check properties of the condition and trip count inputs
523  // to figure out which one it initially was
524  static bool shouldEmitAsForLoop(LoopView stmt) {
525  auto trip_count = toIValue(stmt.maxTripCount());
526  auto cond_input = toIValue(stmt.inputCond());
527  auto cond_next = toIValue(stmt.nextCond());
528 
529  bool condition_is_always_true =
530  cond_input && cond_input->toBool() && cond_next && cond_next->toBool();
531  bool trip_count_is_specified = !trip_count || // trip is not a constant
532  trip_count->toInt() !=
533  std::numeric_limits<int64_t>::max() || // it is a constant but not
534  // the default one
535  stmt.currentTripCount()->uses().size() >
536  0; // it is actually being used in the body.
537 
538  if (condition_is_always_true) {
539  // if the trip count was not specified this was a user-written while True:
540  return trip_count_is_specified;
541  } else {
542  // this must be a while loop, but check that there isn't _also_ a trip
543  // count
544  if (trip_count_is_specified) {
545  throw script::ErrorReport(stmt.node()->getSourceLocation())
546  << "loop cannot be printed as python "
547  << "because it has gone through an optimization "
548  << "that combined while and for loops. File a bug.";
549  }
550  return false;
551  }
552  }
553 
554  void printLoop(LoopView stmt) {
555  // Loop carried dependencies are handled by assigning their initial
556  // values to the node->outputs() before the loop,
557  // and assign node->outputs() to the new values at the end of each trip.
558 
559  bool emit_as_for_loop = shouldEmitAsForLoop(stmt);
560 
561  assignValuesToTheirUniqueNames(stmt.carriedOutputs());
562  // Add aliases for loop-carried dependencies
563  zipWith(
564  stmt.bodyCarriedInputs(), // Start at 1 to ignore trip count
565  stmt.carriedOutputs(),
566  [&](Value* block_input, Value* node_output) {
567  assignValue(block_input, node_output);
568  });
569 
570  // Print initial assignments of loop node outputs = loop node inputs
571  printAssignment(stmt.carriedOutputs(), stmt.carriedInputs());
572 
573  assignValuesToTheirUniqueNames(stmt.currentTripCount());
574  // Loop header
575  if (emit_as_for_loop) {
576  indent();
577  out << "for " << useOf(stmt.currentTripCount()) << " in range("
578  << useOf(stmt.maxTripCount()) << "):\n";
579  } else {
580  // note: trip_count_in_block is unused because this is a while loop,
581  // so we reuse the Value* as a stand-in for the loop condition
582  printAssignment(stmt.currentTripCount(), stmt.inputCond());
583  indent();
584  out << "while " << useOf(stmt.currentTripCount()) << ":\n";
585  }
586  // Loop body
587  {
588  ResourceGuard indent = WithIndented();
589  // Update block outputs to block inputs for next loop iteration
590  // skip the assignment to the new condition in for loops because
591  // the condition is always True
592  size_t offset = emit_as_for_loop ? 1 : 0;
593  auto body_block = stmt.bodyBlock();
594  ArrayRef<Value*> loop_carried_block_inputs =
595  body_block->inputs().slice(offset);
596  printBlock(body_block, loop_carried_block_inputs.size() > 0);
597  printAssignment(
598  loop_carried_block_inputs, body_block->outputs().slice(offset));
599  }
600  }
601 
602  bool isLongLine(const std::string& str) {
603  return str.size() + level * 2 >= 40;
604  }
605 
606  bool isLongInline(Node* node) {
607  return output_inline_.count(node) && isLongLine(useOf(node->output()));
608  }
609 
610  bool isNonConstantInline(Value* input) {
611  return input->node()->kind() != prim::Constant &&
612  output_inline_.count(input->node());
613  }
614 
615  // [reordering of inlines]
616  // We inline anything that is semantically legal to inline, but sometimes
617  // we find that these lines get too long. In that case we break the lines
620  // r = foo(x.add_(b), some_long + expression)
621  // wrong!
622  // _0 = some_long + expression
623  // r = foo(x.add_(b), _0) # wrong! _0 runs before mutating add_
624  // legal!
625  // _0 = x.add_(b)
626  // _1 = some_long + expression
627  // r = foo(_0, _1)
629  size_t long_inline_slice = 0;
630  // find the last input that is too long
631  for (size_t i = 0; i < inputs.size(); ++i) {
632  if (isLongInline(inputs[i]->node())) {
633  long_inline_slice = i + 1;
634  }
635  }
636  // un-inline everything through the last long line
637  // constants are ignored since long constants are never inlined in the
638  // first place
639  for (size_t i = 0; i < long_inline_slice; ++i) {
640  if (isNonConstantInline(inputs[i])) {
641  printOutputDefinition(inputs[i]->node(), useOf(inputs[i]));
642  }
643  }
644  }
645 
646  void printOutputDefinition(Node* node, const std::string& str) {
647  assignValuesToTheirUniqueNames(node->outputs());
648  indent();
649  // Print outputs
650  if (node->outputs().size() > 0) {
651  printValueList(out, node->outputs());
652  out << " = ";
653  }
654  out << str << "\n";
655  }
656 
657  // Recursively check contained types for any class dependencies
658  void registerClassDependencies(const TypePtr& type) {
659  if (const auto classType = type->cast<ClassType>()) {
660  addToClassTable(classType);
661  }
662  for (const auto& containedType : type->containedTypes()) {
663  registerClassDependencies(containedType);
664  }
665  }
666 
667  void printNode(Node* node, bool print_const) {
668  // Check for class dependencies. If this node inputs or outputs a class
669  // type, we need to add it to our table of dependencies.
670  for (const auto input : node->inputs()) {
671  registerClassDependencies(input->type());
672  }
673  for (const auto output : node->outputs()) {
674  registerClassDependencies(output->type());
675  }
676 
677  if (!print_const && node->kind() == prim::Constant)
678  return;
679  if (node->kind() == prim::PythonOp) {
680  auto value = static_cast<const PythonOp*>(node);
681  if (enforce_importable_ && value->ignore_on_export) {
682  // Op has been marked as ignored, so insert an error in its place
683  indent();
684  out << "ops.prim.IgnoredPythonOp()\n";
685  return;
686  }
687  }
688  splitLongInlines(node->inputs());
689  switch (node->kind()) {
690  case prim::Return:
691  if (enforce_importable_ && node->inputs().size() != 1) {
692  throw script::ErrorReport(node->getSourceLocation())
693  << "Exportable methods must have a single return value. "
694  << "Normal use of ScriptMethods should enforce this.";
695  }
696  if (node->inputs().size() > 0) {
697  indent();
698  out << "return ";
699  printValueList(out, node->inputs());
700  out << "\n";
701  }
702  break;
703  case prim::Loop:
704  printLoop(LoopView(node));
705  break;
706  case prim::If:
707  printIf(IfView(node));
708  break;
709  case prim::TupleUnpack:
710  case prim::ListUnpack:
711  assignValuesToTheirUniqueNames(node->outputs());
712  indent();
713  // TupleUnpack(unpacked) turns into an assignment op that forces
714  // the unpack to be inserted when parsed back in:
715  // a, b, = unpacked
716  // a, = unpacked # trailing comma forces an unpack to happen
717  if (node->outputs().size() > 0) {
718  printValueList(out, node->outputs(), "", ", = ");
719  }
720  out << useOf(node->input()) << "\n";
721  break;
722  case prim::SetAttr: {
723  const auto obj = node->inputs().at(0);
724  const auto newVal = node->inputs().at(1);
725  const auto type = obj->type()->expect<ClassType>();
726  const auto& attrname = node->s(attr::name);
727  indent();
728  out << useOf(obj) << "." << attrname << " = " << useOf(newVal) << "\n";
729  } break;
730  default:
731  std::stringstream ss;
732  printRHS(ss, node);
733 
734  // we prevent long constants from inlining here.
735  // it is not safe to do the same thing for non-constants here
736  // because of [reordering of inlines]
737  if (output_inline_.count(node) == 0 ||
738  (node->kind() == prim::Constant && isLongLine(ss.str()))) {
739  printOutputDefinition(node, ss.str());
740  } else {
741  // this node is safe to inline, so assign the output value
742  // to that expression directly
743  assignValue(node->output(), ss.str());
744  }
745  }
746  }
747 
748  void printMaybeAnnotatedConstantList(
749  std::ostream& stmt,
750  const char* the_type,
751  size_t list_size,
752  const IValue& the_list) {
753  if (list_size == 0) {
754  stmt << "annotate(List[" << the_type << "], [])";
755  } else {
756  stmt << the_list;
757  }
758  }
759 
760  void printConstant(std::ostream& stmt, const IValue& v) {
761  if (v.isTensor()) {
762  stmt << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
763  } else if (v.isString()) {
764  printQuotedString(stmt, v.toStringRef());
765  } else if (v.isDevice()) {
766  std::stringstream ss;
767  ss << v.toDevice();
768  stmt << "torch.device(";
769  printQuotedString(stmt, ss.str());
770  stmt << ")";
771  } else if (v.isTensorList()) {
772  stmt << "[";
773  const char* delim = "";
774  for (const auto& t : v.toTensorListRef()) {
775  stmt << delim << "CONSTANTS.c" << getOrAddTensorConstant(t);
776  delim = ", ";
777  }
778  stmt << "]";
779  } else if (v.isBoolList()) {
780  printMaybeAnnotatedConstantList(
781  stmt, "bool", v.toBoolListRef().size(), v);
782  } else if (v.isIntList()) {
783  printMaybeAnnotatedConstantList(stmt, "int", v.toIntListRef().size(), v);
784  } else if (v.isDoubleList()) {
785  printMaybeAnnotatedConstantList(
786  stmt, "float", v.toDoubleListRef().size(), v);
787  } else {
788  stmt << v;
789  }
790  }
791 
792  void printNone(std::ostream& stmt, const Node* node) {
793  if (node->output()->type()->isSubtypeOf(NoneType::get())) {
794  stmt << "None";
795  return;
796  }
797  // XXX - when None has an Optional[T] type, we must ensure that type
798  // can be recovered on parsing. It cannot be recovered if it will be
799  // matched to schema with free variables. If it is used only in places
800  // where there is schema and the scheme has no free variables, then we
801  // can recover it without annotation. Otherwise, we annotate None with
802  // the right optional type
803  const auto& uses = node->output()->uses();
804  bool all_usable_schema =
805  std::all_of(uses.begin(), uses.end(), [](const Use& u) {
806  if (auto schema = u.user->maybeSchema()) {
807  if (u.offset >= schema->arguments().size()) {
808  return false;
809  }
810  return !schema->arguments().at(u.offset).type()->hasFreeVariables();
811  }
812  return false;
813  });
814 
815  if (all_usable_schema) {
816  stmt << "None";
817  } else {
818  stmt << "annotate(" << node->output()->type()->python_str() << ", None)";
819  }
820  }
821 
822  // Prints the RHS value of a Node, e.g. `aten.add(x, y)`
823  void printRHS(std::ostream& stmt, Node* node) {
824  switch (node->kind()) {
825  case PythonOp::Kind: {
826  auto value = static_cast<const PythonOp*>(node);
827  if (enforce_importable_) {
828  throw script::ErrorReport(node->getSourceLocation())
829  << "could not export python function call " << value->name()
830  << ". Remove calls to Python functions before export."
831  << "Did you forget add @script annotation? "
832  << "If this is a modulelist, add it to __constants__.";
833  }
834 
835  stmt << "^" << value->name();
836  value->writeScalars(stmt);
837  printValueList(stmt, node->inputs(), "(", ")");
838  } break;
839  case prim::Constant: {
840  if (node->kind() == prim::Constant && !node->mustBeNone()) {
841  IValue v = toIValue(node->output()).value();
842  printConstant(stmt, v);
843  } else {
844  printNone(stmt, node);
845  }
846  } break;
847  case prim::ImplicitTensorToNum: {
848  stmt << "annotate(" << node->output()->type()->python_str() << ", "
849  << useOf(node->input()) << ")";
850  } break;
851  case prim::Int: {
852  printValueList(stmt, node->inputs(), "int(", ")");
853  } break;
854  case prim::Float: {
855  printValueList(stmt, node->inputs(), "float(", ")");
856  } break;
857  case prim::Bool: {
858  printValueList(stmt, node->inputs(), "bool(", ")");
859  } break;
860  case prim::Print: {
861  printValueList(stmt, node->inputs(), "print(", ")");
862  } break;
863  case prim::TupleConstruct: {
864  printValueList(
865  stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
866  } break;
867  case prim::TupleIndex: {
868  stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::index)
869  << "]";
870  } break;
871  case prim::TupleSlice: {
872  stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::beg) << ":"
873  << node->i(attr::end) << "]";
874  } break;
875  case prim::ListConstruct: {
876  // when the list is empty and is not a list of tensors,
877  // we need to annotate it, otherwise it won't be possible
878  // to infer the type on import
879  if (node->inputs().size() == 0 &&
880  !node->output()->type()->isSubtypeOf(TensorType::get())) {
881  stmt << "annotate(" << node->output()->type()->python_str()
882  << ", [])";
883  } else {
884  printValueList(stmt, node->inputs(), "[", "]");
885  }
886  } break;
887  case prim::DictConstruct: {
888  auto dict_type = node->output()->type()->expect<DictType>();
889  bool is_default_type =
890  dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
891  dict_type->getKeyType()->isSubtypeOf(TensorType::get());
892  if (node->inputs().size() == 0 && !is_default_type) {
893  stmt << "annotate(" << node->output()->type()->python_str()
894  << ", {})";
895  } else {
896  printDict(stmt, node->inputs());
897  }
898  } break;
899  case prim::DictIndex: {
900  stmt << "(" << useOf(node->inputs().at(0)) << ")["
901  << useOf(node->inputs().at(1)) << "]";
902  } break;
903  case prim::fork: {
904  // the subgraph gets emitted as another function
905  auto name = genMethodName("__forked_function");
906  std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
907  worklist.emplace_back(
908  [graph, name, this] { printFunctionDefinition(*graph, name); });
909  // and we put a call to fork which invokes that function.
910  stmt << "fork(self." << name;
911  for (Value* v : node->inputs()) {
912  stmt << ", " << useOf(v);
913  }
914  stmt << ")";
915  } break;
916  case prim::Function: {
917  if (enforce_importable_) {
918  throw script::ErrorReport(node->getSourceLocation())
919  << "closures are not exportable";
920  }
921  auto name = genMethodName("__lambda");
922  std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
923  worklist.emplace_back(
924  [graph, name, this] { printFunctionDefinition(*graph, name); });
925  stmt << "self." << name;
926  } break;
927  case prim::CreateObject: {
928  const auto classType = node->output()->type()->expect<ClassType>();
929  stmt << classType->name() << ".__new__(" << classType->name() << ")";
930  } break;
931  case prim::GetAttr: {
932  const auto obj = node->inputs().at(0);
933  const auto classType = obj->type()->expect<ClassType>();
934  const auto& field = node->s(attr::name);
935  stmt << useOf(obj) << "." << field;
936  } break;
937  default: {
938  Symbol kind = node->kind();
939  if (kind.is_aten()) {
940  // special case aten -> torch because we want to rename
941  // the aten namespace, but this change will take more time
942  // doing it here ensures we do not have fix up archives later
943  stmt << "torch." << kind.toUnqualString() << "(";
944  } else {
945  stmt << "ops." << kind.ns().toUnqualString() << "."
946  << kind.toUnqualString() << "(";
947  }
948  const FunctionSchema& schema = node->schema();
949  for (size_t i = 0; i < node->inputs().size(); ++i) {
950  if (i > 0) {
951  stmt << ", ";
952  }
953  auto v = useOf(node->inputs().at(i));
954  // print the kwarg name if it is a kwarg only argument.
955  if (i < schema.arguments().size()) {
956  auto arg = schema.arguments().at(i);
957  if (arg.kwarg_only()) {
958  stmt << arg.name() << "=";
959  }
960  } else {
961  // vararg functions like format can have extra arguments
962  AT_ASSERT(schema.is_vararg());
963  }
964  stmt << v;
965  }
966  stmt << ")";
967  } break;
968  }
969  }
970 
971  std::ostream& printBlock(Block* root, bool block_has_other_statements) {
972  // pythons weird 'pass' syntax creates a bunch of places where we have to
973  // check if this block would be empty. But not everything in a block is a
974  // node. Sometimes if, loop, and return statements will follow this block
975  // and block_has_other_statements == true.
976  if (!block_has_other_statements &&
977  root->nodes().begin() == root->nodes().end()) {
978  indent();
979  out << "pass\n";
980  }
981  for (auto* node : root->nodes()) {
982  printNode(node, /*print_const=*/false);
983  }
984  return out;
985  }
986 
987  void printDefaultValue(
988  const TypePtr& typ,
989  std::ostream& stmt,
990  const IValue& value) {
991  // xxx - many weak script modules store default values for broadcasting
992  // lists that are not actually the same type as the argument. We can only
993  // serialize default values that will implicitly convert to their declared
994  // return type since we do not need to serialize these built-in modules with
995  // their defaults, we just drop them for now.
996  if (typ->kind() == ListType::Kind &&
997  (value.isInt() || value.isDouble() || value.isBool())) {
998  return;
999  }
1000  stmt << "=";
1001  printConstant(stmt, value);
1002  }
1003  void printFunctionDefinition(
1004  Graph& graph,
1005  const std::string& name,
1006  bool is_class = false,
1007  const std::vector<c10::optional<IValue>>& defaults = {},
1008  const std::vector<std::string>& param_names = {}) {
1009  used_names_.clear(); // each graph can reuse local names
1010 
1011  // we always print constants at the top of the function, in the order
1012  // in which they are used.
1013  std::vector<Node*> constants;
1014  buildConstantList(graph.block(), constants);
1015 
1016  // current graph is used to de-dup names within a single graph
1017  scanBlock(graph.block());
1018 
1019  // last param_names.size() arguments to the graph are parameters and not
1020  // actual inputs, we will print these as, e.g. self.foo.bar
1021  // while we print the true_inputs out as parameters
1022  auto true_inputs =
1023  graph.inputs().slice(0, graph.inputs().size() - param_names.size());
1024  auto param_names_it = param_names.begin();
1025  for (auto param : graph.inputs().slice(true_inputs.size())) {
1026  assignValue(param, *param_names_it++);
1027  }
1028  assignValuesToTheirUniqueNames(true_inputs);
1029  auto defaults_offset = defaults.begin();
1030 
1031  indent();
1032  out << "def " << name << "(";
1033 
1034  auto input_iter = true_inputs.begin();
1035  // Print the `self` argument
1036  if (is_class) {
1037  // If this is a class, print the self var without a type annotation,
1038  // following Python convention
1039  AT_ASSERT(true_inputs.size() > 0);
1040  out << useOf(*input_iter);
1041  ++input_iter;
1042 
1043  AT_ASSERT(!defaults_offset->has_value());
1044  ++defaults_offset;
1045  } else {
1046  // If this is not a class, then we need to insert a "self".
1047  out << "self";
1048  }
1049 
1050  // Print the rest of the arguments
1051  for (; input_iter != true_inputs.end(); ++input_iter) {
1052  auto input = *input_iter;
1053  out << ",\n " << useOf(input) << ": " << input->type()->python_str();
1054  if (defaults_offset != defaults.end()) {
1055  const c10::optional<IValue>& def = *defaults_offset++;
1056  if (def) {
1057  printDefaultValue(input->type(), out, *def);
1058  }
1059  }
1060  }
1061 
1062  // have we use all the provided defaults?
1063  AT_ASSERT(defaults_offset == defaults.end());
1064 
1065  out << ") -> " << resultType(graph)->python_str() << ":\n";
1066  {
1067  auto guard = WithIndented();
1068  // Print initial constant table (most are just inlined into their use,
1069  // but some like long strings do get emitted)
1070  for (Node* n : constants) {
1071  printNode(n, /*print_const=*/true);
1072  }
1073  // Print body
1074  printBlock(
1075  graph.block(), graph.block()->return_node()->inputs().size() > 0);
1076  printNode(graph.block()->return_node(), /*print_const=*/false);
1077  }
1078  }
1079 
1080  public:
1082  std::ostream& out_,
1083  std::vector<at::Tensor>& tensor_table,
1084  std::vector<ClassTypePtr>& class_table,
1085  bool enforce_importable)
1086  : out(out_),
1087  tensor_table_(tensor_table),
1088  class_table_(class_table),
1089  enforce_importable_(enforce_importable) {}
1090 
1091  // TODO: we should consider forcing functions to return a single value
1092  // instead of handling this tuple logic both in the compiler and the printer
1093  TypePtr resultType(const Graph& graph) {
1094  if (graph.outputs().size() == 1) {
1095  return graph.outputs().at(0)->type();
1096  } else {
1097  return TupleType::create(
1098  fmap(graph.outputs(), [&](const Value* v) { return v->type(); }));
1099  }
1100  }
1101 
1102  void printFunction(
1103  Graph& graph,
1104  const std::string& name,
1105  bool is_class,
1106  const std::vector<c10::optional<IValue>>& defaults = {},
1107  const std::vector<std::string>& param_names = {}) {
1108  printFunctionDefinition(graph, name, is_class, defaults, param_names);
1109  while (!worklist.empty()) {
1110  out << "\n\n";
1111  auto work = worklist.back();
1112  worklist.pop_back();
1113  work();
1114  }
1115  }
1116  void printMethod(script::Method& method) {
1117  std::unordered_map<IValue*, QualifiedNamePtr> extra_ivalue_names;
1118  createTensorToParameterNameMap(
1119  method.owner(), QualifiedName::create("self"), extra_ivalue_names);
1120  printMethod(method, /*is_class=*/false, extra_ivalue_names);
1121  }
1122  void printMethod(
1123  script::Method& method,
1124  bool is_class,
1125  const std::unordered_map<IValue*, QualifiedNamePtr>& extra_ivalue_names) {
1126  std::vector<std::string> ivalue_names = fmap(
1127  method.initial_ivalues(),
1128  [&](IValue* slot) { return extra_ivalue_names.at(slot)->str(); });
1129  const std::string& name = method.name();
1130  Graph& graph = *method.graph();
1131  auto defaults = fmap(
1132  method.getSchema().arguments(),
1133  [](const Argument& arg) { return arg.default_value(); });
1134  printFunction(graph, name, is_class, defaults, ivalue_names);
1135  }
1136  void printModule(script::Module& module) {
1137  std::unordered_map<IValue*, QualifiedNamePtr> extra_ivalue_names;
1138  createTensorToParameterNameMap(
1139  module, QualifiedName::create("self"), extra_ivalue_names);
1140  for (auto& method : module.get_methods()) {
1141  const std::string& name = method.value()->name();
1142  // we skip __forked_functions because they actually get inlined into their
1143  // callers, exporting them again will lead to more code generated on each
1144  // export
1145  if (name.find("__forked_function") == 0) {
1146  continue;
1147  }
1148  printMethod(*method.value(), /*is_class=*/false, extra_ivalue_names);
1149  }
1150  }
1151 
1152  void printClass(const ClassTypePtr& classType) {
1153  out << "class " << classType->name() << ":\n";
1154  {
1155  const auto guard = WithIndented();
1156  std::unordered_map<IValue*, QualifiedNamePtr> extra_ivalue_names;
1157  for (auto& method : classType->methods()) {
1158  printMethod(*method, /*is_class=*/true, extra_ivalue_names);
1159  }
1160  }
1161  }
1162 };
1163 
1164 TORCH_API void PythonPrint(
1165  std::ostream& out,
1166  const Graph& graph,
1167  std::vector<at::Tensor>& tensor_table,
1168  std::vector<ClassTypePtr>& class_table,
1169  bool enforce_importable) {
1170  PythonPrintPass pp(out, tensor_table, class_table, enforce_importable);
1171  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1172  pp.printFunction(const_cast<Graph&>(graph), "graph", /*is_class=*/false);
1173 }
1174 
1175 TORCH_API void PythonPrint(
1176  std::ostream& out,
1177  const script::Method& method,
1178  std::vector<at::Tensor>& tensor_table,
1179  std::vector<ClassTypePtr>& class_table,
1180  bool enforce_importable) {
1181  PythonPrintPass pp(out, tensor_table, class_table, enforce_importable);
1182  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1183  pp.printMethod(const_cast<script::Method&>(method));
1184 }
1185 
1186 TORCH_API void PythonPrint(
1187  std::ostream& out,
1188  const script::Module& module,
1189  std::vector<at::Tensor>& tensor_table,
1190  std::vector<ClassTypePtr>& class_table,
1191  bool enforce_importable) {
1192  PythonPrintPass pp(out, tensor_table, class_table, enforce_importable);
1193  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
1194  pp.printModule(const_cast<script::Module&>(module));
1195 }
1196 
1197 TORCH_API void PythonPrint(
1198  std::ostream& out,
1199  const ClassTypePtr& classType,
1200  std::vector<at::Tensor>& tensor_table,
1201  std::vector<ClassTypePtr>& class_table,
1202  bool enforce_importable) {
1203  PythonPrintPass pp(out, tensor_table, class_table, enforce_importable);
1204  pp.printClass(classType);
1205 }
1206 
1207 TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
1208  // WARNING: by adding a value to this set, you are asserting
1209  // that you have also added special handling of this symbol to
1210  // the printer above. Not adding handling will cause import and export
1211  // of modules with this new operator to fail. This is only required
1212  // for operators without schema. Prefer registering your operator with
1213  // schema to editing this list here. These cases should only be things
1214  // that require special handling because they do not fit normal schema
1215  const static std::unordered_set<Symbol> handled = {
1216  prim::Constant,
1217  prim::fork,
1218  prim::ListConstruct,
1219  prim::DictConstruct,
1220  prim::ListUnpack,
1221  prim::Print,
1222  prim::PythonOp,
1223  prim::TupleConstruct,
1224  prim::TupleIndex,
1225  prim::DictIndex,
1226  prim::TupleSlice,
1227  prim::TupleUnpack,
1228  prim::CreateObject,
1229  prim::GetAttr,
1230  prim::SetAttr,
1231  };
1232 
1233  // WARNING: by adding a value to this set, you are asserting that your
1234  // primitive is only ever added during optimization and does not need
1235  // to be correctly printed for export (a process that happens before
1236  // optimization passes run)
1237  const static std::unordered_set<Symbol> unneeded = {
1238  c10::onnx::Reshape, // only used in onnx
1239  c10::onnx::Shape, // only used in onnx
1240  prim::AutogradZero, // temporarily inserted by autograd
1241  prim::AutogradAnyNonZero, // temporarily inserted by autograd
1242  prim::AutogradAdd, // temporarily inserted by autograd
1243  prim::ConstantChunk, // optimization pass adds it
1244  prim::DifferentiableGraph, // optimization pass adds it
1245  prim::BroadcastSizes, // optimization pass (fuser) adds it
1246  prim::ChunkSizes, // optimization pass (fuser) adds it
1247  prim::Drop, // used in interpreter only
1248  prim::FusedConcat, // optimization pass adds it
1249  prim::FusionGroup, // optimization pass adds it
1250  prim::Load, // used in interpreter only
1251  prim::MMTreeReduce, // used as an optimization
1252  prim::MMBatchSide, // used as an optimization
1253  prim::Store, // used in interpreter only
1254 
1255  };
1256 
1257  return handled.count(sym) || unneeded.count(sym);
1258 }
1259 
1260 } // namespace jit
1261 } // namespace torch
AT_CPP14_CONSTEXPR ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array...
Definition: ArrayRef.h:161
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
Definition: jit_type.h:17
intrusive_ptr<T> is an alternative to shared_ptr<T> that has better performance because it does the r...
Definition: intrusive_ptr.h:35
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
void splitLongInlines(at::ArrayRef< Value * > inputs)
and it is important that we un-inline all the inputs preceeding the long input: