Caffe2 - C++ API
A deep learning, cross platform ML framework
ir.cpp
1 #include <torch/csrc/jit/ir.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/constants.h>
5 #include <torch/csrc/jit/operator.h>
6 #include <torch/csrc/jit/passes/python_print.h>
7 #include <torch/csrc/jit/script/schema_matching.h>
8 
9 #include <algorithm>
10 #include <iostream>
11 #include <set>
12 #include <sstream>
13 #include <stack>
14 #include <string>
15 #include <unordered_map>
16 #include <unordered_set>
17 #include <utility>
18 
19 namespace torch {
20 namespace jit {
21 
22 void printQuotedString(std::ostream& stmt, const std::string& str);
23 
24 // Constants relating to maintaining the topological index of nodes.
25 //
26 // Lower and upper bounds of the index. Inclusive range.
27 static constexpr topo_position_t kLowerBound = INT64_MIN;
28 static constexpr topo_position_t kUpperBound = INT64_MAX;
29 static constexpr topo_position_t kMidPoint = 0;
30 // How far away to space nodes that are appended to the graph.
31 // should be 2^n, where:
32 // - n is the maximum number of repeated insertions without a re-index
33 // - 2^(64-n) is the maximum number of appends to the end without reindex
34 static constexpr topo_position_t kAppendInterval = 1099511627776ULL /* 2^40 */;
35 
36 // Sigh, see
37 // https://stackoverflow.com/questions/8016780/undefined-reference-to-static-constexpr-char
38 constexpr Symbol PythonOp::Kind;
39 
40 static void printValueRef(std::ostream& out, const Value* n) {
41  out << "%" << n->uniqueName();
42 }
43 
44 // NB: This overload will become ambiguous with the one Caffe2 provides in its
45 // logging, if they ever intersect.
46 template <typename T>
47 std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
48  out << at::ArrayRef<T>{nodes};
49  return out;
50 }
51 
52 template <typename T>
53 static std::ostream& printValueRefs(
54  std::ostream& out,
55  const at::ArrayRef<T>& nodes) {
56  size_t i = 0;
57  for (auto n : nodes) {
58  if (i++ > 0) {
59  out << ", ";
60  }
61  printValueRef(out, n);
62  }
63  return out;
64 }
65 
66 // Can't make these two overloads directly a template, it'll be ambiguous with
67 // the global printer for operator<<.
68 
69 std::ostream& operator<<(
70  std::ostream& out,
71  const at::ArrayRef<const Value*>& nodes) {
72  return printValueRefs(out, nodes);
73 }
74 
75 std::ostream& operator<<(std::ostream& out, const at::ArrayRef<Value*>& nodes) {
76  return printValueRefs(out, nodes);
77 }
78 
80  const ArrayRef<const Value*> values;
81  std::string delim;
84  std::string delim_ = ", ")
85  : values(values), delim(std::move(delim_)) {}
86 };
87 
88 std::ostream& operator<<(std::ostream& out, const_value_list_with_types l) {
89  size_t i = 0;
90  for (auto n : l.values) {
91  if (i++ > 0) {
92  out << l.delim;
93  }
94  printValueRef(out, n);
95  out << " : ";
96  out << *n->type();
97  }
98  return out;
99 }
100 
101 template <typename T>
102 static void printPrimList(std::ostream& out, const std::vector<T>& items) {
103  out << "[";
104  int i = 0;
105  for (auto& item : items) {
106  if (i++ > 0) {
107  out << ", ";
108  }
109  out << item;
110  }
111  out << "]";
112 }
113 
114 static void printStrList(
115  std::ostream& out,
116  const std::vector<std::string>& items) {
117  out << "[";
118  int i = 0;
119  for (auto& item : items) {
120  if (i++ > 0)
121  out << ", ";
122  printQuotedString(out, item);
123  }
124  out << "]";
125 }
126 
127 void Node::printAttrValue(std::ostream& out, const Symbol& name) const {
128  switch (kindOf(name)) {
129  case AttributeKind::f:
130  out << f(name);
131  break;
132  case AttributeKind::fs:
133  printPrimList(out, fs(name));
134  break;
135  case AttributeKind::i:
136  out << i(name);
137  break;
138  case AttributeKind::is:
139  printPrimList(out, is(name));
140  break;
141  case AttributeKind::s:
142  printQuotedString(out, s(name));
143  break;
144  case AttributeKind::ss:
145  printStrList(out, ss(name));
146  break;
147  case AttributeKind::t: {
148  at::Tensor tensor = t(name);
149  // 1-elem tensors are usually boxed scalars, so print them like it
150  if (tensor.numel() == 1) {
151  auto scalar_tensor = tensor.view({}).item();
152  out << "{";
153  if (scalar_tensor.isFloatingPoint()) {
154  out << scalar_tensor.toDouble();
155  } else {
156  out << scalar_tensor.toLong();
157  }
158  out << "}";
159  } else if (tensor.numel() <= max_tensor_display_size) {
160  // TODO: This is awful code. Also it doesn't work on Windows.
161  std::ostringstream tensor_ss;
162  tensor_ss << tensor;
163  std::string tensor_s{tensor_ss.str()};
164  // Remove newlines
165  std::replace(tensor_s.begin(), tensor_s.end(), '\n', ' ');
166  out << tensor_s;
167  } else {
168  out << "<Tensor>";
169  }
170  break;
171  }
172  case AttributeKind::ts:
173  out << "[<Tensors>]";
174  break;
175  case AttributeKind::g:
176  out << "<Graph>";
177  break;
178  case AttributeKind::gs:
179  out << "[<Graphs>]";
180  break;
181  }
182 }
183 
184 void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
185  const {
186  out << "[";
187  auto names = attributeNames();
188  int i = 0;
189  for (auto name : names) {
190  if (ignore_subgraph && name == attr::Subgraph) {
191  continue;
192  }
193  if (i++ > 0) {
194  out << ", ";
195  }
196  // TODO: debugging mode to see the qualifier. We definitely
197  // don't want to print the qualifier since it should always
198  // be attribute, but you might be able to track down a weird
199  // bug by printing it out.
200  out << name.toUnqualString() << "=";
201 
202  printAttrValue(out, name);
203  }
204  out << "]";
205 }
206 
207 static std::ostream& indent(std::ostream& out, size_t level) {
208  for (size_t i = 0; i < level; ++i) {
209  out << " ";
210  }
211  return out;
212 }
213 
214 std::ostream& Node::print(
215  std::ostream& out,
216  size_t level,
217  std::vector<const Node*>* groups) const {
218  auto outs = outputs();
219  indent(out, level) << const_value_list_with_types(outs);
220  out << " = ";
221  if (kind() == prim::PythonOp) {
222  auto* pyOp = static_cast<const ::torch::jit::PythonOp*>(this);
223  out << "^" << pyOp->name();
224  pyOp->writeScalars(out);
225  } else {
226  if (hasAttribute(attr::Subgraph) && groups) {
227  out << kind().toQualString() << "_" << groups->size();
228  if (numAttributes() > 1 && kind() != prim::DifferentiableGraph) {
229  printAttributes(out, /*ignore_subgraph=*/true);
230  }
231  groups->push_back(this);
232  } else {
233  out << kind().toQualString();
234  if (hasAttributes()) {
235  printAttributes(out);
236  }
237  }
238  }
239 
240  out << "(" << inputs() << ")";
241  std::string scName = scopeName();
242  if (scName.empty()) {
243  out << "\n";
244  } else {
245  out << ", ";
246  out << "scope: " << scName << "\n";
247  }
248  for (size_t i = 0; i < blocks().size(); ++i) {
249  auto b = blocks()[i];
250  indent(out, level + 1) << "block" << i << "("
251  << const_value_list_with_types(b->inputs())
252  << "):\n";
253  for (auto nested : b->nodes()) {
254  nested->print(out, level + 2, groups);
255  }
256  indent(out, level + 2) << "-> (" << b->outputs() << ")\n";
257  }
258  return out;
259 }
260 
261 std::ostream& operator<<(std::ostream& out, const Node& n) {
262  return n.print(out, 0, nullptr);
263 }
264 
265 std::ostream& operator<<(std::ostream& out, const Graph& g) {
266  out << "graph(" << const_value_list_with_types(g.inputs(), ",\n ")
267  << "):\n";
268  std::vector<const Node*> groups;
269  for (auto n : g.nodes()) {
270  n->print(out, 1, &groups);
271  }
272  out << " return (" << g.outputs() << ")\n";
273  size_t i = 0;
274  for (auto fg : groups) {
275  out << "with " << fg->kind().toQualString() << "_" << i++ << " = "
276  << *fg->g(attr::Subgraph);
277  }
278  /*
279  // Uncomment this to debug all_nodes issues
280  {
281  out << "\n";
282  out << "all_nodes:\n";
283  for (auto& n : g.all_nodes) {
284  printNode(out, const_cast<Node*>(n), nullptr);
285  }
286  }
287  */
288  return out;
289 }
290 
291 std::ostream& Graph::prettyPrint(std::ostream& out) {
292  std::vector<at::Tensor> tensor_table;
293  std::vector<ClassTypePtr> class_table;
294  PythonPrint(out, *this, tensor_table, class_table);
295  return out;
296 }
297 
298 void Graph::dumpPretty() {
299  std::vector<at::Tensor> tensor_table;
300  std::vector<ClassTypePtr> class_table;
301  PythonPrint(std::cout, *this, tensor_table, class_table);
302 }
303 
304 static void checkSameDevice(const Node* node) {
305  bool has_device = false;
306  c10::optional<at::Device> device = c10::nullopt;
307  auto checkValue = [&](const Value* v) {
308  if (CompleteTensorTypePtr type = v->type()->cast<CompleteTensorType>()) {
309  if (!has_device) {
310  has_device = true;
311  device = type->device();
312  } else {
313  AT_ASSERT(device == type->device());
314  }
315  }
316  };
317  for (auto input : node->inputs()) {
318  checkValue(input);
319  }
320  for (auto output : node->outputs()) {
321  checkValue(output);
322  }
323 }
324 
325 using node_set = std::set<const Node*>;
326 #define ALL_OF(container) container.begin(), container.end()
327 
328 // These functions purposely operate on the internal members directly, to force
329 // you to think about how the invariants change if you change the data
330 // representation (even if the external API does not change.)
331 
332 // NB: This assert is written to assume you don't have any unattached
333 // nodes. Unattached nodes can occur while manipulations to the
334 // graph are occurring.
335 void Node::lint() const {
336  // Node invariants
337  // - if node should live in list, nodes_iter is consistent
338  // - Inputs are all marked as a use by the nodes they refer to
339  // - Owning graph is non-null and consistent
340  // - The "Select" invariant, when the node is MultiReturn
341  //
342  // The handle invariant:
343  // If a node takes a handle as an input, it is always the
344  // LAST input of the node. There is at most one handle input.
345 
346  {
347  size_t i = 0;
348  for (auto input : inputs_) {
349  // WARNING: O(n^2)
350  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
351  AT_ASSERT(
352  std::find(ALL_OF(input->uses_), Use(const_cast<Node*>(this), i)) !=
353  input->uses_.end());
354  AT_ASSERT(graph_->all_nodes.count(this) == 1);
355  i++;
356  }
357  }
358 
359  for (auto o : outputs()) {
360  size_t i = 0;
361  for (auto use : o->uses()) {
362  // Use invariants
363  // - Use is consistent with inputs
364  // - Every user node is live (checked in Graph)
365  AT_ASSERT(use.user->inputs_[use.offset] == o);
366  i++;
367  }
368  }
369 
370  // Node subclass invariants
371  switch (kind()) {
372  case prim::Constant:
373  AT_ASSERT(inputs_.size() == 0);
374  break;
375  case prim::Return:
376  // Return uses is zero
377  AT_ASSERT(outputs().size() == 0);
378  break;
379  case prim::Param:
380  // Param inputs is zero
381  AT_ASSERT(inputs_.size() == 0);
382  break;
383  case prim::PythonOp: {
384  // Python operator cconv is correct
385  size_t n_scalars = 0, n_tensors = 0;
386  auto* value = static_cast<const PythonOp*>(this);
387  for (auto c : value->cconv) {
388  if (c == 'c') {
389  n_scalars++;
390  } else if (c == 'd') {
391  n_tensors++;
392  } else {
393  AT_ASSERT(0);
394  }
395  AT_ASSERT(static_cast<bool>(value->pyobj));
396  }
397  AT_ASSERT(n_scalars == value->scalar_args.size());
398  AT_ASSERT(n_tensors == inputs_.size());
399  break;
400  }
401  case prim::Eval:
402  // TODO: add invariants
403  // TODO: It's not good for these ops to be top-level, it makes cases
404  // longer.
405  break;
406  case prim::FusionGroup:
407  checkSameDevice(this);
408  // TODO: Typecheck the parameters
409  g(attr::Subgraph)->lint();
410  break;
411  }
412 }
413 
414 // TODO: When lint fails, give better indication about which
415 // instruction triggered the failure.
416 void Graph::lint() const {
417  // Graph invariants
418 
419  // Uncomment the following to see the graph
420  // std::cout << *const_cast<Graph*>(this);
421 
422  // nodes
423  // - nodes_ is a valid topological ordering for inputs
424  // - No repeated nodes
425  // - Params and return do NOT occur in nodes
426  // - next_unique_ is greater than all uniques in graph
427  // - uniques in all_nodes are unique
428  // - every use will occur later in the topsort
429 
430  struct LintScope {
431  LintScope() = default;
432  LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
433  bool contains(const Value* v) {
434  return values.count(v) > 0 || (parent && parent->contains(v));
435  }
436  bool contains(const Node* n) {
437  return nodes.count(n) > 0 || (parent && parent->contains(n));
438  }
439  void insert(const Value* v) {
440  AT_ASSERT(!contains(v));
441  values.insert(v);
442  }
443  void insert(const Node* n) {
444  AT_ASSERT(!contains(n));
445  nodes.insert(n);
446  }
447  std::unique_ptr<LintScope> parent;
448 
449  private:
450  std::unordered_set<const Value*> values;
451  std::unordered_set<const Node*> nodes;
452  };
453  // Struct enables mutual recursion in linting methods.
454  // Putting it inside Graph::lint enables access to private Graph members
455  struct LintImpl {
456  LintImpl(const Graph& g)
457  : g(g),
458  scope(new LintScope()),
459  all_nodes_set(ALL_OF(g.all_nodes)) {} // NB: all_nodes is *unordered*
460  const Graph& g;
461  std::unique_ptr<LintScope> scope;
462  std::unordered_set<size_t> seen_uniques;
463  std::unordered_map<const Node*, int64_t> anticipated_uses;
464  node_set all_nodes_set;
465  node_set sum_set;
466 
467  void check_value(const Value* v) {
468  scope->insert(v);
469  auto b2 = seen_uniques.insert(v->unique());
470  AT_ASSERT(b2.second); // insertion took place
471  AT_ASSERT(v->unique() < g.next_unique_);
472 
473  for (auto use : v->uses()) {
474  AT_ASSERT(!scope->contains(use.user));
475  AT_ASSERT(g.all_nodes.count(use.user) == 1);
476  anticipated_uses[use.user]++; // int default constructs to 0
477  }
478  }
479  void check_node(const Node* n) {
480  for (auto input : n->inputs_) {
481  if (!scope->contains(input)) {
482  AT_ASSERTM(0, input->unique(), " not in scope");
483  }
484  }
485  AT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
486  anticipated_uses[n] = -1; // we saw the anticipated user!
487  scope->insert(n);
488  for (auto block : n->blocks()) {
489  std::unique_ptr<LintScope> new_scope(new LintScope(std::move(scope)));
490  scope = std::move(new_scope);
491  check_block(block);
492  scope = std::move(scope->parent);
493  }
494  size_t i = 0;
495  for (auto o : n->outputs()) {
496  AT_ASSERT(o->node() == n);
497  AT_ASSERT(i++ == o->offset_);
498  check_value(o);
499  }
500  n->lint();
501  }
502  void check_block(const Block* b) {
503  // Check topological ordering
504  AT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
505  auto curNode = *b->nodes().begin();
506  while (curNode != b->return_node()) {
507  AT_ASSERT(curNode->isBefore(curNode->next()));
508  curNode = curNode->next();
509  }
510 
511  for (auto input : b->inputs()) {
512  check_value(input);
513  AT_ASSERT(input->node()->kind_ == prim::Param);
514  }
515 
516  for (auto n : b->nodes()) {
517  AT_ASSERT(n->kind_ != prim::Param);
518  AT_ASSERT(n->kind_ != prim::Return);
519  check_node(n);
520  }
521 
522  AT_ASSERT(b->output_->kind() == prim::Return);
523  check_node(b->output_);
524 
525  // all_nodes
526  // - inputs_, output_ and nodes_ are all included in all_nodes
527  // - all_nodes does not contain dead nodes??? (likely to be temporarily
528  // suspended). Weaker: all_nodes contains all inputs and returns
529  // - only one return node???
530 
531  node_set nodes_set(ALL_OF(b->nodes()));
532  node_set inputs_set{b->input_};
533  node_set output_set{b->output_};
534  // TODO: Make a more type safe std::includes wrapper which disallows use
535  // on non-ordered containers
536  AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
537  AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
538  AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));
539 
540  sum_set.insert(ALL_OF(nodes_set));
541  sum_set.insert(ALL_OF(inputs_set));
542  sum_set.insert(ALL_OF(output_set));
543  }
544  void check_graph() {
545  node_set all_nodes_set(
546  ALL_OF(g.all_nodes)); // NB: all_nodes is *unordered*
547 
548  check_block(g.block_);
549  for (auto kv : anticipated_uses) {
550  AT_ASSERT(kv.second == -1);
551  }
552  AT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)));
553  }
554  };
555  LintImpl(*this).check_graph();
556 }
557 
558 void Graph::dump() const {
559  std::cout << *this << "\n";
560 }
561 
562 void LintGraph(std::shared_ptr<Graph>& graph) {
563  graph->lint();
564 }
565 
566 Block::Block(Graph* graph_, Node* node_)
567  : graph_(graph_),
568  output_(initOutput(graph_->create(prim::Return, 0))),
569  input_(graph_->create(prim::Param, 0)),
570  owning_node_(node_) {
571  graph_->all_blocks.emplace(this);
572  output_->owning_block_ = this;
573  output_->topo_position_ = kUpperBound;
574  input_->owning_block_ = this;
575  input_->topo_position_ = kLowerBound;
576 }
577 
578 void Block::reIndexTopology() {
579  auto curPos = kLowerBound;
580  for (auto node : nodes()) {
581  AT_ASSERT(curPos <= (kUpperBound - kAppendInterval));
582  curPos += kAppendInterval;
583  node->topo_position_ = curPos;
584  }
585 }
586 
587 void Block::cloneFrom(Block* src, std::function<Value*(Value*)> value_map) {
588  std::unordered_map<Value*, Value*> local_map;
589  auto env = [&](Value* v) {
590  auto it = local_map.find(v);
591  if (it != local_map.end()) {
592  return it->second;
593  }
594  return value_map(v);
595  };
596 
597  auto graph = owningGraph();
598  for (auto input : src->inputs()) {
599  local_map[input] = this->addInput()->copyMetadata(input);
600  }
601 
602  for (auto node : src->nodes()) {
603  auto new_node = this->appendNode(graph->createClone(node, env));
604  for (size_t i = 0; i < node->outputs().size(); ++i) {
605  auto oo = node->outputs()[i];
606  auto no = new_node->outputs()[i];
607  local_map[oo] = no;
608  no->copyMetadata(oo);
609  }
610  }
611  for (auto output : src->outputs()) {
612  this->registerOutput(env(output));
613  }
614 }
615 
616 void Block::destroy() {
617  // we cannot destroy the output because it is used as the sentinel
618  // for the nodes() list and has to remain valid for the loop
619  output_->removeAllInputs();
620  for (auto it = this->nodes().reverse().begin(),
621  end = this->nodes().reverse().end();
622  it != end;
623  ++it) {
624  it.destroyCurrent();
625  }
626  output_->destroy();
627  input_->destroy();
628  graph_->freeBlock(this);
629 }
630 
631 std::shared_ptr<Graph> Graph::copy() {
632  auto new_g = std::make_shared<Graph>();
633  auto env = [](Value* v) -> Value* {
634  AT_ERROR(
635  "Graph::copy() encountered a use of a value not in scope. Run lint!");
636  };
637  new_g->block()->cloneFrom(this->block(), env);
638  return new_g;
639 }
640 
641 bool Value::mustBeNone() const {
642  return node_->mustBeNone();
643 }
644 
645 std::string Value::uniqueNameBase() const {
646  std::string name = uniqueName();
647  std::string name_base = name;
648  auto last_dot_pos = name.find_last_of('.');
649  if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
650  if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
651  std::string::npos) {
652  name_base = name.substr(0, last_dot_pos);
653  }
654  }
655  return name_base;
656 }
657 
658 bool Value::isValidName(const std::string& name) {
659  // Empty strings are legal
660  if (!name.size()) {
661  return true;
662  }
663 
664  // Numbers are not legal
665  if (name.find_first_not_of("0123456789") == std::string::npos) {
666  return false;
667  }
668 
669  return true;
670 }
671 
672 Value* Value::setUniqueName(const std::string& name) {
673  if (!isValidName(name)) {
674  throw std::runtime_error("Invalid name: '" + name + "'");
675  }
676 
677  auto& names = node()->owningGraph()->unique_names_;
678 
679  // clear any old name from the map
680  if (hasUniqueName()) {
681  names.erase(unique_name_);
682  unique_name_ = "";
683  }
684 
685  // allow "" to clear the uniquename
686  if (name == "") {
687  return this;
688  }
689 
690  // if someone else has this name, then rename the other value
691  auto old_owner_of_name = names.find(name);
692  if (old_owner_of_name != names.end()) {
693  size_t suffix = 1;
694  std::string name_base = name;
695  auto last_dot_pos = name.find_last_of('.');
696  if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
697  if (name.find_first_not_of("0123456789", last_dot_pos + 1) ==
698  std::string::npos) {
699  suffix = std::stoll(name.substr(last_dot_pos + 1));
700  name_base = name.substr(0, last_dot_pos);
701  }
702  }
703  std::string replacement_name;
704  do {
705  std::stringstream ss;
706  ss << name_base << "." << suffix++;
707  replacement_name = ss.str();
708  } while (names.count(replacement_name) > 0);
709  old_owner_of_name->second->setUniqueName(replacement_name);
710  }
711 
712  names[name] = this;
713  unique_name_ = name;
714  return this;
715 }
716 
717 Value* Value::copyMetadata(Value* from) {
718  setType(from->type());
719  if (from->hasUniqueName()) {
720  setUniqueName(from->uniqueName());
721  }
722  return this;
723 }
724 
725 void Value::replaceFirstUseWith(Value* newValue) {
726  AT_ASSERT(owningGraph() == newValue->owningGraph());
727  auto u = uses()[0];
728  u.user->inputs_[u.offset] = newValue;
729  newValue->uses_.push_back(u);
730  uses_.erase(uses_.begin());
731 }
732 
733 void Value::replaceAllUsesWith(Value* newValue) {
734  while (!uses().empty()) {
735  replaceFirstUseWith(newValue);
736  }
737 }
738 
739 size_t findArgument(const FunctionSchema& the_schema, Symbol name) {
740  auto name_str = name.toUnqualString();
741  for (size_t i = 0; i < the_schema.arguments().size(); ++i) {
742  const Argument* arg = &the_schema.arguments()[i];
743  if (arg->name() == name_str) {
744  return i;
745  }
746  }
747  throw std::runtime_error(
748  std::string("Couldn't find an argument called ") + name.toQualString());
749 }
750 
751 c10::optional<IValue> Node::get(Symbol name) const {
752  return toIValue(namedInput(name));
753 }
754 
755 Value* Node::namedInput(Symbol name) const {
756  return input(findArgument(schema(), name));
757 }
758 
759 bool Node::matches(
760  const char* signature_literal,
761  at::ArrayRef<Symbol> const_inputs) const {
762  if (!sig(signature_literal).matches(this)) {
763  return false;
764  }
765  for (Symbol s : const_inputs) {
766  if (!is_constant(s)) {
767  return false;
768  }
769  }
770  return true;
771 }
772 
773 bool Node::mustBeNone() const {
774  return kind_ == prim::Constant && !this->hasAttributes() &&
775  (output()->type()->cast<OptionalType>() ||
776  output()->type() == NoneType::get());
777 }
778 
779 void Node::dump() const {
780  std::cout << *this << "\n";
781 }
782 
783 void Node::findSchema() const {
784  schema_ = &getOperatorFor(this).schema();
785 }
786 
787 const FunctionSchema* Node::maybeSchema() const {
788  if (!schema_) {
789  if (auto op = findOperatorFor(this)) {
790  schema_ = &op->schema();
791  }
792  }
793  return schema_;
794 }
795 
796 bool Node::isNondeterministic() const {
797  static const OperatorSet nondeterministic_ops = {
798  "aten::dropout(Tensor input, float p, bool train) -> Tensor",
799  "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
800  "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
801  "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
802  "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
803  "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
804  "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
805  "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
806  "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
807  "aten::poisson(Tensor self, Generator? generator) -> Tensor",
808  "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
809  "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
810  "aten::rand(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
811  "aten::rand_like(Tensor self) -> Tensor",
812  "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
813  "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
814  "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
815  "aten::randint_like(Tensor self, int high) -> Tensor",
816  "aten::randint_like(Tensor self, int low, int high) -> Tensor",
817  "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
818  "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
819  "aten::randn(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
820  "aten::randn_like(Tensor self) -> Tensor",
821  "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
822  "aten::randperm(int n, *, int? dtype, int? layout, Device? device) -> Tensor"};
823 
824  if (nondeterministic_ops.find(this) == nullptr) {
825  return false;
826  }
827  // Dropout with train = False is deterministic
828  if (matches("aten::dropout(Tensor input, float p, bool train) -> Tensor") &&
829  is_constant(attr::train) && !get<bool>(attr::train).value()) {
830  return false;
831  }
832  return true;
833 }
834 
835 bool Node::hasSideEffects() const {
836  switch (kind_) {
837  case prim::PythonOp:
838  case prim::IgnoredPythonOp:
839  case prim::Print:
840  case prim::RaiseException:
841  case prim::SetAttr:
842  case aten::warn:
843  return true;
844  }
845  return false;
846 }
847 
848 // Assign this node a topological position, to facilitate fast isBefore() and
849 // isAfter() queries. Must be called right after a node is inserted into the
850 // node list.
851 //
852 // The basic scheme is: assign every node a position (uint64_t). The common
853 // case (appending to the end of the graph) is made more efficient by advancing
854 // a fixed interval past the previous node and placing `this` there. Otherwise,
855 // assign `this` a position at the midpoint between its prev() and next()
856 // nodes.
857 //
858 // If we ever run out of space (by, e.g. inserting too much in place), we
859 // reindex by spreading out all the nodes again.
860 void Node::assignTopoPosition() {
861  auto returnNode = owningBlock()->return_node();
862  const auto prevPos = prev()->topo_position_;
863  const auto nextPos = next()->topo_position_;
864 
865  // Append to the end of the graph
866  if (next() == returnNode) {
867  if (next() == prev()) {
868  // the node list is empty, assign the first position
869  topo_position_ = kMidPoint;
870  return;
871  }
872 
873  if (prevPos >= (kUpperBound - kAppendInterval)) {
874  // we're running off the edge
875  owningBlock()->reIndexTopology();
876  return;
877  }
878 
879  topo_position_ = prevPos + kAppendInterval;
880 
881  // Prepend to the graph
882  } else if (prev() == returnNode) {
883  // next() is the first element in the block list
884  if (nextPos <= (kLowerBound + kAppendInterval)) {
885  // we're running off the edge
886  owningBlock()->reIndexTopology();
887  return;
888  }
889 
890  topo_position_ = nextPos - kAppendInterval;
891 
892  // insert between two existing nodes
893  } else {
894  const auto posBetween = prevPos + (nextPos - prevPos) / 2;
895  if (posBetween == prevPos) {
896  // There was no room
897  owningBlock()->reIndexTopology();
898  return;
899  }
900  topo_position_ = posBetween;
901  }
902 }
903 
904 Node::Node(Graph* graph_, NodeKind kind_)
905  : kind_(kind_),
906  graph_(graph_),
907  owning_block_(nullptr),
908  scope_(graph_->current_scope_),
909  schema_(nullptr),
910  topo_position_(0) {
911  graph_->all_nodes.emplace(this);
912 }
913 
914 void Node::eraseOutput(size_t i) {
915  AT_ASSERT(i < outputs_.size());
916  AT_ASSERT(outputs_[i]->uses().empty());
917  schema_ = nullptr;
918  Value* n = outputs_[i];
919  outputs_.erase(outputs_.begin() + i);
920  owningGraph()->freeValue(n);
921  for (size_t j = i; j < outputs_.size(); j++) {
922  outputs_[j]->offset_--;
923  }
924 }
925 
926 Block* Node::addBlock() {
927  schema_ = nullptr;
928  blocks_.push_back(new Block(owningGraph(), this));
929  return blocks_.back();
930 }
931 
932 void Node::eraseBlock(size_t i) {
933  AT_ASSERT(i < blocks_.size());
934  schema_ = nullptr;
935  Block* n = blocks_[i];
936  blocks_.erase(blocks_.begin() + i);
937  n->destroy();
938 }
939 
940 void Node::destroy() {
941  while (!outputs().empty()) {
942  eraseOutput(outputs().size() - 1);
943  }
944  while (!blocks().empty()) {
945  eraseBlock(blocks().size() - 1);
946  }
947  removeAllInputs();
948  if (inBlockList()) {
949  removeFromList();
950  }
951  graph_->freeNode(this);
952 }
953 
954 void Node::cloneFrom(Node* s) {
955  setSourceLocation(s->getSourceLocation());
956  if (s->scope_ && !s->scope_->isBlank()) {
957  scope_ = s->scope_;
958  }
959  copyAttributes(*s);
960 }
961 
962 void Node::replaceAllUsesWith(Node* n) {
963  AT_ASSERT(outputs().size() == n->outputs().size());
964  size_t nOutputs = outputs().size();
965  for (size_t i = 0; i < nOutputs; i++) {
966  outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
967  }
968 }
969 
970 Value* Node::insertInput(size_t i, Value* value) {
971  AT_ASSERT(graph_ == value->owningGraph());
972  schema_ = nullptr;
973  // First we update the offsets for all existing inputs that will reside
974  // after the one we're inserting. Concretely, these are the inputs at
975  // indices [i, # input). Since we're inserting one input before all of
976  // these inputs, increment their use offsets for this value by 1
977  for (size_t use_itr = i; use_itr < inputs_.size(); ++use_itr) {
978  // See Note [User node does not uniquely identify use]
979  auto use = findUseForInput(use_itr);
980  use->offset += 1;
981  }
982  // Insert the actual input at the specified index
983  inputs_.insert(inputs_.begin() + i, value);
984  // Register the new use of the value we're inserted as an input.
985  value->uses_.emplace_back(this, i);
986  return value;
987 }
988 
989 Value* Node::addInput(Value* value) {
990  AT_ASSERT(graph_ == value->owningGraph());
991  schema_ = nullptr;
992  value->uses_.emplace_back(this, inputs_.size());
993  inputs_.push_back(value);
994  return value;
995 }
996 
997 Value* Node::replaceInput(size_t i, Value* newValue) {
998  AT_ASSERT(newValue->owningGraph() == graph_);
999  schema_ = nullptr;
1000  Value* old = dropInput(i);
1001  inputs_[i] = newValue;
1002  newValue->uses_.emplace_back(this, i);
1003  return old;
1004 }
1005 
1006 void Node::replaceInputWith(Value* from, Value* to) {
1007  AT_ASSERT(from->owningGraph() == graph_);
1008  AT_ASSERT(to->owningGraph() == graph_);
1009  schema_ = nullptr;
1010  size_t i = 0;
1011  for (auto input : inputs()) {
1012  if (input == from) {
1013  replaceInput(i, to);
1014  }
1015  i++;
1016  }
1017 }
1018 
1019 Value* Node::addOutput() {
1020  outputs_.push_back(new Value(this, outputs_.size()));
1021  schema_ = nullptr;
1022  return outputs_.back();
1023 }
1024 
1025 Value* Node::insertOutput(size_t i) {
1026  schema_ = nullptr;
1027  outputs_.insert(outputs_.begin() + i, new Value(this, i));
1028  for (size_t itr = i + 1; itr < outputs_.size(); ++itr) {
1029  outputs_[itr]->setOffset(outputs_[itr]->offset() + 1);
1030  }
1031  return outputs_.at(i);
1032 }
1033 
1034 bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
1035  if (this->owningBlock() == n->owningBlock()) {
1036  if (moveSide == MoveSide::BEFORE) {
1037  return this->topo_position_ < n->topo_position_;
1038  }
1039 
1040  if (moveSide == MoveSide::AFTER) {
1041  return this->topo_position_ > n->topo_position_;
1042  }
1043 
1044  AT_ASSERT(this == n);
1045  return false;
1046  }
1047 
1048  // These nodes don't share a common block. Traverse the blockchains upward
1049  // until we find the first common block.
1050  auto lhs = this;
1051  while (lhs) {
1052  AT_ASSERT(lhs->owningBlock());
1053 
1054  auto rhs = n;
1055  while (rhs) {
1056  if (!rhs->owningBlock()) {
1057  break;
1058  }
1059 
1060  if (lhs->owningBlock() == rhs->owningBlock()) {
1061  return lhs->isBeforeOrAfter(rhs, moveSide);
1062  }
1063  rhs = rhs->owningBlock()->owningNode();
1064  }
1065 
1066  lhs = lhs->owningBlock()->owningNode();
1067  }
1068  // should never reach here, since both nodes are ultimately in the same graph
1069  AT_ASSERT(false);
1070 }
1071 
1072 bool Node::isBefore(const Node* n) const {
1073  return isBeforeOrAfter(n, MoveSide::BEFORE);
1074 }
1075 
1076 bool Node::isAfter(const Node* n) const {
1077  return isBeforeOrAfter(n, MoveSide::AFTER);
1078 }
1079 
1080 Node* Node::insertBefore(Node* n) {
1081  AT_ASSERT(n->inBlockList());
1082  insertAfter(n->prev());
1083  return this;
1084 }
1085 
1086 Node* Node::insertAfter(Node* n) {
1087  AT_ASSERT(!inBlockList() && n->inBlockList());
1088  AT_ASSERT(n->owningBlock());
1089  this->owning_block_ = n->owningBlock();
1090  Node* next = n->next();
1091  n->next() = this;
1092  this->prev() = n;
1093  this->next() = next;
1094  next->prev() = this;
1095  assignTopoPosition();
1096  return this;
1097 }
1098 
1099 void Node::moveAfter(Node* n) {
1100  removeFromList();
1101  insertAfter(n);
1102 }
1103 
1104 void Node::moveBefore(Node* n) {
1105  removeFromList();
1106  insertBefore(n);
1107 }
1108 
1109 void Node::removeInput(size_t i) {
1110  schema_ = nullptr;
1111  dropInput(i);
1112  // everything after this input shifts left,
1113  // so we need to update their use offsets to match
1114  for (size_t j = i + 1; j < inputs_.size(); j++) {
1115  auto it = findUseForInput(j);
1116  it->offset--;
1117  }
1118  inputs_.erase(inputs_.begin() + i);
1119 }
1120 
1121 void Node::removeAllInputs() {
1122  schema_ = nullptr;
1123  for (size_t i = 0; i < inputs().size(); ++i) {
1124  dropInput(i);
1125  }
1126  inputs_.clear();
1127 }
1128 
1129 use_list::iterator Node::findUseForInput(size_t i) {
1130  auto& input_uses = inputs_[i]->uses_;
1131  // O(N) on the use list, but unless we get nodes with +100 uses
1132  // vector traversal still is probably faster than linked list
1133  auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
1134  AT_ASSERT(use_it != input_uses.end());
1135  return use_it;
1136 }
1137 
1138 Value* Node::dropInput(size_t i) {
1139  AT_ASSERT(i < inputs_.size());
1140  auto input_node = inputs_[i];
1141  auto use_it = findUseForInput(i);
1142  input_node->uses_.erase(use_it);
1143  inputs_[i] = nullptr;
1144  return input_node;
1145 }
1146 
1147 void Node::removeFromList() {
1148  AT_ASSERT(inBlockList());
1149  this->owning_block_ = nullptr;
1150  Node* next = this->next();
1151  Node* prev = this->prev();
1152  prev->next() = next;
1153  next->prev() = prev;
1154  this->next() = nullptr;
1155  this->prev() = nullptr;
1156 }
1157 
1158 inline const SourceRange& fakeRange() {
1159  static SourceRange range(
1160  std::make_shared<std::string>("<internally-created-node>"), 0, 1);
1161  return range;
1162 }
1163 
1164 Value* Graph::insert(
1165  Symbol opname,
1167  at::ArrayRef<NamedValue> kwargs,
1168  const c10::optional<SourceRange>& range) {
1169  return script::emitBuiltinCall(
1170  range.value_or(fakeRange()),
1171  *this,
1172  opname,
1173  c10::nullopt,
1174  args,
1175  kwargs,
1176  /*required=*/true);
1177 }
1178 
1179 Node* Graph::create(NodeKind kind, size_t num_outputs) {
1180  // NB: Node constructor adds node to all_nodes
1181  auto n = new Node(this, kind);
1182  for (size_t i = 0; i < num_outputs; i++) {
1183  n->addOutput();
1184  }
1185  return n;
1186 }
1187 
1188 Node* Graph::create(
1189  NodeKind kind,
1190  ArrayRef<Value*> inputs,
1191  size_t num_outputs) {
1192  auto n = create(kind, num_outputs);
1193  for (auto i : inputs) {
1194  n->addInput(i);
1195  }
1196  return n;
1197 }
1198 
1199 Node* Graph::createAutogradZero() {
1200  return create(prim::AutogradZero);
1201 }
1202 
1203 Node* Graph::createNone(TypePtr typ) {
1204  Node* n = create(prim::Constant);
1205  n->output()->setType(OptionalType::create(std::move(typ)));
1206  return n;
1207 }
1208 
1209 Node* Graph::createFusionGroup() {
1210  auto n = create(prim::FusionGroup, 0);
1211  n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
1212  return n;
1213 }
1214 
1215 Node* Graph::createTuple(
1216  at::ArrayRef<Value*> values,
1217  c10::OptNameList field_names) {
1218  auto types = fmap(values, [](Value* v) { return v->type(); });
1219  auto tt = TupleType::create(std::move(types), std::move(field_names));
1220  auto n = create(prim::TupleConstruct, values);
1221  n->output()->setType(tt);
1222  return n;
1223 }
1224 
1225 Node* Graph::createTupleUnpack(Value* v) {
1226  TupleTypePtr tt = v->type()->expect<TupleType>();
1227  auto n = create(prim::TupleUnpack, {v}, 0);
1228  for (auto& element : tt->elements()) {
1229  n->addOutput()->setType(element);
1230  }
1231  return n;
1232 }
1233 
1234 Node* Graph::createTupleIndex(Value* tup, int64_t index) {
1235  auto n = create(prim::TupleIndex, {tup});
1236  n->i_(attr::index, index);
1237  auto tuple_type = tup->type()->expect<TupleType>();
1238  n->output()->setType(tuple_type->elements().at(index));
1239  return n;
1240 }
1241 
1242 Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) {
1243  auto n = create(prim::TupleSlice, {tup});
1244  auto tuple_type = tup->type()->expect<TupleType>();
1245  n->i_(attr::beg, beg);
1246  n->i_(attr::end, end);
1247  std::vector<TypePtr> output_types;
1248  for (auto i = beg; i < end; ++i) {
1249  output_types.push_back(tuple_type->elements().at(i));
1250  }
1251  auto tt = TupleType::create(std::move(output_types));
1252  n->output()->setType(tt);
1253  return n;
1254 }
1255 
1256 Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
1257  auto n = create(prim::ListConstruct, values);
1258  for (const auto& v : values) {
1259  AT_ASSERT(v->type()->isSubtypeOf(elem_type));
1260  }
1261  n->output()->setType(ListType::create(elem_type));
1262  return n;
1263 }
1264 Node* Graph::createListUnpack(Value* v, size_t size) {
1265  ListTypePtr list_type = v->type()->expect<ListType>();
1266  TypePtr elem_type = list_type->getElementType();
1267  auto n = create(prim::ListUnpack, {v}, 0);
1268  for (size_t i = 0; i < size; ++i) {
1269  n->addOutput()->setType(elem_type);
1270  }
1271  return n;
1272 }
1273 
1274 Node* Graph::createDict(
1275  const TypePtr& key_type,
1276  const TypePtr& value_type,
1277  at::ArrayRef<Value*> keys,
1278  at::ArrayRef<Value*> values) {
1279  AT_ASSERT(keys.size() == values.size());
1280  auto n = create(prim::DictConstruct, 1);
1281  for (size_t i = 0; i < keys.size(); ++i) {
1282  AT_ASSERT(keys[i]->type()->isSubtypeOf(key_type));
1283  AT_ASSERT(values[i]->type()->isSubtypeOf(value_type));
1284 
1285  n->addInput(keys[i]);
1286  n->addInput(values[i]);
1287  }
1288  n->output()->setType(DictType::create(key_type, value_type));
1289  return n;
1290 }
1291 
1292 Node* Graph::createDictIndex(Value* dict, Value* index) {
1293  auto dict_type = dict->type()->expect<DictType>();
1294  AT_ASSERT(index->type()->isSubtypeOf(dict_type->getKeyType()));
1295 
1296  auto n = create(prim::DictIndex, {dict, index});
1297  n->output()->setType(dict_type->getValueType());
1298  return n;
1299 }
1300 
1301 Node* Graph::createNumToTensor(Value* value) {
1302  auto typ = value->type();
1303  Node* result = create(prim::NumToTensor, {value});
1304  result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ)));
1305  return result;
1306 }
1307 
1308 Node* Graph::createImplicitTensorToNum(const TypePtr& type, Value* value) {
1309  auto* result = create(prim::ImplicitTensorToNum, {value});
1310  result->output()->setType(type);
1311  return result;
1312 }
1313 
1314 Node* Graph::createObject(const ClassTypePtr& type) {
1315  auto result = create(prim::CreateObject);
1316  result->output()->setType(type);
1317  return result;
1318 }
1319 
1320 Node* Graph::createSetAttr(
1321  Value* obj,
1322  const std::string& field,
1323  Value* newValue) {
1324  auto n = create(prim::SetAttr, {obj, newValue}, /*num_outputs=*/0);
1325  n->s_(attr::name, field);
1326  return n;
1327 }
1328 
1329 Node* Graph::createGetAttr(Value* obj, const std::string& field) {
1330  const auto classType = obj->type()->expect<ClassType>();
1331 
1332  auto n = create(prim::GetAttr, {obj}, /*num_outputs=*/1);
1333  n->s_(attr::name, field);
1334 
1335  const auto outputType = classType->getAttribute(field);
1336  n->output()->setType(outputType);
1337  return n;
1338 }
1339 
1340 Node* Graph::createClone(
1341  Node* n,
1342  const std::function<Value*(Value*)>& value_map,
1343  bool copy_blocks) {
1344  // n can be from a different graph
1345  Node* r = n->allocNewInstance(this);
1346  for (auto o : n->outputs()) {
1347  r->addOutput()->copyMetadata(o);
1348  }
1349  r->cloneFrom(n);
1350  for (auto i : n->inputs()) {
1351  r->addInput(value_map(i));
1352  }
1353  if (copy_blocks) {
1354  for (auto b : n->blocks()) {
1355  r->addBlock()->cloneFrom(b, value_map);
1356  }
1357  }
1358  return r;
1359 }
1360 
1361 Value* Graph::insertConstant(
1362  IValue val,
1363  const TypePtr& result_type,
1365  c10::optional<ScopePtr> scope) {
1366  return jit::insertConstant(
1367  *this, std::move(val), result_type, std::move(loc), std::move(scope));
1368 }
1369 
1370 std::string Graph::toString() const {
1371  std::ostringstream oss;
1372  oss << *this;
1373  return oss.str();
1374 }
1375 
1376 Graph::~Graph() {
1377  for (const Node* n : all_nodes) {
1378  delete n;
1379  }
1380  for (const Value* v : all_values) {
1381  delete v;
1382  }
1383  for (const Block* b : all_blocks) {
1384  delete b;
1385  }
1386 }
1387 
1388 void Graph::freeNode(Node* n) {
1389  auto it = all_nodes.find(n);
1390  AT_ASSERT(it != all_nodes.end());
1391  delete *it;
1392  all_nodes.erase(it);
1393 }
1394 void Graph::freeValue(Value* v) {
1395  v->setUniqueName("");
1396  auto it = all_values.find(v);
1397  AT_ASSERT(it != all_values.end());
1398  delete *it;
1399  all_values.erase(it);
1400 }
1401 void Graph::freeBlock(Block* b) {
1402  auto it = all_blocks.find(b);
1403  AT_ASSERT(it != all_blocks.end());
1404  delete *it;
1405  all_blocks.erase(it);
1406 }
1407 
1408 at::ArrayRef<Value*> createTupleUnpack(Value* v) {
1409  // small peephole optimization to ensure IntArrayRef attributes can still turn
1410  // into constants e.g. in x.expand([3, 4])
1411  if (v->node()->kind() == prim::TupleConstruct) {
1412  return v->node()->inputs();
1413  }
1414  auto& g = *v->owningGraph();
1415  return g.insertNode(g.createTupleUnpack(v))->outputs();
1416 }
1417 
1418 std::vector<Value*> inlineCallTo(
1419  Graph& g,
1420  Graph& callee,
1421  ArrayRef<Value*> inputs,
1422  bool unpack_outputs) {
1423  std::unordered_map<Value*, Value*> value_map;
1424  auto value_map_func = [&](Value* v) { return value_map.at(v); };
1425  AT_ASSERT(callee.inputs().size() == inputs.size());
1426  for (size_t i = 0; i < inputs.size(); ++i) {
1427  value_map[callee.inputs()[i]] = inputs[i];
1428  }
1429  for (auto* node : callee.nodes()) {
1430  auto* new_node = g.insertNode(g.createClone(node, value_map_func));
1431  for (size_t i = 0; i < node->outputs().size(); ++i) {
1432  value_map[node->outputs()[i]] = new_node->outputs()[i];
1433  }
1434  }
1435 
1436  std::vector<Value*> outputs;
1437  for (auto* output : callee.outputs()) {
1438  outputs.push_back(value_map_func(output));
1439  }
1440 
1441  if (unpack_outputs && outputs.size() == 1 &&
1442  callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
1443  auto tup = outputs[0];
1444  outputs.clear();
1445  for (Value* v : createTupleUnpack(tup)) {
1446  outputs.emplace_back(v);
1447  }
1448  // if this was a peephole tuple unpack we can just get rid of
1449  // the tuple construct here and prevent needing DCE
1450  if (tup->node()->kind() == prim::TupleConstruct &&
1451  !tup->node()->hasUses()) {
1452  tup->node()->destroy();
1453  }
1454  }
1455 
1456  return outputs;
1457 }
1458 
1459 PythonOp* defaultAllocPythonOp(Graph* g) {
1460  throw std::runtime_error(
1461  "Trying to allocate a Python object without python bindings loaded");
1462 }
1463 std::atomic<decltype(&defaultAllocPythonOp)> alloc_python_op;
1464 
1465 // patched in when python bindings are loaded
1466 PythonOp* allocPythonOp(Graph* g) {
1467  return alloc_python_op.load()(g);
1468 }
1469 void setAllocPythonOp(PythonOp* (*v)(Graph* g)) {
1470  alloc_python_op.store(v);
1471 }
1472 
1473 } // namespace jit
1474 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41